tensorflow学习(十一)保存和恢复模型

保存/恢复对象

tf.train.Saver这个类负责保存和恢复模型。实际使用中,可以在模型训练的不同阶段保存多个checkpoints。为了防止硬盘被写满,保存模型过大,实际中保存n个最近的checkpoints。

# save
v1 = tf.Variable(1.0, name="v1")
saver = tf.train.Saver()  # global
saver_local = tf.train.Saver({'v1': v1})  # local
sess = tf.Session()
for step in xrange(1000000):
    sess.run(..training_op..)
    if step % 1000 == 0:
        saver.save(sess, 'my-model', global_step=step)  ==> filename: 'my-model-1000'

# restore
saver.restore(sess, 'my-model-1000')
saver.restore(sess, saver.latest_checkpoint())  # 可以自动传入参数路径

执行save方法之后,会生成一系列的文件,最好是单独放在一个路径中。

利用模型fine-tune

vgg_saver = tf.train.import_meta_graph('vgg16.meta')
vgg_graph = tf.get_default_graph()

self.x_plh = vgg_grah.get_tensor_by_name('input:0')

# 选择想要保留的层
output_conv =vgg_graph.get_tensor_by_name('conv1_2:0')
# output_conv =vgg_graph.get_tensor_by_name('conv2_2:0')
# output_conv =vgg_graph.get_tensor_by_name('conv3_3:0')

output_conv_sg = tf.stop_gradient(output_conv)

# 添加自己的网络结构
output_conv_shape = output_conv_sg.get_shape().as_list()
W1 = tf.get_variable('W1', shape=[1, 1, output_conv_shape[3], 32], initializer=tf.random_normal_initializer(stddev=1e-1))
b1 = tf.get_variable('b1', shape=[32], initializer=tf.constant_initializer(0.1))
z1 = tf.nn.conv2d(output_conv_sg, W1, strides=[1, 1, 1, 1], padding='SAME') + b1
a = tf.nn.relu(z1)

从caffe模型加载参数

tensorflow其实可以从任意算法,我们这里用caffe的模型参数为例

from six.moves import cPickle

net = caffe.Net(args.prototxt, args.caffemodel, caffe.TEST)

net_skeleton = list() 
for name, item in net.params.iteritems():
    net_skeleton.append([name + '/w', item[0].data.shape[::-1]])
    net_skeleton.append([name + '/b', item[1].data.shape])
with open(os.path.join(args.output_dir, 'net_skeleton.ckpt'), 'wb') as f:
    cPickle.dump(net_skeleton, f, protocol=cPickle.HIGHEST_PROTOCOL)

net_weights = dict()
for name, item in net.params.iteritems():
    net_weights[name + '/w'] = item[0].data.transpose(2, 3, 1, 0)
    net_weights[name + '/b'] = item[1].data
with open(os.path.join(args.output_dir,'net_weights.ckpt'), 'wb') as f:
    cPickle.dump(net_weights, f, protocol=cPickle.HIGHEST_PROTOCOL)
del net, net_skeleton, net_weights

保存好上述模型后,通过下面的方法加载即可

saver = tf.train.Saver(var_list=trainable)
saver.restore(sess, os.path.join(args.output_dir,'net_weights.ckpt'))