Tensorflow.js tf.broadcastTo()函数
Tensorflow.js 是一个由谷歌开发的开源库,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。
.broadcastTo()函数用于将数组循环到 NumPy 风格的一致模型。
笔记:
- 这里,张量的形状等同于广播从最后到开始的形状。其中,张量形状的前缀为 1,只要它具有与广播形状相同的长度即可。
- 如果 input.shape[i]==shape[i],那么第 (i+1) 个轴以前与广播一致,如果 input.shape[i]==1 加上 shape[i] ==N,则指定的输入张量被该轴覆盖N次。
句法 :
tf.broadcastTo(x, shape)
参数:
- x:它是指定的张量输入,可以是 tf.Tensor、TypedArray 或 Array 类型。
- 形状:输入将被广播到的指定形状。
返回值:它返回 tf.Tensor 对象。
示例 1:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Defining tensor input elements
const y = tf.tensor1d([1, 2, 3, 4]);
// Calling broadcastTo() method and
// Printing output
tf.broadcastTo(y, [1, 4]).print();
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Calling broadcastTo() method and
// Printing output
tf.broadcastTo(tf.tensor1d([3.6, 5.8, 3.7, 1.4, 9.3, 10.5]),
[1, 6]).print();
输出:
Tensor
[[1, 2, 3, 4],]
示例 2:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Calling broadcastTo() method and
// Printing output
tf.broadcastTo(tf.tensor1d([3.6, 5.8, 3.7, 1.4, 9.3, 10.5]),
[1, 6]).print();
输出:
Tensor
[[3.5999999, 5.8000002, 3.7, 1.4, 9.3000002, 10.5],]
参考: https://js.tensorflow.org/api/latest/#broadcastTo