这次我们来上手一个大型网络,googLeNet。这篇讲一下Inception v3网络的训练和测试以及数据集的加载。
数据准备
将图片按照类别放入不同的文件夹下。格式如下:
data_dir/train/label_0/a.jpg
data_dir/train/label_0/b.jpg
...
data_dir/train/label_1/f.jpg
...
data_dir/validation/label_0/c.jpg
data_dir/validation/label_0/d.jpg
...
data_dir/labels.txt
同时将所有label的类别整理成一个txt放在data_dir/labels.txt文件中,行数代表该类别对应的整数值,从0开始计数,其文本内容如下:
label_0
label_1
label_2
...
label_n
转化成为TFRecord格式
首先通过_find_iamge_files
函数得到全量图片的文件名、标签、标签文本信息的对应关系,并打乱排序。然后按线程数将全量文件分块,分别执行_process_image_files_batch
。将数据按分块分别写入对应的TFRecordWriter。这个操作的代码我整理保存在了build_image_tfrecord_data.py。其中两个重要的参数,一个是进程数量的控制参数num_threads
和Record数据分块的参数num_shards
。我这里有个不同,原来代码是读文件之后直接压入Record,我是用OpenCV解码之后,将解码后数据变成string压入Record。代码调用格式如下:
python build_image_tfrecord_data.py \
--train_directory=mydata/train \
--validation_directory=mydata/validation \
--output_directory=mydata/tf_record \
--labels_file=mydata/labels.txt
网络模型
单机训练
分布式训练
测试评估
参考文献
https://github.com/tensorflow/models/tree/master/inception
http://arxiv.org/abs/1512.00567