Tensorflow.js tf.multinomial()函数
Tensorflow.js 是一个由谷歌开发的开源库,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。
.multinomial()函数用于生成 tf.Tensor 以及从多项分布中拖出的输入。
句法:
tf.multinomial(logits, numSamples, seed?, normalized?)
参数:
- logits:它是一个声明的一维数组以及杂乱无章的对数期望,或者是一个具有形状 [batchSize, numOutcomes] 的二维数组,它可以是 tf.Tensor1D、tf.Tensor2D、TypedArray 或 Array 类型。
- numSamples:这是规定的每个行部分要拖动的样本数。它是数字类型。
- 种子:它是规定的种子编号,是类型编号的可选参数。
- 标准化:它检查给定的 logits 是否是有组织的真实预期,即(总和为 1)。默认值为 false 并且是布尔类型的可选参数。
返回值:返回 tf.Tensor1D 或 tf.Tensor2D。
示例 1:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Defining logits
const logits = tf.tensor([35, 158]);
// Calling tf.multinomial() method and
// Printing output
tf.multinomial(logits, 4).print();
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Calling tf.multinomial() method and
// Printing output
tf.multinomial(tf.tensor(
[5.7, 8.7, NaN, 'a', null, 0]), 6).print();
输出:
Tensor
[1, 1, 1, 1]
示例 2:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Calling tf.multinomial() method and
// Printing output
tf.multinomial(tf.tensor(
[5.7, 8.7, NaN, 'a', null, 0]), 6).print();
输出:
Tensor
[5, 5, 5, 5, 5, 5]
参考: https://js.tensorflow.org/api/latest/#multinomial