Tensorflow.js tf.conv1d()函数
简介: Tensorflow.js 是谷歌开发的一个开源库,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。
.conv1d()函数用于根据所述输入张量确定一维卷积。
句法:
tf.conv1d(x, filter, stride, pad, dataFormat?, dilation?, dimRoundingMode?)
参数:
- x:指定的输入张量,其等级为 3 或等级 2,形状为:[batch, height, width, inChannels]。此外,如果等级为 2,则假定批次大小为 1。它可以是 tf.Tensor2D、tf.Tensor3D、TypedArray 或 Array 类型。
- filter:指定的 3 阶滤波器张量和形状:[filterHeight, filterWidth, depth]。它可以是 tf.Tensor3D、TypedArray 或 Array 类型。
- strides:规定的进气次数,在每一步的帮助下,规定的过滤器向右移动。它是数字类型。
- pad:用于填充的规定类型的算法。它的类型可以是 valid、same、number 或 conv_util.ExplicitPadding。
- 在这里,对于相同的步幅和步长 1,输出将具有与输入相同的大小,而与滤波器大小无关。
- 因为,“有效”输出应小于输入,以防过滤器大小大于 1*1×1。
- dataFormat:从“NWC”或“NCW”中指定的可选字符串。默认值为“NWC”,信息按 [batch, in_width, in_channels] 的顺序保存。而且,目前偏爱的只有“NWC”。它是可选的,类型为“NWC”或“NCW”。
- dilations:规定的扩张率,其中输入值在 atrous 卷积中进行采样。默认值为 1。如果大于 1,则步幅应为 1。它是可选的,并且是 number 类型。
- dimRoundingMode: “ceil”、“round”或“floor”中的指定字符串。如果未提供任何值,则默认值为 truncate。它是可选的,类型为天花板、圆形或地板。
返回值:返回 tf.Tensor2D 或 tf.Tensor3D。
示例 1:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Defining input tensor
const x = tf.tensor3d([1, 2, 3, 4], [2, 2, 1]);
// Defining filter tensor
const y = tf.tensor3d([1, 1, 0, 4], [1, 1, 4]);
// Calling conv1d() method
const result = tf.conv1d(x, y, 2, 'valid');
// Printing output
result.print();
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Calling conv1d() method with
// all its parameters
tf.tensor3d([1.1, 2.2, 3.3, 4.4], [2, 2, 1]).conv1d(
tf.tensor3d([1.3, 1.2, null, -4], [1, 1, 4]),
2, 0, 'NWC', 1, 'ceil').print();
输出:
Tensor
[ [[1, 1, 0, 4 ],],
[[3, 3, 0, 12],]]
示例 2:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Calling conv1d() method with
// all its parameters
tf.tensor3d([1.1, 2.2, 3.3, 4.4], [2, 2, 1]).conv1d(
tf.tensor3d([1.3, 1.2, null, -4], [1, 1, 4]),
2, 0, 'NWC', 1, 'ceil').print();
输出:
Tensor
[[[1.4299999, 1.3200001, 0, -4.4000001 ],
[0 , 0 , 0, 0 ]],
[[4.29 , 3.96 , 0, -13.1999998],
[0 , 0 , 0, 0 ]]]
参考: https://js.tensorflow.org/api/latest/#conv1d