tensorflow学习(六)CNN网络Inception

这次我们来上手一个大型网络,googLeNet。这篇讲一下Inception v3网络的训练和测试以及数据集的加载。

数据准备

将图片按照类别放入不同的文件夹下。格式如下:

data_dir/train/label_0/a.jpg
data_dir/train/label_0/b.jpg
...
data_dir/train/label_1/f.jpg
...
data_dir/validation/label_0/c.jpg
data_dir/validation/label_0/d.jpg
...
data_dir/labels.txt

同时将所有label的类别整理成一个txt放在data_dir/labels.txt文件中,行数代表该类别对应的整数值,从0开始计数,其文本内容如下:

label_0
label_1
label_2
...
label_n

转化成为TFRecord格式

首先通过_find_iamge_files函数得到全量图片的文件名、标签、标签文本信息的对应关系,并打乱排序。然后按线程数将全量文件分块,分别执行_process_image_files_batch。将数据按分块分别写入对应的TFRecordWriter。这个操作的代码我整理保存在了build_image_tfrecord_data.py。其中两个重要的参数,一个是进程数量的控制参数num_threads和Record数据分块的参数num_shards。我这里有个不同,原来代码是读文件之后直接压入Record,我是用OpenCV解码之后,将解码后数据变成string压入Record。代码调用格式如下:

python build_image_tfrecord_data.py \
  --train_directory=mydata/train \
  --validation_directory=mydata/validation \
  --output_directory=mydata/tf_record \
  --labels_file=mydata/labels.txt

网络模型

单机训练

分布式训练

测试评估

参考文献

https://github.com/tensorflow/models/tree/master/inception
http://arxiv.org/abs/1512.00567