# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
import tempfile
from argparse import ArgumentParser
import cv2
import mmcv
from mmtrack.apis import inference_sot, init_model
def main():
parser = ArgumentParser()
parser.add_argument('config', help='Config file')
parser.add_argument('--input', help='input video file')
parser.add_argument('--output', help='output video file (mp4 format)')
parser.add_argument('--checkpoint', help='Checkpoint file')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--show',
action='store_true',
default=False,
help='whether to show visualizations.')
parser.add_argument(
'--color', default=(0, 255, 0), help='Color of tracked bbox lines.')
parser.add_argument(
'--thickness', default=3, type=int, help='Thickness of bbox lines.')
parser.add_argument('--fps', type=int, help='FPS of the output video')
parser.add_argument('--gt_bbox_file', help='The path of gt_bbox file')
args = parser.parse_args()
# load images
if osp.isdir(args.input):
imgs = sorted(
filter(lambda x: x.endswith(('.jpg', '.png', '.jpeg')),
os.listdir(args.input)),
key=lambda x: int(x.split('.')[0]))
IN_VIDEO = False
else:
imgs = mmcv.VideoReader(args.input)
IN_VIDEO = True
OUT_VIDEO = False
# define output
if args.output is not None:
if args.output.endswith('.mp4'):
OUT_VIDEO = True
out_dir = tempfile.TemporaryDirectory()
out_path = out_dir.name
_out = args.output.rsplit(os.sep, 1)
if len(_out) > 1:
os.makedirs(_out[0], exist_ok=True)
else:
out_path = args.output
os.makedirs(out_path, exist_ok=True)
fps = args.fps
if args.show or OUT_VIDEO:
if fps is None and IN_VIDEO:
fps = imgs.fps
if not fps:
raise ValueError('Please set the FPS for the output video.')
fps = int(fps)
# build the model from a config file and a checkpoint file
model = init_model(args.config, args.checkpoint, device=args.device)
prog_bar = mmcv.ProgressBar(len(imgs))
# test and show/save the images
for i, img in enumerate(imgs):
if isinstance(img, str):
img_path = osp.join(args.input, img)
img = mmcv.imread(img_path)
if i == 0:
if args.gt_bbox_file is not None:
bboxes = mmcv.list_from_file(args.gt_bbox_file)
init_bbox = list(map(float, bboxes[0].split(',')))
else:
init_bbox = list(cv2.selectROI(args.input, img, False, False))
# convert (x1, y1, w, h) to (x1, y1, x2, y2)
init_bbox[2] += init_bbox[0]
init_bbox[3] += init_bbox[1]
result = inference_sot(model, img, init_bbox, frame_id=i)
if args.output is not None:
if IN_VIDEO or OUT_VIDEO:
out_file = osp.join(out_path, f'{
i:06d}.jpg')
else:
out_file = osp.join(out_path, img_path.rsplit(os.sep, 1)[-1])
else:
out_file = None
model.show_result(
img,
result,
show=args.show,
wait_time=int(1000. / fps) if fps else 0,
out_file=out_file,
thickness=args.thickness)
prog_bar.update()
if args.output and OUT_VIDEO:
print(
f'\nmaking the output video at {
args.output} with a FPS of {
fps}')
mmcv.frames2video(out_path, args.output, fps=fps, fourcc='mp4v')
out_dir.cleanup()
if __name__ == '__main__':
main()
全部源码如下,下面开始解析。
这个demo_sot.py是在mmtracking项目文件夹下面的demo文件夹下的演示代码
import os
import os.path as osp
import tempfile
from argparse import ArgumentParser
import cv2
import mmcv
from mmtrack.apis import inference_sot, init_model
parser = ArgumentParser()
parser.add_argument('config', help='Config file')
parser.add_argument('--input', help='input video file')
parser.add_argument('--output', help='output video file (mp4 format)')
parser.add_argument('--checkpoint', help='Checkpoint file')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--show',
action='store_true',
default=False,
help='whether to show visualizations.')
parser.add_argument(
'--color', default=(0, 255, 0), help='Color of tracked bbox lines.')
parser.add_argument(
'--thickness', default=3, type=int, help='Thickness of bbox lines.')
parser.add_argument('--fps', type=int, help='FPS of the output video')
parser.add_argument('--gt_bbox_file', help='The path of gt_bbox file')
args = parser.parse_args()
使用 argparse 库解析命令行参数。这些参数包括:
config:配置文件
–input:输入视频文件或图像目录
–output:输出视频文件(mp4 格式)
–checkpoint:检查点文件
–device:用于推理的设备(默认为 ‘cuda:0’)
–show:是否显示可视化结果
–color:跟踪边界框线条的颜色
–thickness:边界框线条的粗细
–fps:输出视频的帧率
–gt_bbox_file:ground truth 边界框文件的路径
# load images
if osp.isdir(args.input):
imgs = sorted(
filter(lambda x: x.endswith(('.jpg', '.png', '.jpeg')),
os.listdir(args.input)),
key=lambda x: int(x.split('.')[0]))
IN_VIDEO = False
else:
imgs = mmcv.VideoReader(args.input)
IN_VIDEO = True
OUT_VIDEO = False
首先判断输入的参数args.input是否是一个目录,
如果是一个目录,os.listdir(args.input)就列出该目录下面所有的文件和子目录,然后使用filter(lambda x: x.endswith((‘.jpg’, ‘.png’, ‘.jpeg’)), …)过滤出以.jpg, .png, 或 .jpeg 结尾的文件
lambda x:A,y表示对y使用以x:A的函数,这里就是对于os.listdir(args.input)列出的所有输入目录下面的子文件,使用ilter(lambda x: x.endswith((‘.jpg’, ‘.png’, ‘.jpeg’)),
sorted(…, key=lambda x: int(x.split(‘.’)[0])): 对过滤后的图像文件名进行排序。排序的依据是文件名(不包括扩展名)转化为整数后的值。这意味着如果文件名为 “img10.jpg”, “img2.jpg”, “img1.jpg”,它们会被正确地按数字顺序排序。
IN_VIDEO = False: 设置 IN_VIDEO 为 False,表示输入的不是视频。
如果不是一个目录,即一个视频文件,
mmcv.VideoReader(args.input): 使用 mmcv(一个常用于计算机视觉任务的库)的 VideoReader 函数读取视频文件。这个函数会返回一个迭代器,每次迭代都会返回一个视频帧。
IN_VIDEO = True: 设置 IN_VIDEO 为 True,表示输入的是视频。
OUT_VIDEO = False: 设置 OUT_VIDEO 为 False。这行代码表示在后续代码中,程序有可能会输出或保存为一个视频文件,但目前还没有设置要输出视频。
# define output
if args.output is not None:
if args.output.endswith('.mp4'):
OUT_VIDEO = True
out_dir = tempfile.TemporaryDirectory()
out_path = out_dir.name
_out = args.output.rsplit(os.sep, 1)
if len(_out) > 1:
os.makedirs(_out[0], exist_ok=True)
else:
out_path = args.output
os.makedirs(out_path, exist_ok=True)
fps = args.fps
if args.show or OUT_VIDEO:
if fps is None and IN_VIDEO:
fps = imgs.fps
if not fps:
raise ValueError('Please set the FPS for the output video.')
fps = int(fps)
if args.output is not None:如果参数设置的输出不为空
if args.output.endwith(‘.mp4’) 如果用户希望输出一个MP4视频文件
OUT_VIDEO = True 设置全局变量为True,表示要输出视频
out_dir = tempfile.TemporaryDirectory(): 创建一个临时目录,并将其路径存储在out_dir变量中。
out_path = out_dir.name: 获取临时目录的路径,并将其存储在out_path变量中。
_out = args.output.rsplit(os.sep, 1): 使用rsplit方法将args.output字符串从右边开始分割,分割符为os.sep(这通常是文件路径中的分隔符,如/或\),并且只分割一次。结果存储在_out变量中。
代码获取args.fps的值,即每秒帧数,用于后续的视频输出。
如果args.show或OUT_VIDEO为True(即用户希望显示或输出视频),代码会检查fps的值。
如果fps为None且IN_VIDEO为True(似乎是一个未在代码段中定义的变量,可能表示输入也是一个视频),则使用输入视频的fps。
如果fps仍然为None或False,则抛出一个ValueError,要求用户设置输出视频的FPS。
最后,将fps转换为整数。
# build the model from a config file and a checkpoint file
model = init_model(args.config, args.checkpoint, device=args.device)
prog_bar = mmcv.ProgressBar(len(imgs))
调用了一个 init_model 的函数,并将结果赋值给变量 model。这个函数来自
mmtracki.apis.inference.py
args.config: 是模型的配置文件路径,通常包含模型的结构、优化器设置、训练参数等。
args.checkpoint: 是模型的检查点文件路径,通常包含模型的权重或其他训练过程中的状态。
device=args.device: 指定模型应该在哪个设备上运行,例如 CPU 或 GPU。args.device 可能是一个字符串,如 “cpu” 或 “cuda:0”。
我们首先来看这个init_model函数
def init_model(config,
checkpoint=None,
device='cuda:0',
cfg_options=None,
verbose_init_params=False):
"""Initialize a model from config file.
Args:
config (str or :obj:`mmcv.Config`): Config file path or the config
object.
checkpoint (str, optional): Checkpoint path. Default as None.
cfg_options (dict, optional): Options to override some settings in
the used config. Default to None.
verbose_init_params (bool, optional): Whether to print the information
of initialized parameters to the console. Default to False.
Returns:
nn.Module: The constructed detector.
"""
if isinstance(config, str):
config = mmcv.Config.fromfile(config)
elif not isinstance(config, mmcv.Config):
raise TypeError('config must be a filename or Config object, '
f'but got {
type(config)}')
if cfg_options is not None:
config.merge_from_dict(cfg_options)
if 'detector' in config.model:
config.model.detector.pretrained = None
model = build_model(config.model)
if not verbose_init_params:
# Creating a temporary file to record the information of initialized
# parameters. If not, the information of initialized parameters will be
# printed to the console because of the call of
# `mmcv.runner.BaseModule.init_weights`.
tmp_file = tempfile.NamedTemporaryFile(delete=False)
file_handler = logging.FileHandler(tmp_file.name, mode='w')
model.logger.addHandler(file_handler)
# We need call `init_weights()` to load pretained weights in MOT
# task.
model.init_weights()
file_handler.close()
model.logger.removeHandler(file_handler)
tmp_file.close()
os.remove(tmp_file.name)
else:
# We need call `init_weights()` to load pretained weights in MOT task.
model.init_weights()
if checkpoint is not None:
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
if 'meta' in checkpoint and 'CLASSES' in checkpoint['meta']:
model.CLASSES = checkpoint['meta']['CLASSES']
if not hasattr(model, 'CLASSES'):
if hasattr(model, 'detector') and hasattr(model.detector, 'CLASSES'):
model.CLASSES = model.detector.CLASSES
else:
print("Warning: The model doesn't have classes")
model.CLASSES = None
model.cfg = config # save the config in the model for convenience
model.to(device)
model.eval()
return model
函数接受5个输入:
①config:这个是这是配置文件的路径或配置对象。是一个str字符串或者 mmcv.Config。mmcv.Config是一个对象,用于管理和解析配置文件。如果传入的是字符串,则它应该是指向配置文件的路径。
②checkpoint (str, optional): 检查点(checkpoint)路径,这是一个可选参数。检查点通常包含模型的权重和可能的其他状态信息。如果在初始化模型时要加载预训练的权重,则会用到这个参数。
③device (str, optional): 设备字符串,指定模型应该在哪个设备上运行。默认值是’cuda:0’,意味着模型将在第一个GPU上运行。如果要在CPU上运行,可以传入’cpu’。
④cfg_options (dict, optional): 一个字典,用于覆盖配置文件中的某些设置。这是一个可选参数,默认值为None。
⑤verbose_init_params (bool, optional): 一个布尔值,决定是否将初始化的参数信息打印到控制台。默认值为False。
接下来
if isinstance(config, str):
config = mmcv.Config.fromfile(config)
elif not isinstance(config, mmcv.Config):
raise TypeError('config must be a filename or Config object, '
f'but got {
type(config)}')
检查 config 的类型:
如果 config 是一个字符串(str),那么它很可能是一个配置文件的路径。这种情况下,代码使用 mmcv.Config.fromfile(config) 来从这个路径加载配置文件,并将其解析为 mmcv.Config 对象。
如果 config 不是一个 mmcv.Config 对象,那么代码会抛出一个 TypeError,提示 config 必须是一个文件名或 Config 对象。
if cfg_options is not None:
config.merge_from_dict(cfg_options)
合并配置选项:
如果 cfg_options 不为 None,代码会将其作为一个字典合并到 config 中。这意味着 cfg_options 中的任何设置都会覆盖 config 中的相应设置。
if 'detector' in config.model:
config.model.detector.pretrained = None
处理预训练模型:
如果 config.model 包含一个 ‘detector’ 键,代码会将其 pretrained 属性设置为 None。这通常意味着在构建模型时不会使用预训练的权重。
model = build_model(config.model)
构建模型:
最后,代码使用 build_model(config.model) 来根据 config.model 中的配置构建一个模型。这里假设 build_model 是一个已经定义好的函数,它可以根据提供的配置信息来创建和返回一个模型对象。
if not verbose_init_params:
# Creating a temporary file to record the information of initialized
# parameters. If not, the information of initialized parameters will be
# printed to the console because of the call of
# `mmcv.runner.BaseModule.init_weights`.
tmp_file = tempfile.NamedTemporaryFile(delete=False)
file_handler = logging.FileHandler(tmp_file.name, mode='w')
model.logger.addHandler(file_handler)
# We need call `init_weights()` to load pretained weights in MOT
# task.
model.init_weights()
file_handler.close()
model.logger.removeHandler(file_handler)
tmp_file.close()
os.remove(tmp_file.name)
else:
# We need call `init_weights()` to load pretained weights in MOT task.
model.init_weights()
判断verbose_init_params的值:如果verbose_init_params为False,则执行以下的代码块。
创建临时文件:使用tempfile.NamedTemporaryFile创建一个临时文件,并且设置delete=False,这意味着当文件关闭后,它不会被自动删除。
设置文件处理器:为model.logger添加一个文件处理器,这样model.logger输出的日志信息就会被写入到之前创建的临时文件中。
初始化模型权重:无论verbose_init_params的值如何,都会执行此行代码来加载预训练的权重。
关闭文件处理器和临时文件
删除临时文件
如果verbose_init_params为True:如果verbose_init_params为True,则不会创建临时文件,而是直接调用model.init_weights()来加载预训练的权重。在这种情况下,由于mmcv.runner.BaseModule.init_weights的调用,初始化的参数信息将被直接打印到控制台。
if checkpoint is not None:
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
if 'meta' in checkpoint and 'CLASSES' in checkpoint['meta']:
model.CLASSES = checkpoint['meta']['CLASSES']
if not hasattr(model, 'CLASSES'):
if hasattr(model, 'detector') and hasattr(model.detector, 'CLASSES'):
model.CLASSES = model.detector.CLASSES
else:
print("Warning: The model doesn't have classes")
model.CLASSES = None
检查检查点是否为空:检查checkpoint变量是否不为None。如果checkpoint有一个有效的值(即不是None),则进入该if语句块。
加载检查点:如果checkpoint不为空,这行代码会调用load_checkpoint函数,尝试加载检查点。load_checkpoint函数接受三个参数:model(模型对象)、checkpoint(检查点路径或对象)和map_location=‘cpu’(指定模型应加载到CPU上)。
从检查点中提取类别信息:这部分代码首先检查checkpoint字典中是否有一个meta键,并且meta字典中是否有一个CLASSES键。如果两者都存在,那么它会将CLASSES信息从检查点中提取出来,并赋值给模型的CLASSES属性。
检查模型是否有CLASSES属性:这行代码检查模型对象model是否没有CLASSES属性。如果没有,它会进入该if语句块。
从模型的detector属性中提取CLASSES:这部分代码首先检查模型对象model是否有一个detector属性,并且detector属性是否有一个CLASSES属性。如果两者都存在,那么它会将CLASSES信息从detector中提取出来,并赋值给模型的CLASSES属性。
警告:模型没有类别:如果上述所有条件都不满足(即模型及其detector属性都没有CLASSES属性),则会打印一条警告消息,并将模型的CLASSES属性设置为None。
这段代码的主要目的是从检查点或模型的detector属性中加载CLASSES信息,并确保模型具有CLASSES属性。如果模型或其detector属性中没有CLASSES信息,则会发出警告并将CLASSES设置为None。
model.cfg = config # save the config in the model for convenience
model.to(device)
model.eval()
return model
将config对象赋值给model的cfg属性,将模型移动到指定的设备上,将模型设置为评估模式,返回了配置并移动到指定设备的模型。
# test and show/save the images
for i, img in enumerate(imgs):
if isinstance(img, str):
img_path = osp.join(args.input, img)
img = mmcv.imread(img_path)
if i == 0:
if args.gt_bbox_file is not None:
bboxes = mmcv.list_from_file(args.gt_bbox_file)
init_bbox = list(map(float, bboxes[0].split(',')))
else:
init_bbox = list(cv2.selectROI(args.input, img, False, False))
# convert (x1, y1, w, h) to (x1, y1, x2, y2)
init_bbox[2] += init_bbox[0]
init_bbox[3] += init_bbox[1]
result = inference_sot(model, img, init_bbox, frame_id=i)
if args.output is not None:
if IN_VIDEO or OUT_VIDEO:
out_file = osp.join(out_path, f'{
i:06d}.jpg')
else:
out_file = osp.join(out_path, img_path.rsplit(os.sep, 1)[-1])
else:
out_file = None
model.show_result(
img,
result,
show=args.show,
wait_time=int(1000. / fps) if fps else 0,
out_file=out_file,
thickness=args.thickness)
prog_bar.update()
if args.output and OUT_VIDEO:
print(
f'\nmaking the output video at {
args.output} with a FPS of {
fps}')
mmcv.frames2video(out_path, args.output, fps=fps, fourcc='mp4v')
out_dir.cleanup()
遍历图像序列:for i, img in enumerate(imgs): 遍历 imgs 列表(或其他可迭代对象),i 是索引,img 是每个图像或图像路径。
处理图像路径:如果 img 是一个字符串(可能是文件路径),代码将其转换为绝对路径,并使用 mmcv.imread 读取图像。
初始化边界框:
如果 args.gt_bbox_file 存在,则从文件中读取边界框坐标。
如果不存在,使用 cv2.selectROI 手动选择图像上的感兴趣区域(ROI)作为初始边界框。
将边界框从 (x1, y1, w, h) 格式转换为 (x1, y1, x2, y2) 格式。
执行模型推理:调用 inference_sot(model, img, init_bbox, frame_id=i) 函数,对图像执行目标跟踪推理,并返回结果。
处理输出结果:
根据 args.output 和其他条件确定输出文件的路径。
使用 model.show_result 显示结果,可以选择是否显示图像、设置等待时间、保存输出文件等。
更新进度条(prog_bar.update())。
生成输出视频:如果 args.output 存在且 OUT_VIDEO 为真,将输出目录中的帧转换为视频文件,并清理输出目录。
更多【机器学习-【目标跟踪】【MMTracking的部署与开发】03 demo_sot.py源码解析】相关视频教程:www.yxfzedu.com