Tensorflow.js tf.maxPoolWithArgmax()函数
Tensorflow.js是由谷歌开发的开源库,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。
.maxPoolWithArgmax()函数用于确定图像的 2D 最大池化以及 argmax 列表(即索引)。其中,argmax 中的索引是水平的,以便位置 [b, y, x, c] 处的峰值变为压缩索引: (y * width + x) * channels + c in case, include_batch_in_index 为 false 并且如果 include_batch_in_index为真,则为 ((b * height + y) * width + x) * channels +c。此外,在展平之前,返回的索引始终在 [0, height) x [0, width) 中。
句法:
tf.maxPoolWithArgmax(x, filterSize,
strides, pad, includeBatchInIndex?)
参数:
- x:指定的输入张量,其等级为 4 或等级 3,形状为:[batch, height, width, inChannels]。此外,如果等级为 3,则假定批次大小为 1。它可以是 tf.Tensor4D、TypedArray 或 Array 类型。
- filterSize:形状的规定过滤器大小:[filterHeight,filterWidth]。如果过滤器大小是一个奇异数,那么 filterHeight == filterWidth。它可以是 [number, number] 或 number 类型。
- 步幅:形状池的规定步幅:[strideHeight, strideWidth]。如果 strides 是一个单数,那么 strideHeight == strideWidth。它可以是 [number, number] 或 number 类型。
- pad:用于填充的规定类型的算法。它可以是类型 valid、same 或 number。
- 在这里,对于相同和步长 1,输出将具有与输入相同的大小,而与滤波器大小无关。
- 因为,“有效”输出应小于输入,以防过滤器大小大于 1*1×1。
- includeBatchInIndex:它是可选的并且是布尔类型。
返回值:返回 {[name: 字符串]: tf.Tensor}。
示例 1:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Defining input tensor
const x = tf.tensor4d([1, 2, 3, 4], [2, 2, 1, 1]);
// Calling maxPoolWithArgmax() method
const result = tf.maxPoolWithArgmax(x, 3, 2, 'same');
// Printing output
console.log(result)
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Calling maxPoolWithArgmax() method
console.log(tf.maxPoolWithArgmax(
tf.tensor4d([1.1, 2.1, 3.1, 4.1],
[1, 2, 2, 1]), [1, 2], [1, 1],
'valid', true
));
输出:
{
"result": {
"kept": false,
"isDisposedInternal": false,
"shape": [
2,
1,
1,
1
],
"dtype": "float32",
"size": 2,
"strides": [
1,
1,
1
],
"dataId": {
"id": 20
},
"id": 20,
"rankType": "4",
"scopeId": 14
},
"indexes": {
"kept": false,
"isDisposedInternal": false,
"shape": [
2,
1,
1,
1
],
"dtype": "float32",
"size": 2,
"strides": [
1,
1,
1
],
"dataId": {
"id": 21
},
"id": 21,
"rankType": "4",
"scopeId": 14
}
}
示例 2:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Calling maxPoolWithArgmax() method
console.log(tf.maxPoolWithArgmax(
tf.tensor4d([1.1, 2.1, 3.1, 4.1],
[1, 2, 2, 1]), [1, 2], [1, 1],
'valid', true
));
输出:
{
"result": {
"kept": false,
"isDisposedInternal": false,
"shape": [
1,
2,
1,
1
],
"dtype": "float32",
"size": 2,
"strides": [
2,
1,
1
],
"dataId": {
"id": 80
},
"id": 80,
"rankType": "4",
"scopeId": 54
},
"indexes": {
"kept": false,
"isDisposedInternal": false,
"shape": [
1,
2,
1,
1
],
"dtype": "float32",
"size": 2,
"strides": [
2,
1,
1
],
"dataId": {
"id": 81
},
"id": 81,
"rankType": "4",
"scopeId": 54
}
}
参考: https://js.tensorflow.org/api/latest/#maxPoolWithArgmax