📜  Tensorflow.js tf.GraphModel 类 .save() 方法(1)

📅  最后修改于: 2023-12-03 14:47:55.017000             🧑  作者: Mango

TensorFlow.js中tf.GraphModel类 .save()方法

TensorFlow.js是Google为JavaScript开发者提供的开源机器学习库,提供了许多用于构建和训练机器学习模型的API。其中,tf.GraphModel类提供了一个易于使用的接口来加载和使用预训练的TensorFlow模型。本文将重点介绍tf.GraphModel类的.save()方法。

.save()方法

.save()方法是tf.GraphModel类中的一个方法,用于将模型保存到本地或云端。其语法如下:

await tf.GraphModel.prototype.save(
    ioHandler: tf.io.IOHandler, 
    includeOptimizer: boolean = false
): Promise<tf.io.SaveResult>

其中,参数ioHandler是用于写入和读取模型的ioHandler对象。它可以是以下类型的一种:

  • tf.io browser IOHandler对象,用于在浏览器中保存和读取模型
  • tf.io node IOHandler对象,用于在Node.js环境中保存和读取模型
  • tf.io HTTPRequestIOHandler对象,用于使用HTTP协议从服务器获取或存储模型

参数includeOptimizer决定是否应该将优化器的状态一起保存。默认为false

该方法返回一个Promise对象,在保存完成后解析为一个tf.io.SaveResult对象,其中包含了关于模型保存位置和成功/失败的信息。

示例

以下代码段展示了如何使用.save()方法将预训练模型保存在本地:

// 引入模型
const model = await tf.loadGraphModel('model.json');

// 保存模型
const saveResult = await model.save('downloads://my-model');
console.log(`模型保存成功,保存路径为 ${saveResult.modelArtifactsInfo['my-model'].weightDataUrl}`);

以上代码段中,我们首先使用tf.loadGraphModel()方法加载了一个已经训练好的模型。然后使用.save()方法将该模型保存到本地。需要注意的是,'downloads://my-model'指定了模型保存的本地路径,其中'downloads://'前缀表示浏览器的下载文件夹。保存成功后,我们可以通过saveResult.modelArtifactsInfo['my-model'].weightDataUrl获取该模型在本地的绝对路径。

如果我们要将模型保存在云端存储中,可以使用tf.io.browserHTTPRequest()创建一个HTTPRequestIOHandler对象,并将其作为参数传递给.save()方法:

// 引入模型
const model = await tf.loadGraphModel('model.json');

// 将模型保存到云端存储
const URL = 'https://my-bucket.storage.googleapis.com/';
const ioHandler = tf.io.browserHTTPRequest(URL);
const saveResult = await model.save(ioHandler);
console.log(`模型保存成功,保存路径为 ${saveResult.modelArtifactsInfo['']}`);

使用以上代码,我们可以将模型保存在谷歌云存储(Google Cloud Storage)中,并得到该模型在云端存储中的绝对路径。

总结

.save()方法是tf.GraphModel类中的一个方法,用于将模型保存到本地或云端。它需要一个ioHandler对象作为参数,返回一个Promise对象并在保存完成时解析为一个tf.io.SaveResult对象。.save()方法是一个非常有用的方法,它允许我们保存已经训练好的模型,并在需要的时候重新加载和使用它们。