Tensorflow.js tf.squeeze()函数
Tensorflow.js 是一个由谷歌开发的开源库,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。
.squeeze()函数用于丢弃长度为 1 的尺寸,使其脱离指定的 tf.Tensor 的形状。
句法 :
tf.squeeze(x, axis?)
参数:
- x:是指定的被压缩的张量输入,可以是tf.Tensor、TypedArray或Array类型。
- 轴:它是一个可选参数,包含一个数字列表。如果指定,它只会压缩指定的尺寸。而且,这里的维度索引是从零开始的,压缩长度不是一的维度是个缺陷。
返回值:它返回 tf.Tensor 对象。
示例 1:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Defining tensor input elements
const y = tf.tensor([11, 76, -4, 6], [2, 2, 1]);
// Calling squeeze() method and
// Printing output
y.squeeze().print();
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Calling squeeze() method with
// all its parameter
var res = tf.squeeze(tf.tensor(
[2.1, 5.6, 8.6, 7.6],
[4, 1]), [1]
);
// Printing output
res.print();
输出:
Tensor
[[11, 76],
[-4, 6 ]]
示例 2:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Calling squeeze() method with
// all its parameter
var res = tf.squeeze(tf.tensor(
[2.1, 5.6, 8.6, 7.6],
[4, 1]), [1]
);
// Printing output
res.print();
输出:
Tensor
[2.0999999, 5.5999999, 8.6000004, 7.5999999]
参考: https://js.tensorflow.org/api/latest/#squeeze