Tensorflow.js tf.Variable 类 .assign() 方法
TensorFlow.js 是由 Google 设计的开源 JavaScript 库,用于开发机器学习模型和深度学习神经网络。 tf.Variable 类用于在 Tensorflow 中创建、跟踪和管理变量。 Tensorflow 变量表示可以通过对其运行操作来更改其值的张量。
assign()是 Variable 类中可用的方法,用于将新的 tf.Tensor 分配给变量。新值必须与旧变量值具有相同的形状和 dtype。
句法:
assign(newValue)
参数:
- newValue:它是一个 tf.Tensor 类型的对象。这个新值被分配给变量。
返回值:返回void。
示例 1:此示例首先创建一个具有初始值的新 tf.Variable 对象,然后为该变量分配一个具有相同形状和 dtype 的新值作为初始值。
Javascript
// Defining tf.Tensor object for initial value
initialValue = tf.tensor([[1, 2, 3]])
// Defining tf.Variable object
let x = new tf.Variable(initialValue);
// Printing variables dtype
console.log("dtype:",x.dtype)
// Printing variable shape
console.log("shape:",x.shape)
// Printing the tf.Variable object
x.print()
// Defining new tf.Tensor object of same
// shape and dtype as initial tf.Tensor
newValue = tf.tensor([[5, 8, 10]])
// Assigning new value to the variable
x.assign(newValue)
// Printing the tf.Variable object
x.print()
Javascript
// Defining tf.Tensor object for initial value
initialValue = tf.tensor([[1, 2],[3, 4]])
// Defining tf.Variable object
let x = new tf.Variable(initialValue);
// Printing variables dtype
console.log("dtype:",x.dtype)
// Printing variable shape
console.log("shape:",x.shape)
// Printing the tf.Variable object
x.print()
// Defining new tf.Tensor object of same
// shape as initial tf.Tensor
newValue = tf.tensor([[5, 6],[10, 11]])
// Assigning new value to the variable
x.assign(newValue)
// Printing the tf.Variable object
x.print()
输出:
dtype: float32
shape: 1,3
Tensor
[[1, 2, 3],]
Tensor
[[5, 8, 10],]
示例 2:此示例首先创建一个具有初始值的新 tf.Variable 对象,然后尝试为该变量分配一个具有不同形状的新值作为初始值。这将给出一个错误。
Javascript
// Defining tf.Tensor object for initial value
initialValue = tf.tensor([[1, 2],[3, 4]])
// Defining tf.Variable object
let x = new tf.Variable(initialValue);
// Printing variables dtype
console.log("dtype:",x.dtype)
// Printing variable shape
console.log("shape:",x.shape)
// Printing the tf.Variable object
x.print()
// Defining new tf.Tensor object of same
// shape as initial tf.Tensor
newValue = tf.tensor([[5, 6],[10, 11]])
// Assigning new value to the variable
x.assign(newValue)
// Printing the tf.Variable object
x.print()
输出:
dtype: float32
shape: 2,2
Tensor
[[1, 2],
[3, 4]]
Tensor
[[5 , 6 ],
[10, 11]]
参考: https ://js.tensorflow.org/api/latest/#tf.Variable.assign