📅  最后修改于: 2023-12-03 14:47:55.017000             🧑  作者: Mango
TensorFlow.js是Google为JavaScript开发者提供的开源机器学习库,提供了许多用于构建和训练机器学习模型的API。其中,tf.GraphModel
类提供了一个易于使用的接口来加载和使用预训练的TensorFlow模型。本文将重点介绍tf.GraphModel
类的.save()
方法。
.save()
方法.save()
方法是tf.GraphModel
类中的一个方法,用于将模型保存到本地或云端。其语法如下:
await tf.GraphModel.prototype.save(
ioHandler: tf.io.IOHandler,
includeOptimizer: boolean = false
): Promise<tf.io.SaveResult>
其中,参数ioHandler
是用于写入和读取模型的ioHandler对象。它可以是以下类型的一种:
参数includeOptimizer
决定是否应该将优化器的状态一起保存。默认为false
。
该方法返回一个Promise对象,在保存完成后解析为一个tf.io.SaveResult
对象,其中包含了关于模型保存位置和成功/失败的信息。
以下代码段展示了如何使用.save()
方法将预训练模型保存在本地:
// 引入模型
const model = await tf.loadGraphModel('model.json');
// 保存模型
const saveResult = await model.save('downloads://my-model');
console.log(`模型保存成功,保存路径为 ${saveResult.modelArtifactsInfo['my-model'].weightDataUrl}`);
以上代码段中,我们首先使用tf.loadGraphModel()
方法加载了一个已经训练好的模型。然后使用.save()
方法将该模型保存到本地。需要注意的是,'downloads://my-model'
指定了模型保存的本地路径,其中'downloads://'
前缀表示浏览器的下载文件夹。保存成功后,我们可以通过saveResult.modelArtifactsInfo['my-model'].weightDataUrl
获取该模型在本地的绝对路径。
如果我们要将模型保存在云端存储中,可以使用tf.io.browserHTTPRequest()
创建一个HTTPRequestIOHandler对象,并将其作为参数传递给.save()
方法:
// 引入模型
const model = await tf.loadGraphModel('model.json');
// 将模型保存到云端存储
const URL = 'https://my-bucket.storage.googleapis.com/';
const ioHandler = tf.io.browserHTTPRequest(URL);
const saveResult = await model.save(ioHandler);
console.log(`模型保存成功,保存路径为 ${saveResult.modelArtifactsInfo['']}`);
使用以上代码,我们可以将模型保存在谷歌云存储(Google Cloud Storage)中,并得到该模型在云端存储中的绝对路径。
.save()
方法是tf.GraphModel
类中的一个方法,用于将模型保存到本地或云端。它需要一个ioHandler对象作为参数,返回一个Promise对象并在保存完成时解析为一个tf.io.SaveResult
对象。.save()
方法是一个非常有用的方法,它允许我们保存已经训练好的模型,并在需要的时候重新加载和使用它们。