📅  最后修改于: 2023-12-03 15:35:17.987000             🧑  作者: Mango
Tensorflow.js的tf.split()函数用于在某个维度上将张量切割成多个子张量。该函数接受三个参数:源张量、切割次数(整数)和切割维度(整数)。返回一个张量数组,每个张量是源张量在指定维度上的一个切片。
tf.split(x, numOrSizeSplits, axis)
参数说明:
x
: 源张量。numOrSizeSplits
: 切割次数,可以是一个整数,表示等分成多少份,也可以是一个数组,表示要切割的大小。axis
: 切割维度,取值范围为[0, x.rank)。const x = tf.tensor2d([[1, 2], [3, 4], [5, 6], [7, 8]]);
const [a, b] = tf.split(x, 2, 0);
a.print();
b.print();
输出:
Tensor
[[1, 2],
[3, 4]]
Tensor
[[5, 6],
[7, 8]]
本例中,我们将一个2x4的矩阵分成了2份,分别为2x2的矩阵。
当切割次数为整数时,等分成多少份就由该参数决定。
const x = tf.tensor2d([[1, 2, 3], [4, 5, 6], [7, 8, 9]]);
const [a, b, c] = tf.split(x, 3, 1);
a.print();
b.print();
c.print();
输出:
Tensor
[[1],
[4],
[7]]
Tensor
[[2],
[5],
[8]]
Tensor
[[3],
[6],
[9]]
在本例中,我们将一个3x3的矩阵分成了3份,分别为3x1的矩阵。
当切割次数为数组时,表示按照数组中的大小切割。
const x = tf.tensor2d([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]);
const [a, b, c] = tf.split(x, [1, 2, 1], 1);
a.print();
b.print();
c.print();
输出:
Tensor
[[1],
[5],
[9]]
Tensor
[[2, 3],
[6, 7],
[10, 11]]
Tensor
[[4],
[8],
[12]]
在本例中,我们将一个3x4的矩阵切成了3份,分别为3x1、3x2、3x1的矩阵。
以上三个示例分别演示了切割次数为整数、数组的情况。在实际使用中,需要根据具体需求来选择切割次数。