📜  Tensorflow.js tf.oneHot()函数

📅  最后修改于: 2022-05-13 01:56:50.205000             🧑  作者: Mango

Tensorflow.js tf.oneHot()函数

Tensorflow.js是谷歌开发的一个开源库,用于在浏览器或节点环境中运行机器学习模型和深度学习神经网络。

tf.oneHot()函数用于创建 one-hot tf.Tensor。由索引表示的位置取值为 1(默认值),也称为 onValue,而所有其他位置取值为 0(默认值),也称为 offValue。

句法:

tf.oneHot (indices, depth, onValue, offValue)

参数:该函数接受三个参数,如下图所示:

  • 索引它可以 tf.Tensor(TypedArray 或 Array) 具有 dtype int32 的索引。
  • depth 深度的数据类型是数字。它用于表示一个热维度的深度。
  • onValue onValue的数据类型是数字。它用于在索引与位置匹配时填写输出。它是一个可选参数。
  • offValue offValue的数据类型是数字。用于在索引与位置不匹配时填写输出。它也是一个可选参数。

返回:它返回一个 tf.Tensor。

示例 1:

Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
 
// Use of oneHot function.
var val = tf.oneHot(tf.tensor1d([0,1,2], 'int32'), 3);;
 
// Printing the tensor
val.print()


Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
 
 
// Creating and initializing a new variable
var val = tf.oneHot(tf.tensor1d([0,1,2], 'int32'), 3,9,-1);
 
// Printing the tensor
val.print()


输出:

Tensor
    [[1, 0, 0],
    [0, 1, 0],
    [0, 0, 1]]

示例 2:

Javascript

// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
 
 
// Creating and initializing a new variable
var val = tf.oneHot(tf.tensor1d([0,1,2], 'int32'), 3,9,-1);
 
// Printing the tensor
val.print()

输出:

Tensor
    [[9 , -1, -1],
    [-1, 9 , -1],
    [-1, -1, 9 ]]

参考: https://js.tensorflow.org/api/latest/#oneHot