Tensorflow.js tf.dilation2d()函数
Tensorflow.js 是由谷歌开发的开源库,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。
.dilation2d()函数用于评估指定输入张量上的灰度膨胀。
句法:
tf.dilation2d(x, filter, strides, pad, dilations?, dataFormat?)
参数:
- x:指定的输入张量,其等级为 3 或等级 4,形状为:[batch, height, width, inChannels]。此外,如果等级为 3,则假定批次大小为 1。它可以是 tf.Tensor3D、tf.Tensor4D、TypedArray 或 Array 类型。
- filter:指定的 3 阶滤波器张量和形状:[filterHeight, filterWidth, depth]。它可以是 tf.Tensor3D、TypedArray 或 Array 类型。
- strides :给定输入张量的每个大小的滑动窗口的规定步幅:[strideHeight,strideWidth]。如果规定的步幅是一个数字,那么 strideHeight == strideWidth。它可以是 [number, number] 或 number 类型。
- pad:用于填充的规定类型的算法。它可以是类型 valid 或相同。
- 在这里,对于相同的步幅和步长 1,输出将具有与输入相同的大小,而与滤波器大小无关。
- 因为,“有效”输出应小于输入,以防过滤器大小大于 1*1×1。
- dilations:规定的扩张率:[dilationHeight, dilationWidth],因为输入值是在高度和宽度维度上采样的,有利于atrous morphological dilation。默认值为 [1, 1]。此外,如果 dilations 是单个数字,则 dilationHeight == dilationWidth。如果它大于 1,那么所有的步幅值都应该是 1。它是可选的并且是 [number, number], number 类型。
- dataFormat:指定输入和输出数据的数据格式。默认值为“NHWC”。而且,这里的数据存储顺序是:[batch, height, width, channels]。它是可选的,属于“NHWC”类型。
返回值:返回 tf.Tensor3D 或 tf.Tensor4D。
示例 1:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Defining input tensor
const x = tf.tensor3d([1, 2, 3, 4], [2, 2, 1]);
// Defining filter tensor
const y = tf.tensor3d([1, 1, 0, 4], [1, 1, 4]);
// Calling dilation2d() method
const result = tf.dilation2d(x, y, 2, 'valid');
// Printing output
result.print();
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Calling dilation2d() method with
// all its parameters
tf.tensor3d([1.1, 2.2, 3.3, 4.4], [2, 2, 1]).dilation2d(
tf.tensor3d([1.3, 1.2, null, -4], [1, 1, 4]),
2, 'valid', [3, 2], 'NHWC').print();
输出:
Tensor
[ [[2],]]
示例 2:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Calling dilation2d() method with
// all its parameters
tf.tensor3d([1.1, 2.2, 3.3, 4.4], [2, 2, 1]).dilation2d(
tf.tensor3d([1.3, 1.2, null, -4], [1, 1, 4]),
2, 'valid', [3, 2], 'NHWC').print();
输出:
Tensor
[ [[2.4000001],]]
参考: https://js.tensorflow.org/api/latest/#dilation2d