📜  Tensorflow.js tf.split()函数(1)

📅  最后修改于: 2023-12-03 15:35:17.987000             🧑  作者: Mango

Tensorflow.js tf.split()函数

介绍

Tensorflow.js的tf.split()函数用于在某个维度上将张量切割成多个子张量。该函数接受三个参数:源张量、切割次数(整数)和切割维度(整数)。返回一个张量数组,每个张量是源张量在指定维度上的一个切片。

语法
tf.split(x, numOrSizeSplits, axis)

参数说明:

  • x: 源张量。
  • numOrSizeSplits: 切割次数,可以是一个整数,表示等分成多少份,也可以是一个数组,表示要切割的大小。
  • axis: 切割维度,取值范围为[0, x.rank)。
示例
const x = tf.tensor2d([[1, 2], [3, 4], [5, 6], [7, 8]]);
const [a, b] = tf.split(x, 2, 0);
a.print();
b.print();

输出:

Tensor
    [[1, 2],
     [3, 4]]
Tensor
    [[5, 6],
     [7, 8]]

本例中,我们将一个2x4的矩阵分成了2份,分别为2x2的矩阵。

当切割次数为整数时,等分成多少份就由该参数决定。

const x = tf.tensor2d([[1, 2, 3], [4, 5, 6], [7, 8, 9]]);
const [a, b, c] = tf.split(x, 3, 1);
a.print();
b.print();
c.print();

输出:

Tensor
    [[1],
     [4],
     [7]]
Tensor
    [[2],
     [5],
     [8]]
Tensor
    [[3],
     [6],
     [9]]

在本例中,我们将一个3x3的矩阵分成了3份,分别为3x1的矩阵。

当切割次数为数组时,表示按照数组中的大小切割。

const x = tf.tensor2d([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]);
const [a, b, c] = tf.split(x, [1, 2, 1], 1);
a.print();
b.print();
c.print();

输出:

Tensor
    [[1],
     [5],
     [9]]
Tensor
    [[2, 3],
     [6, 7],
     [10, 11]]
Tensor
    [[4],
     [8],
     [12]]

在本例中,我们将一个3x4的矩阵切成了3份,分别为3x1、3x2、3x1的矩阵。

以上三个示例分别演示了切割次数为整数、数组的情况。在实际使用中,需要根据具体需求来选择切割次数。