保存/恢复对象
tf.train.Saver这个类负责保存和恢复模型。实际使用中,可以在模型训练的不同阶段保存多个checkpoints。为了防止硬盘被写满,保存模型过大,实际中保存n个最近的checkpoints。
# save
v1 = tf.Variable(1.0, name="v1")
saver = tf.train.Saver() # global
saver_local = tf.train.Saver({'v1': v1}) # local
sess = tf.Session()
for step in xrange(1000000):
sess.run(..training_op..)
if step % 1000 == 0:
saver.save(sess, 'my-model', global_step=step) ==> filename: 'my-model-1000'
# restore
saver.restore(sess, 'my-model-1000')
saver.restore(sess, saver.latest_checkpoint()) # 可以自动传入参数路径
执行save方法之后,会生成一系列的文件,最好是单独放在一个路径中。
利用模型fine-tune
vgg_saver = tf.train.import_meta_graph('vgg16.meta')
vgg_graph = tf.get_default_graph()
self.x_plh = vgg_grah.get_tensor_by_name('input:0')
# 选择想要保留的层
output_conv =vgg_graph.get_tensor_by_name('conv1_2:0')
# output_conv =vgg_graph.get_tensor_by_name('conv2_2:0')
# output_conv =vgg_graph.get_tensor_by_name('conv3_3:0')
output_conv_sg = tf.stop_gradient(output_conv)
# 添加自己的网络结构
output_conv_shape = output_conv_sg.get_shape().as_list()
W1 = tf.get_variable('W1', shape=[1, 1, output_conv_shape[3], 32], initializer=tf.random_normal_initializer(stddev=1e-1))
b1 = tf.get_variable('b1', shape=[32], initializer=tf.constant_initializer(0.1))
z1 = tf.nn.conv2d(output_conv_sg, W1, strides=[1, 1, 1, 1], padding='SAME') + b1
a = tf.nn.relu(z1)
从caffe模型加载参数
tensorflow其实可以从任意算法,我们这里用caffe的模型参数为例
from six.moves import cPickle
net = caffe.Net(args.prototxt, args.caffemodel, caffe.TEST)
net_skeleton = list()
for name, item in net.params.iteritems():
net_skeleton.append([name + '/w', item[0].data.shape[::-1]])
net_skeleton.append([name + '/b', item[1].data.shape])
with open(os.path.join(args.output_dir, 'net_skeleton.ckpt'), 'wb') as f:
cPickle.dump(net_skeleton, f, protocol=cPickle.HIGHEST_PROTOCOL)
net_weights = dict()
for name, item in net.params.iteritems():
net_weights[name + '/w'] = item[0].data.transpose(2, 3, 1, 0)
net_weights[name + '/b'] = item[1].data
with open(os.path.join(args.output_dir,'net_weights.ckpt'), 'wb') as f:
cPickle.dump(net_weights, f, protocol=cPickle.HIGHEST_PROTOCOL)
del net, net_skeleton, net_weights
保存好上述模型后,通过下面的方法加载即可
saver = tf.train.Saver(var_list=trainable)
saver.restore(sess, os.path.join(args.output_dir,'net_weights.ckpt'))