📅  最后修改于: 2023-12-03 14:51:11.355000             🧑  作者: Mango
Tensorflow 提供了将模型保存到硬盘和从硬盘加载模型的功能。这个功能对于在训练模型时定期保存模型状态非常有用,这样在模型中断或失败时可以恢复模型。它还可以用于在不重新训练模型的情况下重新使用模型。
在 Tensorflow 中,可以使用 tf.train.Saver
实例来保存和加载模型。下面详细介绍这个功能。
要保存模型,首先必须创建 tf.train.Saver
实例。通常,这个实例在创建模型时被创建,例如:
import tensorflow as tf
# 假设这个变量是你训练模型的一部分
my_variable = tf.Variable([1.0, 2.0], name="my_variable")
# 现在创建一个 saver 实例
saver = tf.train.Saver()
注意,Tensorflow 中所有的变量都必须有一个名字,因为这个名字是用来保存和加载变量值的。
要保存模型,可以通过 Saver.save
方法执行:
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# 训练你的模型
# ...
# 现在保存你的模型
save_path = saver.save(sess, "/tmp/model.ckpt")
print("Model saved in path: %s" % save_path)
这个操作将序列化图形和所有变量值到文件 /tmp/model.ckpt
。在保存模型时,你必须在 sess.run
块中运行 tf.global_variables_initializer()
,以确保所有变量都被初始化。
save_path
是返回的字符串,其中包含完整的路径和文件名。这个字符串可以在稍后加载模型时使用。
要加载模型,必须按照与保存模型相同的方式定义变量。然后在定义 Saver 实例和会话时,通过给 Saver
构造函数传递变量的列表来指定要加载的变量。
例如,如果要加载前面例子中保存的模型,可以执行:
import tensorflow as tf
# 定义与保存模型相同的变量
my_variable = tf.Variable([0.0, 0.0], name="my_variable")
# 创建一个 Saver 实例
saver = tf.train.Saver([my_variable])
# 使用 saver.restore() 方法加载模型
with tf.Session() as sess:
# 加载以前保存的模型
saver.restore(sess, "/tmp/model.ckpt")
print("Model restored.")
# 现在开始对模型进行评估
# ...
注意,必须在定义变量时给它们相同的名字才能正常加载模型。此外,被加载的变量的形状和类型也必须匹配。
在上面的例子中,我们指定了 my_variable
变量,但是实际上你可能有更多的变量,你可以将它们作为列表传递给构造函数,例如:
saver = tf.train.Saver([my_first_variable, my_second_variable, my_third_variable])
Tensorflow 还提供了一个方便的函数 tf.trainable_variables
来获取所有训练模型参数的列表。这可以用来方便地创建 Saver 实例。
例如:
import tensorflow as tf
# 创建变量
my_variable1 = tf.Variable([1.0, 2.0], name="my_variable1")
my_variable2 = tf.Variable([3.0, 4.0], name="my_variable2")
# 创建 Saver 实例
saver = tf.train.Saver(tf.trainable_variables())
# 保存模型
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
save_path = saver.save(sess, "/tmp/model.ckpt")
print("Model saved in path: %s" % save_path)
# 加载模型
with tf.Session() as sess:
# 定义与保存模型相同的变量
my_variable1 = tf.Variable([0.0, 0.0], name="my_variable1")
my_variable2 = tf.Variable([0.0, 0.0], name="my_variable2")
# 创建 Saver 实例
saver = tf.train.Saver(tf.trainable_variables())
# 加载以前保存的模型
saver.restore(sess, "/tmp/model.ckpt")
print("Model restored.")
在这个例子中,我们使用 tf.trainable_variables
函数来获取在变量上执行优化的所有变量的列表。然后我们将这个列表传递给 Saver
构造函数,这样就不需要手动指定每个变量了。
当我们加载模型时,我们也定义了一个与原来相同的变量,但是初始化为 0。然后我们使用 tf.trainable_variables
和 Saver
实例来加载模型参数。