📜  Tensorflow.js tf.truncatedNormal()函数(1)

📅  最后修改于: 2023-12-03 15:05:33.516000             🧑  作者: Mango

TensorFlow.js中的tf.truncatedNormal()函数

在TensorFlow.js中,tf.truncatedNormal()函数用于创建具有截断正态分布的张量。截断正态分布是指具有正态分布形态,但将具有均值加减两倍标准差之外的值截断为标准差之内的值。

tf.truncatedNormal()函数有以下参数:

  • shape:张量的形状,如[2, 4]表示一个2行4列的张量;
  • mean:正态分布的均值,默认值为0;
  • stddev:正态分布的标准差,默认值为1;
  • dtype:数据类型,默认值为float32;
  • seed:随机数生成器的种子值。

下面是一个使用tf.truncatedNormal()函数创建张量的示例:

const shape = [2, 3];
const mean = 0;
const stddev = 1;
const dtype = "float32";
const seed = 12345;

const tensor = tf.truncatedNormal(shape, mean, stddev, dtype, seed);
console.log(tensor.toString());

输出的结果如下:

Tensor
    [[-0.2284743, -0.26878512,  0.9872056 ],
     [-0.6348249, -1.2398715,   0.69449687]]

注意:tf.truncatedNormal()函数中的mean和stddev参数只影响正态分布的形态,不影响截断操作的结果。如果需要自定义截断范围,可以使用tf.clipByValue()函数。

使用tf.truncatedNormal()函数创建的张量可以用作模型的初始权重,也可以用于生成数据集。显然,由于其截断性质,与纯正态分布相比,创建具有截断正态分布的张量,往往更符合现实世界的数据分布。