最近下载了一个刚刚开源的货柜数据集,是coco格式的json数据,想来检测一下物体检测模型的效果。需要将json格式转换为xml形式(当然,如果只是要tfrecord,tensorflow的目标检测模型中也提供了直接将json转为tfrecord的代码文件 create_coco_tf_record.py),研究了一下coco api,然后写了一下转换代码,这里记录一下。
采用的数据集为 ‘旷视RPC大型商品数据集’(咦?编辑超链接选项咋没了呢,那就复制一下链接吧-_-!!! https://rpc-dataset.github.io/),下载完后是这样的格式(不是就做成这样,哈哈哈)
这里完全调用coco API来进行数据的提取,正好熟悉一下,当然根据代码的第一个注释也可发现,读取json然后用纯字典的查找也是可以实现的;其次,这里在写xml文件的时候是顺序进行节点的封装,精致的你们可以稍微进行封装一下,之前用数据的时候也写了各种cvs,json,MATLAB等格式的标注转换xml,其实都是一样的套路,这里稍微记录一下,后期再用的时候做个参考,希望对大家也有帮助。
代码:
from pycocotools.coco import COCO
import json
from lxml.etree import Element, ElementTree, SubElement
import os
import shutil
'''
with open("COCO_train.json","r+") as f:
data = json.load(f)
print("read ready")
for i in data:
print(i)
# info
# licenses
# categories
# __raw_Chinese_name_df
# images
# annotations
'''
if __name__ == "__main__":
json_name = 'instances_test2019.json'
xml_path = 'test2019_xml'
if os.path.exists(xml_path):
shutil.rmtree(xml_path)
os.mkdir(xml_path)
# 构建coco实例
coco=COCO(json_name)
# 得到所有图片id
images_id = coco.getImgIds()
# 通过id遍历得到图像的具体信息
for id in images_id:
img = coco.loadImgs(id)
jpg_name = img[0].get('file_name')
print(jpg_name)
xml_name = os.path.join(xml_path, jpg_name.strip('.jpg') + '.xml')
file = open(xml_name, 'wb+')
node_root = Element('annotation')
node_filename = SubElement(node_root, 'filename')
node_filename.text = str(jpg_name)
node_size = SubElement(node_root, 'size')
node_width = SubElement(node_size, 'width')
node_height = SubElement(node_size, 'height')
node_depth = SubElement(node_size, 'depth')
node_width.text = str(img[0].get('width'))
node_height.text = str(img[0].get('height'))
node_depth.text = '3'
# 根据图像id得到所有的标签id
annIds = coco.getAnnIds(imgIds=id)
# 根据标签的所有id得到具体的annotation信息
cats = coco.loadAnns(annIds)
for anno in cats:
node_object = SubElement(node_root, 'object')
node_name = SubElement(node_object, 'name')
'''
# 按照id排序得到的类别list
cats = coco.loadCats(coco.getCatIds())
nms=[cat['name'] for cat in cats]
'''
# 根据类别id直接得到真实类别
category_id = anno.get('category_id')
cats = coco.loadCats(category_id)[0].get('name')
node_name.text = str(cats)
node_bnbox = SubElement(node_object, 'bndbox')
node_xmin = SubElement(node_bnbox, 'xmin')
node_ymin = SubElement(node_bnbox, 'ymin')
node_xmax = SubElement(node_bnbox, 'xmax')
node_ymax = SubElement(node_bnbox, 'ymax')
bbox_info = anno.get('bbox')
# coco格式是左上角(x, y)与宽长(w, h)
node_xmin.text = str(int(bbox_info[0]))
node_ymin.text = str(int(bbox_info[1]))
node_xmax.text = str(int(bbox_info[0]+bbox_info[2]))
node_ymax.text = str(int(bbox_info[1]+bbox_info[3]))
doc = ElementTree(node_root)
doc.write(file, pretty_print=True)
print('xml make done!')