📅  最后修改于: 2023-12-03 15:35:17.817000             🧑  作者: Mango
在使用 TensorFlow.js 进行机器学习建模时,常常需要对张量的形状进行转换。这个时候,就可以使用 tf.reshape()
函数进行操作。
tf.reshape(x, shape)
函数接受两个参数:
x
:需要进行形状转换的张量。shape
:目标形状,是一个由整数构成的数组或张量。这个函数会返回一个形状为 shape
的新张量,新张量与原张量的元素数量应该相等。
需要注意的是,新的张量与原张量共享存储空间,在进行视图修改(view modification)时会引起原张量的变化。
以一个二维张量 x
为例,形状为 (2, 3)
:
const x = tf.tensor([[1, 2, 3], [4, 5, 6]]);
console.log(x.shape); // 输出 [2, 3]
我们可以使用 tf.reshape()
函数将 x
转换为形状为 (3, 2)
的新张量 y
:
const y = tf.reshape(x, [3, 2]);
console.log(y.shape); // 输出 [3, 2]
console.log(y.arraySync()); // 输出 [[1, 2], [3, 4], [5, 6]]
需要注意的是,y
与 x
共享存储空间。我们可以修改 y
中的元素,从而导致 x
中相应位置的元素也发生变化:
y.buffer().set(999, 1, 1);
console.log(x.arraySync()); // 输出 [[1, 2, 3], [4, 999, 6]]
tf.reshape()
函数支持的形状限制很少,基本上可以使用任何非负整数数组作为目标形状。但是需要保证,目标形状中的元素数量应该与原张量的元素数量相等,否则会抛出错误。
此外,由于张量的形状对于一些操作(例如卷积)有特定要求,因此在进行形状转换时需要特别注意。一些常见的形状转换可以参考以下表格:
| 原形状 | 目标形状 |
| ---------- | ----------- |
| [a, b, c]
| [-1, c]
|
| [a, b, c]
| [a * b, c]
|
| [a, b, c]
| [a, b * c]
|
| [a, b, c]
| [b, a, c]
|
tf.reshape()
函数是 TensorFlow.js 中非常实用的形状转换函数,经常在机器学习建模过程中使用。使用这个函数时,需要注意形状的合法性以及共享存储空间的问题。