Tensorflow.js tf.GraphModel 类 .save() 方法
简介: Tensorflow.js 是谷歌开发的一个开源库,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。
.save()函数用于保存所述GraphModel的结构和/或权重。
笔记:
- IOHandler是一个对象,它拥有与指定的准确签名有关的保存方法。
- save方法控制顺序数据的累积或传输,即描述模型拓扑的工件以及在特定介质上或通过特定介质的权重,如本地存储文件、文件下载、Web 浏览器中的IndexedDB以及 HTTP 请求服务器。
- TensorFlow.js 支持IOHandler实现,支持许多重复使用的保存介质,如tf.io.browserDownloads()和tf.io.browserLocalStorage 。
- 此外,此方法还允许我们应用特定类型的IOHandler ,例如类似 URL 的字符串技术,例如'localstorage://' 和 'indexeddb://' 。
句法:
save(handlerOrURL, config?)
参数:
- handlerOrURL:声明的IOHandler实例,或者类似于基于设计的字符串技术的 URL,支持IOHandler 。它的类型为io.IOHandler|字符串。
- config:指定的选项以保存指定的模型。它是可选的并且是对象类型。它下面有两个参数,如下所示:
- trainableOnly:它说明是否只保存所述模型的可训练权重,忽略不可训练的权重。它是布尔类型,默认为 false。
- includeOptimizer:它说明是否将存储所述优化器。它是布尔类型,默认为 false。
返回值:返回 io.SaveResult 的 promise。
示例 1:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Defining model url
const model_Url =
'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json';
// Calling the loadGraphModel() method
const mymodel = await tf.loadGraphModel(model_Url);
// Calling save() method
const output = await mymodel.save('downloads://mymodel');
// Printing output
console.log(output)
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Calling the loadGraphModel() method
const mymodel = await tf.loadGraphModel(
'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json');
// Calling save() method with all its
// parameters
const output = await mymodel.save('downloads://mymodel', true, true);
// Printing output
console.log(JSON.stringify(output))
输出:
{
"modelArtifactsInfo": {
"dateSaved": "2021-08-19T12:00:15.603Z",
"modelTopologyType": "JSON",
"modelTopologyBytes": 90375,
"weightSpecsBytes": 15791,
"weightDataBytes": 13984940
}
}
示例 2:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Calling the loadGraphModel() method
const mymodel = await tf.loadGraphModel(
'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json');
// Calling save() method with all its
// parameters
const output = await mymodel.save('downloads://mymodel', true, true);
// Printing output
console.log(JSON.stringify(output))
输出:
{"modelArtifactsInfo":{"dateSaved":"2021-08-19T12:05:35.906Z",
"modelTopologyType":"JSON","modelTopologyBytes":90375,
"weightSpecsBytes":15791,"weightDataBytes":13984940}}
参考: https://js.tensorflow.org/api/latest/#tf.GraphModel.save