📅  最后修改于: 2023-12-03 15:05:33.083000             🧑  作者: Mango
在Tensorflow.js中,tf.maxPoolWithArgmax()函数用于在max pooling期间记录池化旁边的最大值的位置,并返回一个包含最大值索引的二维张量。这个函数可以方便地用于图像分割和对象检测等各种深度学习领域中。
这个函数可以异步执行,并返回一个Promise,所以我们可以使用async/await来等待结果。
tf.maxPoolWithArgmax(
x: Tensor4D,
filterSize: [number, number] | number,
strides: [number, number] | number,
pad: keyof typeof tf.PadMode = 'valid'
): {result: Tensor4D, indexes: Tensor4D}
Tensor4D
类型,输入张量。维数必须是四,表示形状为 [batch, height, width, channels]
。[number, number] | number
类型,过滤器在高度和宽度上期望的大小。如果是一个数字,则期望的高度和宽度相同。[number, number] | number
类型,期望滑动窗口 在高度和宽度上滑动的跨度。如果是一个数字,则期望高度和宽度上的步幅相同。keyof typeof tf.PadMode
类型,控制如何在输入周围填充边缘。默认是 'valid'
。函数的返回值是一个含有 result
和 indexes
属性的对象。其中,result
是输入张量 x
的最大池化结果张量,indexes
是表示每个元素最大值索引的二维张量。
import * as tf from '@tensorflow/tfjs';
// 构造一个5x5的随机图像
const x = tf.randomUniform([1, 5, 5, 1]);
// 创建一个 2x2的池化过滤器,步长为2
const filterSize = 2;
const strides = 2;
// 计算最大池化结果及其索引
const {result, indexes} = tf.maxPoolWithArgmax(x, filterSize, strides);
// 打印最大化结果及其索引
result.print();
indexes.print();
@tensorflow/tfjs-backend-webgl
包,并调用 tf.setBackend('webgl')
来使用它。tf.unstack()
函数将张量划分为可处理的大小。