Tensorflow.js tf.LayersModel 类 .compile() 方法
Tensorflow.js 是谷歌开发的一个开源库,用于在浏览器或节点环境中运行机器学习模型和深度学习神经网络。
.compile()函数为训练和评估过程配置和制作模型。通过调用 .compile()函数,我们为模型准备了优化器、损失和指标。 .compile()函数将参数对象作为参数。
注意:如果您在未编译的模型上调用.fit()或.evaluate()函数,则程序将抛出错误。
句法:
tf.model.compile({optimizer, loss}, metrics=[])
参数:
- optimizer:强制参数。它接受 tf.train.Optimizer 的对象或优化器的字符串名称。以下是优化器的字符串名称—— “sgd” 、 “adam” 、 “adamax” 、 “adadelta” 、 “adagrad” 、 “rmsprop” 、 “momentum” 。
- loss:强制参数。它接受损失类型的字符串值或字符串组。如果我们的模型有多个输出,我们可以通过传递一组损失在每个输出上使用不同的损失。模型将最小化的损失值将是所有单个损失的总和。以下是损失的字符串名称—— “meanSquaredError” 、 “meanAbsoluteError”等。
- 指标:它是一个可选参数。它接受模型在训练和测试阶段要评估的指标列表。通常,我们使用metrics=['accuracy'] 。要为多输出模型的不同输出指定不同的指标,我们还可以传递字典。
返回值:因为它准备模型进行训练,所以它不返回任何东西。 (即返回类型为 void)
示例 1:在此示例中,我们将创建一个简单的模型,并将传递优化器和损失参数的值。在这里,我们将优化器用作“adam” ,将损失用作“meanSquaredError” 。
Javascript
// Importing the tensorflow.js library
const tf = require("@tensorflow/tfjs");
// define the model
const model = tf.sequential({
layers: [tf.layers.dense({ units: 1, inputShape: [10] })],
});
// compile the model
// using "adam" optimizer and "meanSquaredError" loss
model.compile({ optimizer: "adam", loss: "meanSquaredError" });
// evaluate the model which was compiled above
// computation is done in batches of size 4
const result = model.evaluate(tf.ones([8, 10]), tf.ones([8, 1]), {
batchSize: 4,
});
// print the result
result.print();
Javascript
// Importing the tensorflow.js library
const tf = require("@tensorflow/tfjs");
// define the model
const model = tf.sequential({
layers: [tf.layers.dense({ units: 1, inputShape: [10] })],
});
// compile the model
// using "adam" optimizer, "meanSquaredError" loss and "accuracy" metrics
model.compile(
{ optimizer: "adam", loss: "meanSquaredError" },
(metrics = ["accuracy"])
);
// evaluate the model which was compiled above
// computation is done in batches of size 4
const result = model.evaluate(tf.ones([8, 10]), tf.ones([8, 1]), {
batchSize: 4,
});
// print the result
result.print();
Javascript
// Importing the tensorflow.js library
const tf = require("@tensorflow/tfjs");
// define the model
const model = tf.sequential({
layers: [tf.layers.dense({ units: 1, inputShape: [10] })],
});
// compile the model
// using "adam" optimizer, "meanSquaredError" loss and "accuracy" metrics
model.compile(
{ optimizer: "sgd", loss: "meanAbsoluteError" },
(metrics = ["precision"])
);
// evaluate the model which was compiled above
// computation is done in batches of size 4
const result = model.evaluate(tf.ones([8, 10]), tf.ones([8, 1]), {
batchSize: 4,
});
// print the result
result.print();
输出:
Tensor
2.6806342601776123
示例 2:在此示例中,我们将创建一个简单的模型,并将传递优化器、损失和指标参数的值。在这里,我们将优化器用作“sgd” ,将损失用作“meanAbsoluteError” ,将“准确度”用作指标。
Javascript
// Importing the tensorflow.js library
const tf = require("@tensorflow/tfjs");
// define the model
const model = tf.sequential({
layers: [tf.layers.dense({ units: 1, inputShape: [10] })],
});
// compile the model
// using "adam" optimizer, "meanSquaredError" loss and "accuracy" metrics
model.compile(
{ optimizer: "adam", loss: "meanSquaredError" },
(metrics = ["accuracy"])
);
// evaluate the model which was compiled above
// computation is done in batches of size 4
const result = model.evaluate(tf.ones([8, 10]), tf.ones([8, 1]), {
batchSize: 4,
});
// print the result
result.print();
输出:
Tensor
1.4847172498703003
示例 3:在此示例中,我们将创建一个简单的模型,并将传递优化器、损失和指标参数的值。在这里,我们将优化器用作“sgd” ,将损失用作“meanAbsoluteError” ,将“精度”用作指标。
Javascript
// Importing the tensorflow.js library
const tf = require("@tensorflow/tfjs");
// define the model
const model = tf.sequential({
layers: [tf.layers.dense({ units: 1, inputShape: [10] })],
});
// compile the model
// using "adam" optimizer, "meanSquaredError" loss and "accuracy" metrics
model.compile(
{ optimizer: "sgd", loss: "meanAbsoluteError" },
(metrics = ["precision"])
);
// evaluate the model which was compiled above
// computation is done in batches of size 4
const result = model.evaluate(tf.ones([8, 10]), tf.ones([8, 1]), {
batchSize: 4,
});
// print the result
result.print();
输出:
Tensor
1.507279634475708
参考: https://js.tensorflow.org/api/latest/#tf.LayersModel.compile