📜  Tensorflow.js tf.stridedSlice()函数

📅  最后修改于: 2022-05-13 01:56:46.210000             🧑  作者: Mango

Tensorflow.js tf.stridedSlice()函数

Tensorflow.js 是由谷歌开发的开源库,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。

.stridedSlice()函数用于拉出指定输入张量的跨步部分。

注意: < 此函数用于从指定的输入张量中提取指定长度步幅的切片。从 begin 给定的位置开始,切片通过在所有测量值大于end之前将步幅附加到指定索引来继续。此外,步幅也可以是负数,这会导致反向切片。

句法:

tf.stridedSlice(x, begin, end, strides?, beginMask?, endMask?, 
ellipsisMask?, newAxisMask?, shrinkAxisMask?)

参数:

  • x:为了跨越切片而声明的张量。它可以是 tf.Tensor、TypedArray 或 Array 类型。
  • begin:切片开始的指定坐标。它的类型为 number[]。
  • end:切片结束的指定坐标。它的类型为 number[]。
  • strides:切片的规定长度。它是可选的,类型为 number[]。
  • beginMask:它是一个可选参数,类型为 number。如果 beginMask 的第 i 位是固定的,则忽略 begin[i],反之利用该维度的最大可达到范围。
  • endMask:它是一个可选参数,类型为 number。如果 endMask 的第 i 位是固定的,则忽略 end[i],反之利用该维度的最大可达范围。
  • ellipsisMask:它是数字类型的可选参数。
  • newAxisMask:它是一个数字类型的可选参数。
  • shrinkAxisMask:指定的位掩码,其中位i指示第 i 个指定必须缩小容量。开始结束应该指示长度为 1 的切片。它是可选的,类型为 number。

返回值:返回 tf.Tensor。

示例 1:

Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining tensor input
const tn = tf.tensor3d([5, 5, 5 ,7, 7, 7, 13, 
    13, 13, 14, 14, 14, 1, 1, 1, 2, 2, 2],
    [3, 3, 2]);
  
// Calling stridedSlice() method and 
// Printing output
tn.stridedSlice([5, 0, 5], [7, 5, 7], [2, 2, 2]).print();


Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining tensor input
const tn = tf.tensor3d([5, 5, 5 ,7, 7, 7, 13, 
    13, 13, 14, 14, 14, 1, 1, 1, 2, 2, 2],
    [3, 3, 2]);
  
// Defining all the parameters
const x = [-5, 5, 7];
const begin = [-7, 13, 14];
const end = [-7, 13, 13];
const strides = [1];
const beginMask = 2;
const endMask = 0;
const ellipsisMask = 2;
const shrinkAxisMask = 15;
  
// Calling stridedSlice() method and 
// Printing output
tn.stridedSlice(x, begin, end, strides, beginMask,
     endMask, ellipsisMask, shrinkAxisMask).print();


输出:

Tensor
    [[[1],
      [2]]]

示例 2:

Javascript

// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining tensor input
const tn = tf.tensor3d([5, 5, 5 ,7, 7, 7, 13, 
    13, 13, 14, 14, 14, 1, 1, 1, 2, 2, 2],
    [3, 3, 2]);
  
// Defining all the parameters
const x = [-5, 5, 7];
const begin = [-7, 13, 14];
const end = [-7, 13, 13];
const strides = [1];
const beginMask = 2;
const endMask = 0;
const ellipsisMask = 2;
const shrinkAxisMask = 15;
  
// Calling stridedSlice() method and 
// Printing output
tn.stridedSlice(x, begin, end, strides, beginMask,
     endMask, ellipsisMask, shrinkAxisMask).print();

输出:

Tensor
    2

参考: https://js.tensorflow.org/api/latest/#stridedSlice