tensorflow学习(十二)加载自定义图像数据集

之前介绍了一种官方的TFRecords序列化方法,感觉过于复杂,这次实践一次自定义数据集的数据加载方案。介绍几个关键的步骤。

数据读取

  • 获取数据集的图像文件以及对应标签文件的文件名列表
  • 讲文件名列表转换为tensor格式
  • 根据文件名列表读取图像并转换为tensor,并进行预处理
  • 产生一个batch的数据

其中前3个方法在ImageReader初始化过程中实现,最后一个步骤在dequeue方法中实现,实例如下:

def read_labeled_image_list(data_dir, data_list):
    f = open(data_list, 'r')
    images = []
    labels = []
    for line in f:
        image, label = line.strip("\n").split(' ')
        images.append(data_dir + image)
        labels.append(int(label))
    return images, labels

def read_images_from_disk(input_queue, input_size, random_scale): 
    img_contents = tf.read_file(input_queue[0])
    img = tf.image.decode_jpeg(img_contents, channels=3)
    if input_size is not None:
        h = input_size
        w = input_size
        if random_scale:
            scale = tf.random_uniform([1], minval=0.75, maxval=1.25, dtype=tf.float32, seed=None)
            h_new = tf.to_int32(tf.mul(tf.to_float(tf.shape(img)[0]), scale))
            w_new = tf.to_int32(tf.mul(tf.to_float(tf.shape(img)[1]), scale))
            new_shape = tf.squeeze(tf.pack([h_new, w_new]), squeeze_dims=[1])
            img = tf.image.resize_images(img, new_shape)
        img = tf.image.resize_image_with_crop_or_pad(img, h, w)
    # RGB -> BGR.
    img_r, img_g, img_b = tf.split(split_dim=2, num_split=3, value=img)
    img = tf.cast(tf.concat(2, [img_b, img_g, img_r]), dtype=tf.float32)
    # Extract mean.
    img -= np.array((104.008,116.669,122.675), dtype=np.float32)

    label = tf.cast(input_queue[1], tf.int32)
    return img, label

class ImageReader(object):
    def __init__(self, data_dir, data_list, input_size, random_scale):
        self.data_dir = data_dir
        self.data_list = data_list
        self.input_size = input_size

        self.image_list, self.label_list = read_labeled_image_list(self.data_dir, self.data_list)
        self.images = tf.convert_to_tensor(self.image_list, dtype=tf.string)
        self.labels = tf.convert_to_tensor(self.label_list, dtype=tf.string)
        self.queue = tf.train.slice_input_producer([self.images, self.labels],
                                                   shuffle=input_size is not None) # Not shuffling if it is val.
        self.image, self.label = read_images_from_disk(self.queue, self.input_size, random_scale) 

    def dequeue(self, num_elements):
        image_batch, label_batch = tf.train.batch([self.image, self.label], num_elements)
        return image_batch, label_batch

线程和队列

Coordinator类负责多个线程的同步工作和同步终止,tf.train.start_queue_runners负责创建一组线程。推荐的使用模版

init_op = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init_op)

# Start input enqueue threads.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

try:
    while not coord.should_stop():
        sess.run(train_op)
except tf.errors.OutOfRangeError:
    print 'Done training -- epoch limit reached'
finally:
    coord.request_stop()

coord.join(threads)
sess.close()

调用实例

# image_batch, label_batch = tf.train.batch([image, label], 32)
reader = ImageReader()
image_batch, label_batch = reader.dequeue(32)
loss = loss = tf.nn.softmax_cross_entropy_with_logits(prediction, label_batch)

init_op = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init_op)

coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

try:
    for step in range(args.num_steps):
        loss_value = sess.run(loss)
except tf.errors.OutOfRangeError:
    print 'Done training -- epoch limit reached'
finally:
    coord.request_stop()

coord.join(threads)
sess.close()