📜  Tensorflow.js tf.Variable 类 .assign() 方法

📅  最后修改于: 2022-05-13 01:56:39.700000             🧑  作者: Mango

Tensorflow.js tf.Variable 类 .assign() 方法

TensorFlow.js 是由 Google 设计的开源 JavaScript 库,用于开发机器学习模型和深度学习神经网络。 tf.Variable 类用于在 Tensorflow 中创建、跟踪和管理变量。 Tensorflow 变量表示可以通过对其运行操作来更改其值的张量。

assign()是 Variable 类中可用的方法,用于将新的 tf.Tensor 分配给变量。新值必须与旧变量值具有相同的形状和 dtype。

句法:

assign(newValue)

参数:

  • newValue:它是一个 tf.Tensor 类型的对象。这个新值被分配给变量。

返回值:返回void。

示例 1:此示例首先创建一个具有初始值的新 tf.Variable 对象,然后为该变量分配一个具有相同形状和 dtype 的新值作为初始值。

Javascript
// Defining tf.Tensor object for initial value
initialValue = tf.tensor([[1, 2, 3]])
  
// Defining tf.Variable object 
let x = new tf.Variable(initialValue);
  
// Printing variables dtype
console.log("dtype:",x.dtype)
  
// Printing variable shape
console.log("shape:",x.shape)
  
// Printing the tf.Variable object
x.print()
  
// Defining new tf.Tensor object of same 
// shape and dtype as initial tf.Tensor
newValue = tf.tensor([[5, 8, 10]])
  
// Assigning new value to the variable
x.assign(newValue)
  
// Printing the tf.Variable object
x.print()


Javascript
// Defining tf.Tensor object for initial value
initialValue = tf.tensor([[1, 2],[3, 4]])
  
// Defining tf.Variable object 
let x = new tf.Variable(initialValue);
  
// Printing variables dtype
console.log("dtype:",x.dtype)
  
// Printing variable shape
console.log("shape:",x.shape)
  
// Printing the tf.Variable object
x.print()
  
// Defining new tf.Tensor object of same 
// shape as initial tf.Tensor
newValue = tf.tensor([[5, 6],[10, 11]])
  
// Assigning new value to the variable
x.assign(newValue)
  
// Printing the tf.Variable object
x.print()


输出:

dtype: float32
shape: 1,3
Tensor
     [[1, 2, 3],]
Tensor
     [[5, 8, 10],]

示例 2:此示例首先创建一个具有初始值的新 tf.Variable 对象,然后尝试为该变量分配一个具有不同形状的新值作为初始值。这将给出一个错误。

Javascript

// Defining tf.Tensor object for initial value
initialValue = tf.tensor([[1, 2],[3, 4]])
  
// Defining tf.Variable object 
let x = new tf.Variable(initialValue);
  
// Printing variables dtype
console.log("dtype:",x.dtype)
  
// Printing variable shape
console.log("shape:",x.shape)
  
// Printing the tf.Variable object
x.print()
  
// Defining new tf.Tensor object of same 
// shape as initial tf.Tensor
newValue = tf.tensor([[5, 6],[10, 11]])
  
// Assigning new value to the variable
x.assign(newValue)
  
// Printing the tf.Variable object
x.print()

输出:

dtype: float32
shape: 2,2
Tensor
    [[1, 2],
     [3, 4]]
Tensor
    [[5 , 6 ],
     [10, 11]]

参考: https ://js.tensorflow.org/api/latest/#tf.Variable.assign