Tensorflow.js tf.oneHot()函数
Tensorflow.js是谷歌开发的一个开源库,用于在浏览器或节点环境中运行机器学习模型和深度学习神经网络。
tf.oneHot()函数用于创建 one-hot tf.Tensor。由索引表示的位置取值为 1(默认值),也称为 onValue,而所有其他位置取值为 0(默认值),也称为 offValue。
句法:
tf.oneHot (indices, depth, onValue, offValue)
参数:该函数接受三个参数,如下图所示:
- 索引它可以 tf.Tensor(TypedArray 或 Array) 具有 dtype int32 的索引。
- depth 深度的数据类型是数字。它用于表示一个热维度的深度。
- onValue onValue的数据类型是数字。它用于在索引与位置匹配时填写输出。它是一个可选参数。
- offValue offValue的数据类型是数字。用于在索引与位置不匹配时填写输出。它也是一个可选参数。
返回:它返回一个 tf.Tensor。
示例 1:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Use of oneHot function.
var val = tf.oneHot(tf.tensor1d([0,1,2], 'int32'), 3);;
// Printing the tensor
val.print()
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Creating and initializing a new variable
var val = tf.oneHot(tf.tensor1d([0,1,2], 'int32'), 3,9,-1);
// Printing the tensor
val.print()
输出:
Tensor
[[1, 0, 0],
[0, 1, 0],
[0, 0, 1]]
示例 2:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Creating and initializing a new variable
var val = tf.oneHot(tf.tensor1d([0,1,2], 'int32'), 3,9,-1);
// Printing the tensor
val.print()
输出:
Tensor
[[9 , -1, -1],
[-1, 9 , -1],
[-1, -1, 9 ]]
参考: https://js.tensorflow.org/api/latest/#oneHot