Tensorflow.js tf.layers.multiply()函数
Tensorflow.js 是谷歌开发的一个开源库,用于在浏览器或节点环境中运行机器学习模型和深度学习神经网络。
tf.layers.multiply()函数用于执行输入数组的元素乘法。
句法:
tf.layers.multiply()
参数:
- inputShape:如果定义了这个参数,它将创建另一个输入层插入到该层之前。
- batchInputShape:如果定义了这个参数,它将创建另一个输入层插入到该层之前。
- batchSize:如果尚未指定,则用于构造 batchInputShape。
- dtype:指定该层的数据类型。此参数的默认值为“float32”。
- name:指定该层的名称。
- 可更新:指定该层的权重是否可以通过拟合更新。
- trainable:指定该层的权重是否可通过拟合更新。
- weights:指定层的初始权重值。
- i nputDType: “float32”或“int32”或“bool”或“complex64”或“字符串”。
返回值:与输入张量相同类型的单个张量。
示例 1:
Javascript
// Import the library
import * as tf from "@tensorflow/tfjs"
const input1 = tf.input({shape: [3, 2]})
const input2 = tf.input({shape: [3, 2]})
const input3 = tf.input({shape: [3, 2]})
// Create a multiply layer
const multiplyLayer = tf.layers.multiply()
// Multiple array of inputs by applying multiplyLayer
const product = multiplyLayer.apply([input1, input2, input3])
// Print the shape of output tensor
console.log(JSON.stringify(product.shape))
Javascript
// Import the library
import * as tf from "@tensorflow/tfjs"
// Inputs
const input1 = tf.tensor([-2, 1, 0, 5]);
const input2 = tf.tensor([3, 2, 3, 2]);
const input3 = tf.tensor([4, 3, 1, 2]);
// Create multiply layer
const multiplyLayer = tf.layers.multiply();
// Multiply inputs
const product = multiplyLayer.apply(
[input1, input2, input3]);
// Print product
console.log(product);
输出:
[null,3,2]
注意:这里的null表示未确定的批量大小。
示例 2:
Javascript
// Import the library
import * as tf from "@tensorflow/tfjs"
// Inputs
const input1 = tf.tensor([-2, 1, 0, 5]);
const input2 = tf.tensor([3, 2, 3, 2]);
const input3 = tf.tensor([4, 3, 1, 2]);
// Create multiply layer
const multiplyLayer = tf.layers.multiply();
// Multiply inputs
const product = multiplyLayer.apply(
[input1, input2, input3]);
// Print product
console.log(product);
输出:
Tensor
[-24, 6, 0, 20]
参考: https://js.tensorflow.org/api/latest/#layers.multiply