网上的教程、官方的资料大多关注的都是如何使用模型,对于数据源的说明都比较简单。基本资料都是用的MNIST数据集做的实验。
供给数据概述
tensorflow的模型数据读取还是非常友好的,最适合大量的数据的变量操作的形式应该是feed机制,利用tf.placeholder()
为操作的变量创建占位符,在run()
或eval()
的时候调用参数,将变量x和其真实标签y_传入模型。代码如下:
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
...
x = tf.placeholder("float", shape=[None, 784])
y_ = tf.placeholder("float", shape=[None, 10])
...
for i in range(1000):
batch = mnist.train.next_batch(50)
train_step.run(feed_dict={x:batch[0], y_: batch[1]})
上述代码解析:
首先将图像数据转换到4维的uint8的ndarray类型数组[index, y, x, depth]
。标签数据根据是否为one_hot,转换为一个uint8的整型数值或者变成一个只有一个数值为1,其他维度为0的向量。然后将这样的数据标签对保存为tensorflow的DataSet类中。本例中这样的操作通过mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
实现。然后将这样的数据通过占位符feed进模型。
读取文件列表
我们可以通过一个文件列表来读入数据,之前都是通过python的glob来读取文件列表,也可以通过tf.train.match_filenames_once
产生文件列表,然后交给tf.train.string_input_producer
产生合适大小和是否乱序的文件名队列。
读取文件
图片文件
读入后为numpy的array
格式,范围[-1,1]
,通道为[H,W,channels]
,代码如下
def imread(self, file_name):
# image : an image with type np.float32 in range [-1, 1]
# of size (H x W x 3) in RGB or
# of size (H x W x 1) in grayscale.
img = skimage.img_as_float(skimage.io.imread(file_name, as_grey=False)).astype(np.float32)
img = img*2.0 - 1.0
img_resize = skimage.transform.resize(img, (width, height, 3))
img_resize = img_resize[:, :, (2,1,0)]
return img_resize
csv文件
使用tf.TextLineReader
读取文件的一行内容,然后通过tf.decode_csv
来解析这一行内容并将其转为张量列表
二进制文件读取固定长度
从二进制文件中读取固定长度纪录, 可以使用tf.FixedLengthRecordReader
的tf.decode_raw
操作。decode_raw
操作可以讲一个字符串转换为一个uint8的张量。例程中的cifar-10的数据集的读入就是通过这个方式。
标准TensorFlow格式TFRecords文件
将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串,并且通过tf.python_io.TFRecordWriter
写入到TFRecords文件。通过传参images和labels将数据和标签数据保存到tfrecords中,代码如下:
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def convert_to(images, labels, name):
num_examples = labels.shape[0]
if images.shape[0] != num_examples:
raise ValueError("Images size %d does not match label size %d." % (images.shape[0], num_examples))
rows = images.shape[1]
cols = images.shape[2]
depth = images.shape[3]
filename = os.path.join(FLAGS.directory, name + '.tfrecords')
print('Writing', filename)
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
image_raw = images[index].tostring()
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(rows),
'width': _int64_feature(cols),
'depth': _int64_feature(depth),
'label': _int64_feature(int(labels[index])),
'image_raw': _bytes_feature(image_raw)}))
writer.write(example.SerializeToString())
writer.close()
从TFRecords文件中读取数据,可以使用tf.TFRecordReader
的tf.parse_single_example
解析器。这个parse_single_example
操作可以将Example协议内存块(protocol buffer)解析为张量。通过train参数传入TFRecords文件返回值是图像标签对。代码如下:
def read_and_decode(filename_queue):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
# Defaults are not specified since both keys are required.
features={
'image_raw': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
})
# Convert from a scalar string tensor (whose single string has
# length mnist.IMAGE_PIXELS) to a uint8 tensor with shape
# [mnist.IMAGE_PIXELS].
image = tf.decode_raw(features['image_raw'], tf.uint8)
image.set_shape([mnist.IMAGE_PIXELS])
# Convert from [0, 255] -> [-0.5, 0.5] floats.
image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
# Convert label from a scalar uint8 tensor to an int32 scalar.
label = tf.cast(features['label'], tf.int32)
return image, label
def inputs(train, batch_size, num_epochs):
if not num_epochs: num_epochs = None
filename = os.path.join(FLAGS.train_dir, TRAIN_FILE if train else VALIDATION_FILE)
with tf.name_scope('input'):
filename_queue = tf.train.string_input_producer([filename], num_epochs=num_epochs)
image, label = read_and_decode(filename_queue)
return image, label
批处理
TBD
创建线程并使用QueueRunner对象来预取
TBD
训练模型保存和恢复
通过实例化一个tf.train.Saver
来实现模型的保存。
saver = tf.train.Saver()
saver.save(sess, FLAGS.train_dir, global_step=step)
在训练循环中,将定期调用saver.save()
方法,向训练文件夹中写入包含了当前所有可训练参数的检查点文件。
恢复检查点的时候使用saver.restore()
方法,重载模型的参数,继续训练。
saver.restore(sess, FLAGS.train_dir)
参考资料
http://wiki.jikexueyuan.com/project/tensorflow-zh/how_tos/reading_data.html