在使用mmlab系列框架的推断inference_model代码时报错如下:
File "/home/user/Lab/mmdetection/mmdet/apis/inference.py", line 121, in inference_detector
data = scatter(data, [device])[0]
File "/home/user/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/parallel/scatter_gather.py", line 44, in scatter
return scatter_map(inputs)
File "/home/user/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/parallel/scatter_gather.py", line 34, in scatter_map
out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
File "/home/user/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/parallel/scatter_gather.py", line 29, in scatter_map
return list(zip(*map(scatter_map, obj)))
File "/home/user/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/parallel/scatter_gather.py", line 31, in scatter_map
out = list(map(list, zip(*map(scatter_map, obj))))
File "/home/user/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/parallel/scatter_gather.py", line 31, in scatter_map
out = list(map(list, zip(*map(scatter_map, obj))))
File "/home/user/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/parallel/scatter_gather.py", line 34, in scatter_map
out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
File "/home/user/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/parallel/scatter_gather.py", line 29, in scatter_map
return list(zip(*map(scatter_map, obj)))
File "/home/user/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/parallel/scatter_gather.py", line 27, in scatter_map
return Scatter.forward(target_gpus, obj.data)
File "/home/user/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/parallel/_functions.py", line 76, in forward
streams = [_get_stream(device) for device in target_gpus]
File "/home/user/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/parallel/_functions.py", line 76, in <listcomp>
streams = [_get_stream(device) for device in target_gpus]
File "/home/user/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/nn/parallel/_functions.py", line 115, in _get_stream
if _streams[device] is None:
TypeError: list indices must be integers or slices, not torch.device
主要是原因是测试配置文件中不要使用dict(type=’DefaultFormatBundle’),的数据前处理操作,改成dict(type=’ImageToTensor’, keys=[‘img’]),可以参考官方的测试代码,前者是为了在训练中保持数据对齐,测试一般用后者。