📅  最后修改于: 2023-12-03 14:47:56.187000             🧑  作者: Mango
TensorFlow.js 是由 Google 开发的用于在浏览器中训练和部署机器学习模型的库。它有许多强大的功能,包括对现有 TensorFlow 模型的加载和使用,使用 JavaScript 进行模型的训练和推理,以及对图像和音频数据的处理和分析。
在使用 TensorFlow.js 之前,需要使用 npm 或 yarn 进行安装。可以在 Node.js 应用程序中使用它,也可以将其引入到浏览器中。
以下是安装 TensorFlow.js 的示例代码:
# 使用 npm 安装 TensorFlow.js
npm install @tensorflow/tfjs
# 使用 yarn 安装 TensorFlow.js
yarn add @tensorflow/tfjs
要在浏览器中使用 TensorFlow.js,可以使用以下 CDN 引入代码:
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.7.0"></script>
TensorFlow.js 支持从 Keras 或 TensorFlow SavedModel 加载训练好的模型。以下是加载 TensorFlow 模型的示例代码:
import * as tf from '@tensorflow/tfjs';
const model = await tf.loadLayersModel('path/to/model.json');
TensorFlow.js 提供了一组 API,可以在 JavaScript 中训练模型。例如,通过以下代码,可以创建一个简单的线性回归模型并训练它:
import * as tf from '@tensorflow/tfjs';
const model = tf.sequential();
model.add(tf.layers.dense({ units: 1, inputShape: [1] }));
model.compile({ loss: 'meanSquaredError', optimizer: 'sgd' });
const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]);
const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]);
await model.fit(xs, ys, {
epochs: 10,
callbacks: {
onEpochEnd: (_, logs) => console.log(`Loss: ${logs.loss}`),
},
});
TensorFlow.js 提供了可以在浏览器中执行模型推理的 API。例如,以下代码演示了如何使用 MNIST 手写数字识别模型对图像进行分类:
import * as tf from '@tensorflow/tfjs';
import { loadGraphModel } from '@tensorflow/tfjs-converter';
const model = await loadGraphModel('path/to/mnist/model.json');
const image = tf.browser.fromPixels(document.getElementById('image'));
const logits = model.predict(image);
const prediction = tf.argMax(logits, -1).dataSync()[0];
console.log(`Prediction: ${prediction}`);
TensorFlow.js 是一种非常有用的工具,可以在浏览器中构建和部署机器学习模型。它提供了许多功能,可让您轻松加载现有模型、训练新模型并执行模型推理。