Tensorflow.js tf.expandDims()函数
Tensorflow.js 是一个由谷歌开发的开源库,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。
.expandDims()函数用于查找 tf.Tensor 对象,通过在其中插入维度来借助张量的形状扩展其秩。
句法:
tf.expandDims(x, axis?)
参数:
- x:指定张量输入,需要扩展大小,可以是tf.Tensor、TypedArray或Array类型。
- 轴:这是要插入形状 1 的规定尺寸索引。默认值为零,即第一个维度,它是数字类型。
返回值:返回 tf.Tensor 对象。
示例 1:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Defining tensor input and axis
const y = tf.tensor1d([5, 6, 7, 8]);
const axs = 1;
// Calling tf.expandDims() method and
// Printing output
y.expandDims(axs).print();
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Calling tf.expandDims() method and
// Printing output
tf.expandDims(tf.tensor1d([1.2, 3.5, 7.6, 9.7]), 1).print();
输出:
Tensor
[[5],
[6],
[7],
[8]]
示例 2:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Calling tf.expandDims() method and
// Printing output
tf.expandDims(tf.tensor1d([1.2, 3.5, 7.6, 9.7]), 1).print();
输出:
Tensor
[[1.2 ],
[3.5 ],
[7.5999999],
[9.6999998]]
参考: https://js.tensorflow.org/api/latest/#expandDims