📜  Tensorflow.js tf.truncatedNormal()函数

📅  最后修改于: 2022-05-13 01:56:50.474000             🧑  作者: Mango

Tensorflow.js tf.truncatedNormal()函数

Tensorflow.js 是一个由谷歌开发的开源库,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。

.truncatedNormal()函数用于查找 tf.Tensor 以及从截断正态分布评估的值。此外,这里作为输出生成的值遵循正态分布,支持规定的平均值和标准偏差,不包括那些大小大于平均值 2 个标准偏差的值,将被丢弃并再次选择。

句法 :

tf.truncatedNormal(shape, mean?, stdDev?, dtype?, seed?)

参数:

  • shape:它是一个数组,包含描述输出张量形状的整数,类型为 number[]。
  • 均值:它是正态分布的规定均值,是数字类型。
  • stdDev:它是正态分布的规定标准偏差,是数字类型。
  • dtype:它是返回的输出张量的规定数据类型,可以是 float32 或 int32 类型。
  • 种子:它是指定的种子,有助于随机数生成器并且是数字类型。

返回值:它返回 tf.Tensor 对象。

示例 1:

Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Calling truncatedNormal() method and
// Printing output
tf.truncatedNormal([3, 4]).print();


Javascript
// Importing the tensorflow.js library 
import * as tf from "@tensorflow/tfjs"
  
// Defining shape
var sh = [3, 2];
var mean = 4;
var st_dev = 5;
var dtyp = 'int32';
  
// Calling truncatedNormal() method
var res = tf.truncatedNormal(sh, mean, st_dev, dtyp);
  
// Printing output
res.print();


输出:

Tensor
    [[-0.0277713, -0.4777073, -0.3911407, 1.85613   ],
     [-0.0667888, -0.0867875, 0.8295102 , -0.5933844],
     [0.5160138 , 0.7871808 , 0.6818511 , 1.2177598 ]]

示例 2:

Javascript

// Importing the tensorflow.js library 
import * as tf from "@tensorflow/tfjs"
  
// Defining shape
var sh = [3, 2];
var mean = 4;
var st_dev = 5;
var dtyp = 'int32';
  
// Calling truncatedNormal() method
var res = tf.truncatedNormal(sh, mean, st_dev, dtyp);
  
// Printing output
res.print();

输出:

Tensor
    [[-1, -5],
     [4 , 4 ],
     [11, 2 ]]

参考: https://js.tensorflow.org/api/latest/#truncatedNormal