caffe学习(二)用自己的数据训练模型

数据准备

整理、去重

将图片文件按照类别放入不同的文件夹下面,同时去掉相同文件夹下的相同图片

打tag

将图片数据按8:1:1的比例分成train,val,test三个子集,并产生“文件路径 标签”的标签数据,如下

../data/train_img/104166241/143088862023832744.jpg 0
../data/train_img/104166241/143088862023832435.jpg 0
../data/train_img/121391188/142096822616212389.jpg 1

生成lmdb数据库文件

copy一份examples/imagenet/create_imagenet.sh文件并修改其中EXAMPLE、DATA、TOOLS、TRAIN_DATA_ROOT、VAL_DATA_ROOT的对应路径,并且修改RESIZE=true,因为我们的训练图片要归一化到256*256,我的修改如下

EXAMPLE=.
DATA=.
TOOLS=../build/tools

TRAIN_DATA_ROOT=./
VAL_DATA_ROOT=./

RESIZE=true

最后把其中的脚本语句中的convert_imageset命令中的标签文件改为自己的文件名(train.txt、val.txt),然后执行

./create_db.sh

生成均值文件

copy一份examples/imagenet/make_imagenet_mean.sh文件并修改其中的EXAMPLE、DATA、TOOLS路径位置,以及命令compute_image_mean中的训练数据集的路径和保存的均值文件的路径文件名,然后执行

./make_db_mean.sh

模型参数调整

选择一个模型,这里选择GoogleNet,将models/bvlc_googlenet中的solver.prototxttrain_val.prototxt复制到新文件夹下。

train_val.prototxt文件

  • 修改输入layer中的TRAIN source和TEST中的source到自己的数据库路径
  • 同时注意一下其中的batch_size参数,决定每次迭代送多少图片去训练
  • googleNet有3个弱分类器,修改3个最终的loss1/classifier层中的num_output为我们样本的实际类别数,imageNet类别为1000

solver.prototxt文件

  • test_iter按照caffe官方的推荐,应该满足batch_size * test_iter = val子集的数据数,比如有1万张验证集图片,batch_size是50,那么test_iter应该设置为200左右。
  • test_interval决定多次训练迭代后测试一次,由于googleNet每次迭代较快,设置为2000比较合适,大概半个小时test一次
  • display参数决定多少次迭代显示一次,由于googleNet每次迭代较快,这里我设置为100
  • stepsize决定多次次迭代后减少学习率,应该和max_iter最大迭代次数综合考虑
  • snapshot决定多次次迭代后保存一次模型
  • snapshot_prefix是保存的模型文件名
  • solver_mode是用GPU还是CPU
  • lr_policy: 这个参数可以选的参数如下:
    • fixed : lr永不变
    • step : lr = baselr * gamma^(iter/stepsize)
    • exp : lr = baselr * gamma^iter
    • inv : lr = baselr (1+gammaiter)^(-power)
    • multistep : 直接写iter在某个范围内时lr应该是多少,自定义学习曲线
    • poly : lr = baselr * (1 - iter/maxiter)^power
    • sigmoid : lr = baselr (1 / (1 + e^(-gamma(iter-stepsize))))
  • iter_size这个参数很有意思,可以控制train的时候,做多少次batch_size才计算一次梯度,也就是iter_size * batch_size。这个参数当网络很大,显存不够用,又想增大batch_size时候就非常给力

其他参数都是模型参数学习的相关参数,决定了模型的初始值和收敛速度,暂时先不调节

训练

简单,执行

../build/tools/caffe train --solver=solver.prototxt

看到了terminal不停的喷log了,那酸爽,记得看一下log中的test结果,可以看到top-1/top-5的验证结果。