📜  Tensorflow.js tf.layers setWeights() 方法

📅  最后修改于: 2022-05-13 01:56:21.133000             🧑  作者: Mango

Tensorflow.js tf.layers setWeights() 方法

Tensorflow.js 是一个由谷歌开发的开源库,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。

.setWeights() 方法用于根据给定的张量设置所述层的权重。

句法:

setWeights(weights)

参数:

  • 权重:它是输入张量的规定列表。它是 tf.Tensor[] 类型。其中,数组的数量及其形状应等于所用层的规定权重的尺寸数量。换句话说,它必须等于getWeights()方法的结果。

返回值:返回void。

示例 1:

Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Creating a model
const model = tf.sequential();
  
// Adding a layer
model.add(tf.layers.dense({units: 2, inputShape: [11]}));
  
// Calling setWeights() method
model.layers[0].setWeights([tf.truncatedNormal([11, 2]), tf.zeros([2])]);
  
// Compiling the model
model.compile({loss: 'categoricalCrossentropy', optimizer: 'sgd'});
  
// Printing output using getWeights() method
model.layers[0].getWeights()[0].print();


Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Creating a model
const model = tf.sequential();
  
// Adding layers
model.add(tf.layers.dense({units: 1, 
    inputShape: [5], batchSize: 1, dtype: 'int32'}));
model.add(tf.layers.dense({units: 2, inputShape: [6], batchSize: 5}));
model.add(tf.layers.dense({units: 3, inputShape: [7], batchSize: 8}));
model.add(tf.layers.dense({units: 4, inputShape: [8], batchSize: 12}));
  
// Calling setWeights() method
model.layers[0].setWeights([tf.ones([5, 1]), tf.zeros([1])]);
model.layers[1].setWeights([tf.ones([1, 2]), tf.zeros([2])]);
  
// Printing output using getWeights() method
model.layers[0].getWeights()[0].print();
model.layers[0].getWeights()[1].print();
model.layers[1].getWeights()[0].print();
model.layers[1].getWeights()[1].print();


输出:

Tensor
    [[-0.5969906, -0.1883931],
     [0.8569255 , -0.49416  ],
     [0.1157023 , 0.1150239 ],
     [-0.4052143, 1.9936075 ],
     [0.3090054 , 0.7212474 ],
     [0.4626641 , -0.7287846],
     [0.4352857 , -0.5195332],
     [0.4626429 , 0.0216295 ],
     [-0.1110666, -0.5997615],
     [-0.5083916, -0.3582681],
     [-0.2847465, 1.184485  ]]

这里, truncatedNormal()方法用于创建 tf.Tensor 以及从截断正态分布中采样的值, zeros()方法用于创建 tf.Tensor 以及设置为 0 和getWeights的所有元素()方法用于打印使用setWeights()方法设置的权重。

示例 2:

Javascript

// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Creating a model
const model = tf.sequential();
  
// Adding layers
model.add(tf.layers.dense({units: 1, 
    inputShape: [5], batchSize: 1, dtype: 'int32'}));
model.add(tf.layers.dense({units: 2, inputShape: [6], batchSize: 5}));
model.add(tf.layers.dense({units: 3, inputShape: [7], batchSize: 8}));
model.add(tf.layers.dense({units: 4, inputShape: [8], batchSize: 12}));
  
// Calling setWeights() method
model.layers[0].setWeights([tf.ones([5, 1]), tf.zeros([1])]);
model.layers[1].setWeights([tf.ones([1, 2]), tf.zeros([2])]);
  
// Printing output using getWeights() method
model.layers[0].getWeights()[0].print();
model.layers[0].getWeights()[1].print();
model.layers[1].getWeights()[0].print();
model.layers[1].getWeights()[1].print();

输出:

Tensor
    [[1],
     [1],
     [1],
     [1],
     [1]]
Tensor
    [0]
Tensor
     [[1, 1],]
Tensor
    [0, 0]

在这里, ones()方法用于创建一个 tf.Tensor 以及所有设置为 1 的元素。

参考: https://js.tensorflow.org/api/latest/#tf.layers.Layer.setWeights