📜  Tensorflow.js tf.Variable 类(1)

📅  最后修改于: 2023-12-03 14:47:56.134000             🧑  作者: Mango

TensorFlow.js tf.Variable 类

在 TensorFlow.js 中,tf.Variable 类用于创建可训练的变量。变量是一个可以在训练过程中更新的张量。tf.Variable 类是 tf.Tensor 类的子类,它可以像普通张量一样进行操作,并支持在训练过程中进行更新。

创建 tf.Variable 对象

我们可以使用 tf.variable()tf.Variable() 方法来创建一个 tf.Variable 对象。这两个方法的作用是一样的,tf.variable()tf.Variable() 方法的别名。

// 使用 tf.variable() 方法创建一个 tf.Variable 对象
const a = tf.variable(tf.tensor([1, 2, 3]));

// 使用 tf.Variable() 方法创建一个 tf.Variable 对象
const b = tf.Variable(tf.tensor([4, 5, 6]));

上面的代码创建了两个 tf.Variable 对象 ab,它们分别包含了向量 [1, 2, 3][4, 5, 6]。注意,创建变量时必须要指定其初始值,这里我们使用 tf.tensor() 方法来创建张量对象。

对 tf.Variable 对象进行操作

tf.Variable 对象可以像普通张量一样进行操作,也可以在训练过程中更新其值。

// 创建两个变量
const a = tf.variable(tf.tensor([1, 2, 3]));
const b = tf.variable(tf.tensor([4, 5, 6]));

// 相加并输出结果
const c = tf.add(a, b);
c.print();  // 输出:[5, 7, 9]

// 更新变量 a 的值
a.assign(tf.tensor([4, 5, 6]));

// 相加并输出结果
const d = tf.add(a, b);
d.print();  // 输出:[8, 10, 12]

上面的代码中,我们创建了两个变量 ab,然后对它们进行了加法操作。我们可以通过 print() 方法输出 tf.Tensor 对象的值。

使用 assign() 方法可以更新变量的值,这里我们将变量 a 的值更新为 [4, 5, 6]

使用 tf.Variable 对象进行模型训练

在 TensorFlow.js 中,我们可以使用 tf.Variable 对象来定义模型的权重。在训练过程中,我们可以通过调整权重的值来优化模型的性能。

// 创建一个线性回归模型
const model = tf.sequential();
model.add(tf.layers.dense({
  units: 1,
  inputShape: [1],
}));

// 指定损失函数和优化器
model.compile({
  loss: 'meanSquaredError',
  optimizer: 'sgd',
});

// 准备数据
const x = tf.tensor([[1], [2], [3], [4]]);
const y = tf.tensor([[1.2], [2.4], [3.6], [4.8]]);

// 训练模型
const epochs = 10;
const batchSize = 2;
await model.fit(x, y, {
  epochs: epochs,
  batchSize: batchSize,
  callbacks: {
    onEpochEnd: async (epoch, logs) => {
      console.log(`Epoch ${epoch}: loss = ${logs.loss.toFixed(4)}`);
    }
  }
});

// 输出模型的权重
model.getWeights()[0].print();

上面的代码中,我们创建了一个简单的线性回归模型,使用 tf.sequential() 方法创建一个顺序模型,然后使用 tf.layers.dense() 方法添加了一个全连接层,其中 units 参数指定了该层的神经元个数,inputShape 参数指定了输入数据的形状。

我们使用 compile() 方法对模型进行配置,指定了损失函数和优化器。

我们准备了一些数据,使用 fit() 方法对模型进行训练,在每次迭代中输出了损失函数的值。

最后,使用 getWeights() 方法获取模型的权重,并通过 print() 方法输出了第一层的权重矩阵。这里我们可以看到,模型的权重得到了训练,其值已经发生了变化。