Tensorflow.js tf.variableGrads()函数
Tensorflow.js 是一个由谷歌开发的开源库,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。
.variableGrads()函数用于计算和返回 f(x) 的梯度,与参数varList提供的可管理变量的所述列表进行比较。此外,如果未给出列表,则默认情况下它是所有可管理的变量。
句法:
tf.variableGrads(f, varList?)
参数:
- f:这是要执行的规定函数。其中,f() 必须返回一个标量。它的类型是 (() => tf.Scalar)。
- varList:它是用于计算梯度的变量列表,默认情况下它是所有可管理的变量。它是 tf.Variable[] 类型。
返回值:它返回类型为tf.Scalar的值,它还返回类型为 {[name: 字符串]: tf.Tensor} 的 grads。
示例 1:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Defining list of variables
const p = tf.variable(tf.tensor1d([9, 6]));
const q = tf.variable(tf.tensor1d([7, 8]));
// Defining tf.tensor1d
const r = tf.tensor1d([3, 4]);
// Defining the function that is to
// be executed
const fn = () => p.add(r.square()).mul(q.add(r)).sum();
// Calling tf.variableGrads method
const {val, grads} = tf.variableGrads(fn);
// Printing output
Object.keys(grads).forEach(
variable_Name => grads[variable_Name].print());
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Defining list of variables containing
// float values
const p = tf.variable(tf.tensor1d([3.1, 5.2]));
const q = tf.variable(tf.tensor1d([4.4, 6.7]));
// Defining tf.tensor1d with float values
const r = tf.tensor1d([7.1, 3.2]);
// Calling tf.variableGrads method
const {val, grads} = tf.variableGrads(
() => p.add(r.square()).mul(q.add(r)).sum());
// Printing output
Object.keys(grads).forEach(
variable_Name => grads[variable_Name].print());
输出:
Tensor
[10, 12]
Tensor
[18, 22]
示例 2:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Defining list of variables containing
// float values
const p = tf.variable(tf.tensor1d([3.1, 5.2]));
const q = tf.variable(tf.tensor1d([4.4, 6.7]));
// Defining tf.tensor1d with float values
const r = tf.tensor1d([7.1, 3.2]);
// Calling tf.variableGrads method
const {val, grads} = tf.variableGrads(
() => p.add(r.square()).mul(q.add(r)).sum());
// Printing output
Object.keys(grads).forEach(
variable_Name => grads[variable_Name].print());
输出:
Tensor
[11.5, 9.8999996]
Tensor
[53.5099983, 15.4400005]
参考: https://js.tensorflow.org/api/latest/#variableGrads