tensorflow学习(二)数据接口

网上的教程、官方的资料大多关注的都是如何使用模型,对于数据源的说明都比较简单。基本资料都是用的MNIST数据集做的实验。

供给数据概述

tensorflow的模型数据读取还是非常友好的,最适合大量的数据的变量操作的形式应该是feed机制,利用tf.placeholder()为操作的变量创建占位符,在run()eval()的时候调用参数,将变量x和其真实标签y_传入模型。代码如下:

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

...

x = tf.placeholder("float", shape=[None, 784])
y_ = tf.placeholder("float", shape=[None, 10])

...

for i in range(1000):
    batch = mnist.train.next_batch(50)
    train_step.run(feed_dict={x:batch[0], y_: batch[1]})

上述代码解析:
首先将图像数据转换到4维的uint8的ndarray类型数组[index, y, x, depth]。标签数据根据是否为one_hot,转换为一个uint8的整型数值或者变成一个只有一个数值为1,其他维度为0的向量。然后将这样的数据标签对保存为tensorflow的DataSet类中。本例中这样的操作通过mnist = input_data.read_data_sets('MNIST_data', one_hot=True)实现。然后将这样的数据通过占位符feed进模型。

读取文件列表

我们可以通过一个文件列表来读入数据,之前都是通过python的glob来读取文件列表,也可以通过tf.train.match_filenames_once产生文件列表,然后交给tf.train.string_input_producer产生合适大小和是否乱序的文件名队列。

读取文件

图片文件

读入后为numpy的array格式,范围[-1,1],通道为[H,W,channels],代码如下

def imread(self, file_name):
    #    image : an image with type np.float32 in range [-1, 1]
    #    of size (H x W x 3) in RGB or
    #    of size (H x W x 1) in grayscale.
    img = skimage.img_as_float(skimage.io.imread(file_name, as_grey=False)).astype(np.float32)
    img = img*2.0 - 1.0
    img_resize = skimage.transform.resize(img, (width, height, 3))
    img_resize = img_resize[:, :, (2,1,0)]
    return img_resize

csv文件

使用tf.TextLineReader读取文件的一行内容,然后通过tf.decode_csv来解析这一行内容并将其转为张量列表

二进制文件读取固定长度

从二进制文件中读取固定长度纪录, 可以使用tf.FixedLengthRecordReadertf.decode_raw操作。decode_raw操作可以讲一个字符串转换为一个uint8的张量。例程中的cifar-10的数据集的读入就是通过这个方式。

标准TensorFlow格式TFRecords文件

将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串,并且通过tf.python_io.TFRecordWriter写入到TFRecords文件。通过传参images和labels将数据和标签数据保存到tfrecords中,代码如下:

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def convert_to(images, labels, name):
    num_examples = labels.shape[0]
    if images.shape[0] != num_examples:
        raise ValueError("Images size %d does not match label size %d." % (images.shape[0], num_examples))
    rows = images.shape[1]
    cols = images.shape[2]
    depth = images.shape[3]

    filename = os.path.join(FLAGS.directory, name + '.tfrecords')
    print('Writing', filename)
    writer = tf.python_io.TFRecordWriter(filename)
    for index in range(num_examples):
        image_raw = images[index].tostring()
        example = tf.train.Example(features=tf.train.Features(feature={
            'height': _int64_feature(rows),
            'width': _int64_feature(cols),
            'depth': _int64_feature(depth),
            'label': _int64_feature(int(labels[index])),
            'image_raw': _bytes_feature(image_raw)}))
        writer.write(example.SerializeToString())
    writer.close()

从TFRecords文件中读取数据,可以使用tf.TFRecordReadertf.parse_single_example解析器。这个parse_single_example操作可以将Example协议内存块(protocol buffer)解析为张量。通过train参数传入TFRecords文件返回值是图像标签对。代码如下:

def read_and_decode(filename_queue):
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(
        serialized_example,
        # Defaults are not specified since both keys are required.
        features={
            'image_raw': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([], tf.int64),
        })

    # Convert from a scalar string tensor (whose single string has
    # length mnist.IMAGE_PIXELS) to a uint8 tensor with shape
    # [mnist.IMAGE_PIXELS].
    image = tf.decode_raw(features['image_raw'], tf.uint8)
    image.set_shape([mnist.IMAGE_PIXELS])

    # Convert from [0, 255] -> [-0.5, 0.5] floats.
    image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
    # Convert label from a scalar uint8 tensor to an int32 scalar.
    label = tf.cast(features['label'], tf.int32)

    return image, label


def inputs(train, batch_size, num_epochs):
    if not num_epochs: num_epochs = None
    filename = os.path.join(FLAGS.train_dir, TRAIN_FILE if train else VALIDATION_FILE)

    with tf.name_scope('input'):
        filename_queue = tf.train.string_input_producer([filename], num_epochs=num_epochs)
        image, label = read_and_decode(filename_queue)

        return image, label

批处理

TBD

创建线程并使用QueueRunner对象来预取

TBD

训练模型保存和恢复

通过实例化一个tf.train.Saver来实现模型的保存。

saver = tf.train.Saver()
saver.save(sess, FLAGS.train_dir, global_step=step)

在训练循环中,将定期调用saver.save()方法,向训练文件夹中写入包含了当前所有可训练参数的检查点文件。
恢复检查点的时候使用saver.restore()方法,重载模型的参数,继续训练。

saver.restore(sess, FLAGS.train_dir)

参考资料

http://wiki.jikexueyuan.com/project/tensorflow-zh/how_tos/reading_data.html