📜  TensorFlow.js 后端完整参考(1)

📅  最后修改于: 2023-12-03 15:20:35.733000             🧑  作者: Mango

TensorFlow.js 后端完整参考

TensorFlow.js 是一个基于 JavaScript 平台的机器学习框架,其中的后端部分负责处理数据的输入输出和神经网络的训练。这篇文章将提供 TensorFlow.js 后端的完整参考,包括模型创建、训练、推理等方面。

模型创建

TensorFlow.js 后端支持通过代码或者 Keras 模型转换两种方式创建模型。

通过代码创建模型

可以通过 tf.sequential()tf.layers 方法来创建序贯模型和层级模型。

序贯模型

const model = tf.sequential();

model.add(tf.layers.dense({ units: 10, inputShape: [inputSize] }));
model.add(tf.layers.dense({ units: 1 }));

层级模型

const inputLayer = tf.layers.input({ shape: [inputSize] });
const dense1 = tf.layers.dense({ units: 10 })(inputLayer);
const dense2 = tf.layers.dense({ units: 1 })(dense1);

const model = tf.model({ inputs: inputLayer, outputs: dense2 });
通过 Keras 模型转换创建模型
const model = await tf.loadModel('https://example.com/model.json')
模型训练

TensorFlow.js 后端提供了 model.fit() 方法用于模型的训练,可以通过指定参数实现不同的训练方式。

model.compile({ optimizer: 'sgd', loss: 'meanSquaredError' });
await model.fit(inputs, labels, { epochs: 10, batchSize: 32, shuffle: true });
模型推理

TensorFlow.js 后端提供了 model.predict() 方法用于进行模型推理,可以对输入数据进行预测。

const output = model.predict(inputs);
模型保存

TensorFlow.js 后端提供了 model.save() 方法用于保存模型到本地或者远程。

await model.save('localstorage://my-model');
await model.save('downloads://my-model');
await model.save('indexeddb://my-model');

await model.save('https://example.com/save');
其他方法

TensorFlow.js 后端还提供了一些其他的方法,如下表所示。

| 方法名 | 描述 | | --- | --- | | tf.data.csv() | 从 CSV 文件读取数据 | | tf.data.generator() | 生成数据 | | tf.data.array() | 从数组中读取数据 | | tf.losses | 损失函数 | | tf.metrics | 评估函数 | | tf.serialization | 序列化函数 |

以上就是 TensorFlow.js 后端的完整参考,可以根据实际需要使用其中的方法。