Tensorflow.js tf.layers.embedding()函数
Tensorflow.js 是一个由谷歌开发的开源库,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。
tf.layers.embedding()函数用于将正整数映射为固定大小的密集向量。
句法:
tf.layers.embedding(args)
参数:此函数接受args作为参数,该参数可以具有以下属性:
- inputDim:用于指定词汇量大小。
- outputDim:用于指定密集嵌入的维度。
- embeddingsInitializer:用于指定嵌入矩阵的初始化器。
- embeddingsRegularizer:用于指定将哪个正则化函数应用于嵌入矩阵。
- activityRegularizer:用于指定将哪个正则化函数应用于激活。
- embeddingsConstraint :用于指定将哪个约束函数应用于嵌入矩阵。
- maskZero:用于检查输入值 0 是否为特殊填充值。
- inputLength:用于指定输入序列的长度。
- inputShape:用于创建一个输入层以插入到该层之前。
- batchInputShape:用于创建一个输入层以插入到该层之前。
- batchSize:如果指定了inputShape,没有指定batchInputShape,则用于构造batchInputShape。
- dtype:用于表示该层的数据类型。
- name:用于表示该层的名称。
- trainable:用于表示该层的权重是否可以通过fit更新。
- weights:用于表示层的初始权重值。
- inputDType:仅用于遗留支持,不适用于新代码。
返回值:返回Embedding。
示例 1:
Javascript
// Import library
import * as tf from "@tensorflow/tfjs"
// Create embedding layer
const embeddingLayer = tf.layers.embedding({
inputDim: 10,
outputDim: 3,
inputLength: 2
});
const input = tf.ones([2, 2]);
// Apply embedding to input
const output = embeddingLayer.apply(input);
// Print the output
console.log(output)
Javascript
// Import the library
import * as tf from "@tensorflow/tfjs"
// Create embedding layer
const embeddingLayer = tf.layers.embedding({
inputDim: 100,
outputDim: 4,
inputLength: 3
});
const input = tf.ones([3, 3]);
// Apply embedding to input
const output = embeddingLayer.apply(input);
// Print the output
console.log(output)
输出:
Tensor
[[[0.0179072, 0.0069226, 0.0202718],
[0.0179072, 0.0069226, 0.0202718]],
[[0.0179072, 0.0069226, 0.0202718],
[0.0179072, 0.0069226, 0.0202718]]]
示例 2:
Javascript
// Import the library
import * as tf from "@tensorflow/tfjs"
// Create embedding layer
const embeddingLayer = tf.layers.embedding({
inputDim: 100,
outputDim: 4,
inputLength: 3
});
const input = tf.ones([3, 3]);
// Apply embedding to input
const output = embeddingLayer.apply(input);
// Print the output
console.log(output)
输出:
Tensor
[[[0.0443502, -0.0342815, 0.0228792, 0.0198386],
[0.0443502, -0.0342815, 0.0228792, 0.0198386],
[0.0443502, -0.0342815, 0.0228792, 0.0198386]],
[[0.0443502, -0.0342815, 0.0228792, 0.0198386],
[0.0443502, -0.0342815, 0.0228792, 0.0198386],
[0.0443502, -0.0342815, 0.0228792, 0.0198386]],
[[0.0443502, -0.0342815, 0.0228792, 0.0198386],
[0.0443502, -0.0342815, 0.0228792, 0.0198386],
[0.0443502, -0.0342815, 0.0228792, 0.0198386]]]
参考: https://js.tensorflow.org/api/latest/#layers.embedding