Tensorflow迁移学习
首先,为了方便划分不同的网络模块(子网络)、分别导入权重、指定是否需要训练、指定是否需要复用,需要使用如下语句为网络权重设定scope;使用collection将隐含层输出保存为字典;在with xxx:语句内部定义的网络层也要定义scope
import tensorflow as tf
import tensorflow.contrib.slim as slim
# 需要Reuse时,设定reuse=tf.AUTO_REUSE
with tf.variable_scope("model_name", "model_name", reuse=tf.AUTO_REUSE) as sc:
# 之后就可以在下面定义网络了
conv1 = slim.conv2d(input, 32, [3, 3], rate=1, activation_fn=lrelu, scope='layer_name')
# 还可以使用end_points字典保存隐含层输出
end_points = slim.utils.convert_collection_to_dict("collection_name")
end_points[sc.name + "/layer_name'] = conv1
然后,在优化器优化代码处定义需要训练的scope
loss = #定义好你的loss
# 从scope获得需要训练的变量表
train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "model_name")
# 将var_list设定为刚才获得的变量表
opt = tf.train.AdamOptimizer(learning_rate=1e-5).minimize(loss, var_list=train_vars)
最后,导入预训练的权重
# 导入权重前要进行变量初始化
sess.run(tf.global_variables_initializer())
# 错误?导入预训练的权重(导入后会被初始化覆盖?)
# tf.train.init_from_checkpoint("model_name.ckpt", {"model_name/":"model_name/"})
# 更好的导入权重的方法
saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, "model_name"))
saver_sid.restore(sess, "model_name.ckpt")
评论已关闭