Tensorflow.js tf.LayersModel 类 .save() 方法
Tensorflow.js 是由谷歌开发的开源库,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。
.save()函数用于保存所述LayersModel的结构和/或权重。
笔记:
- IOHandler 是一个对象,它拥有与指定的准确签名有关的保存方法。
- save 方法控制顺序数据的累积或传输,即描述模型拓扑的工件以及在特定介质上或通过特定介质的权重,如本地存储文件、文件下载、Web 浏览器中的 IndexedDB 以及 HTTP 请求服务器。
- TensorFlow.js 支持 IOHandler 实现,支持许多重复使用的保存介质,例如 tf.io.browserDownloads() 和 tf.io.browserLocalStorage。
- 此外,此方法还允许我们应用特定类型的 IOHandler,例如类似 URL 的字符串技术,例如 'localstorage://' 和 'indexeddb://'。
句法:
save(handlerOrURL, config?)
参数:
- handlerOrURL:IOHandler的指定实例,或者类似于支持 IOHandler 的基于设计的字符串技术的 URL。它的类型为 io.IOHandler|字符串。
- config:指定的选项以保存指定的模型。它是可选的并且是对象类型。它下面有两个参数,如下所示:
- trainableOnly:它说明是否只保存所述模型的可训练权重,忽略不可训练的权重。它是布尔类型,默认为 false。
- includeOptimizer:它说明是否将存储所述优化器。它是布尔类型,默认为 false。
返回值:返回 io.SaveResult 的 promise。
示例 1:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Creating a model
const model = tf.sequential(
{layers: [tf.layers.dense({units: 2, inputShape: [2]})]});
// Calling save() method
const res = await model.save('localstorage://mymodel');
// Printing output
console.log(res)
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Creating a model
const model = tf.sequential(
{layers: [tf.layers.dense({units: 2, inputShape: [2]})]});
// Calling save() method
const res = await model.save('localstorage://mymodel', true, true);
// Printing output
console.log(JSON.stringify(res))
输出:
{
"modelArtifactsInfo": {
"dateSaved": "2021-08-23T09:53:28.198Z",
"modelTopologyType": "JSON",
"modelTopologyBytes": 612,
"weightSpecsBytes": 125,
"weightDataBytes": 24
}
}
示例 2:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Creating a model
const model = tf.sequential(
{layers: [tf.layers.dense({units: 2, inputShape: [2]})]});
// Calling save() method
const res = await model.save('localstorage://mymodel', true, true);
// Printing output
console.log(JSON.stringify(res))
输出:
{"modelArtifactsInfo":{"dateSaved":"2021-08-23T09:55:33.044Z",
"modelTopologyType":"JSON","modelTopologyBytes":612,
"weightSpecsBytes":125,"weightDataBytes":24}}
参考: https://js.tensorflow.org/api/latest/#tf.LayersModel.save