上一篇文章简单介绍了本机构建运行tensorflow中object detection api过程常遇到的几个问题,也是基础编译的问题,作者近期在实际项目需求中涉及使用该api,对其将数据构造tfrecord作为输入过程进行了修改,由XML直接进行数据构造,解决项目中的构建tfrecord数据的等待过程,将思路进行总结,供有需要的同学参考。
初级思路
由于项目的需求,需要使用api进行更新数据库的计算,但是由于在更新数据的过程中涉及到需要重新构造tfrecord文件,这会导致物体检测过程出现等待,因此初始想法为:构造tfrecord的时候不写入图像的二进制内容,而是将图像的路径写入tfrecord文件,这样会加速tfrecord的构建并减少tfrecord的存储;在实际api的调用过程中,在解析tfrecord时,将解析出来的图像路劲进行转换,根据解析路径提取对应图像,赋值给对应的变量,完成后续的计算。
这个思路相对简单,但是也需要对api进行整体的了解才能找到在哪里能够将图像路径转换为图像,这里不再赘述,整个过程涉及的两点一个为制作tfrecord保存图像替换为路径,一个为在api中对应位置进行提取并赋值给对应变量,对应的操作分别如下:
将制作tfrecord数据过程中的 encoded_jpg 变量,也就是图像的二进制变量更换为图像的路径,例如encoded_jpg = path+’/’+group.filename,此时encoded_jpg 代表一个图像路径(如’D:/images/0001.jpg’),这里注意变为路径之后要和其他变量例如filename一样,添加‘.encode(‘utf8′)’操作,即 encoded_jpg = (path+’/’+group.filename).encode(‘utf8’) , 具体位置如下,相对简单,不在赘述
在api中解析图像路径,并提取对应的图像到 tf.contrib.slim.tfexample_decoder 路径找到tfexample_decoder.py文件,在Image类下tensors_to_item函数中修改如下部分:
- ***tf.read_flie()函数可以解析二进制文件,当初还采用了tf.decode_image函数进行得到路径二进制的解析,一直报错,偶然用tf.read_file函数成功了,对函数不了解耽误了很多时间,还是要了解基础才行***
- 通过上面两步操作,实现了由tfrecord存储图像改为存储图像路径,并完成在API中的转换,具体思路这里不好解释,需要对整个数据处理流程进行捋顺。简单来说,在api的tfrecord解析过程中,涉及到图像解析的部分为Image类,在这里进行修改即完成目标
进阶思路
实现第一个初级的想法之后,我们也对物体检测api中的Dataset API ,tf.data.Dataset函数有了进一步接触。API中输入tfrecord路径之后采用Dataset进行解析,那么考虑可不可以直接用Dataset直接解析XML文件得到图像的信息,跳过制作tfrecord文件在进行读取转换。
经过又一系列的瞎蒙,最终实现了,实现代码如下:
import tensorflow as tf
import numpy as np
import glob
import xml.etree.ElementTree as ET
import os
def get_indices_new(obj_num):
fir_array = np.zeros([obj_num, 2], dtype=np.int64)
for num in range(obj_num):
fir_array[num, 0] = num
fir_array[num, 1] = 0
return fir_array
def get_tensor_dict_new_(xml_path, img_path):
img_path_list = []
width_list = []
height_list = []
labels_list = []
box_num = []
boxs_lists = []
path = glob.glob(xml_path + '*.xml')
for name in path:
tree = ET.parse(name)
root = tree.getroot()
image_path = os.path.join(img_path + root.find('filename').text)
width_list.append(int(root.find('size')[0].text))
height_list.append(int(root.find('size')[1].text))
img_path_list.append(image_path)
one_img_label_list = []
boxs_list = []
i = 0
for member in root.findall('object'):
i = i + 1
box_list = []
one_img_label_list.append(member[0].text)
box_list.append(float(member[4][0].text))
box_list.append(float(member[4][1].text))
box_list.append(float(member[4][2].text))
box_list.append(float(member[4][3].text))
boxs_list.append(box_list)
boxs_lists.append(tf.convert_to_tensor(boxs_list))
box_num.append(i)
labels_list.append(tf.convert_to_tensor(one_img_label_list))
max_num = max(box_num)
for i, label in enumerate(labels_list):
pad_num = max_num - box_num[i]
labels_list[i] = tf.pad(label, [[0, pad_num]], "CONSTANT")
for i, box in enumerate(boxs_lists):
pad_num = max_num - box_num[i]
boxs_lists[i] = tf.pad(box, [[0, pad_num], [0, 0]], "CONSTANT")
xml_list = {
"filename": img_path_list,
#"width": width_list,
#"height": height_list,
"label": labels_list,
"bboxs": boxs_lists,
"box_num": box_num
}
return xml_list
def parse_fun(value):
boxs = value['bboxs']
num = value['box_num']
labels = value['label']
label = tf.slice(labels, [0], [num])
box = tf.slice(boxs, [0, 0], [num, 4])
new_value = {
"filename": value['filename'],
#"width": value['width'],
#"height": value['height'],
"bboxs": box,
"labels": label
}
return new_value
if __name__ == '__main__':
xml_path = 'D:/Test/1/xmls/'
img_path = 'D:/Test/1/images'
file_dict = get_tensor_dict_new_(xml_path, img_path)
tensor_dict = tf.data.Dataset.from_tensor_slices(file_dict)
tensor_dicts = tensor_dict.map(parse_fun)
iterator = tensor_dicts.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session() as sess:
try:
while True:
print(sess.run(one_element))
sess.run(one_element)
except tf.errors.OutOfRangeError:
print("end!")
以上是探索过程中的测试版本,本地测试可以走通,在实际的项目中,如果要以这种方式进行数据解析,需要对物体检测api的结构进行相应的调整。当然,这是结合具体项目的,如果是简单的使用该api则不会涉及到这个程度。构造的过程也是根据API中原始构造数据的结构进行设计,希望给有相应需求的同学提供参考。讲解过程比较粗糙,如果有错误的的地方,欢迎批评指出。各路大神请自动飘过~