📜  Tensorflow.js tf.squeeze()函数(1)

📅  最后修改于: 2023-12-03 15:20:35.577000             🧑  作者: Mango

Tensorflow.js tf.squeeze()函数介绍

在Tensorflow.js中,tf.squeeze()函数可以将张量中尺寸为1的维度进行压缩,从而得到一个尺寸更小的张量。

函数定义
tf.squeeze<T extends Tensor>(x: T, axis?: number[]): T;
  • x: 输入的张量。
  • axis: 可选参数,表示要进行压缩的维度。如果不传入该参数,则默认将所有尺寸为1的维度进行压缩。
使用示例
const input = tf.tensor([[[[1], [2]], [[3], [4]]]]);
const output = tf.squeeze(input);
console.log(output.shape); // [2, 2]

const input2 = tf.tensor([[[[1], [2]], [[3], [4]]]]);
const output2 = tf.squeeze(input2, [0, 3]);
console.log(output2.shape); // [2, 2]

const input3 = tf.tensor([[[[1], [2]], [[3], [4]]]]);
const output3 = tf.squeeze(input3, [2]);
console.log(output3.shape); // [1, 2, 2]
注意事项
  • 如果被压缩的维度不为1,则会报错。
  • 如果没有被压缩的维度,则会返回原张量,而不是一个标量。
  • 当传入axis参数时,该参数应当只包括要进行压缩的维度。若包含其他维度,则会报错。
  • 当传入axis参数时,压缩后的张量中不应当再包括axis中的维度。否则会报错。