Tensorflow.js tf.scatterND()函数
Tensorflow.js 是一个由谷歌开发的开源库,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。
.scatterND()函数用于根据所述索引通过对所述形状张量的零张量内的各个切片或值的分散更新来形成不同的张量。此外,此函数是 tf.gatherND()函数的否定,该函数从指定的张量中获取切片或值。
句法:
tf.scatterND(indices, updates, shape)
参数:
- 指数:它是规定的张量,它持有输出张量的指数,它可以是 tf.Tensor、TypedArray 或 Array 类型。
- 更新:保存索引值的是声明的张量,它可以是 tf.Tensor、TypedArray 或 Array 类型。
- 形状:它是输出张量的规定形状,类型为 number[]。
返回值:返回 tf.Tensor 对象。
示例 1:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Defining indices, updates and shape
const ind = tf.tensor2d([6, 5, 2], [3, 1], 'int32');
const updat = tf.tensor1d([1, 2, 3]);
const shp = [6];
// Calling tf.scatterND() method and
// Printing output
tf.scatterND(ind, updat, shp).print();
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Calling tf.scatterND() method and
// Printing output
tf.scatterND(tf.tensor2d([5.4, 2.4], [2, 1], 'int32'),
tf.tensor1d([1.8, 4.2]),
[4]).print();
输出:
Tensor
[0, 0, 3, 0, 0, 2]
示例 2:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Calling tf.scatterND() method and
// Printing output
tf.scatterND(tf.tensor2d([5.4, 2.4], [2, 1], 'int32'),
tf.tensor1d([1.8, 4.2]),
[4]).print();
输出:
Tensor
[0, 0, 4.1999998, 0]
参考: https://js.tensorflow.org/api/latest/#scatterND