Tensorflow.js tf.layers build() 方法
Tensorflow.js 是一个由谷歌开发的开源库,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。
.build()函数用于创建所述层的权重。这种方法应该应用于每个持有权重的层。此外,它在调用apply()方法以构建权重时调用。
句法:
build(inputShape)
参数:
- inputShape:它是声明的形状或未触及的形状数组。
返回值:返回void。
示例 1:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Creating a model
const model = tf.sequential();
// Adding a layer
model.add(tf.layers.dense({units: 1, inputShape: [3]}));
// Defining input
const input = tf.input({shape: [6, 2, 6]});
// Calling build method with its
// parameter
model.layers[0].build([input.Shape]);
// Printing output
console.log(JSON.stringify(input.shape));
model.layers[0].getWeights()[0].print();
Javascript
// Importing the tensorflow.js library
//import * as tf from "@tensorflow/tfjs"
// Creating a model
const model = tf.sequential();
// Adding layers
model.add(tf.layers.dense({units: 1, inputShape: [2]}));
model.add(tf.layers.dense({units: 2}));
// Defining inputs
const input1 = tf.input({shape: [1, 2]});
const input2 = tf.input({shape: [1.7, 2.7, 6.5]});
// Calling build method with its
// parameter
model.layers[0].build([input1.Shape]);
model.layers[1].build([input2.Shape]);
// Printing outputs
console.log(JSON.stringify(input1.shape));
console.log(JSON.stringify(input2.shape));
model.layers[0].getWeights()[0].print();
model.layers[1].getWeights()[0].print();
输出:
[null,6,2,6]
Tensor
[[-0.3726568],
[0.7343086 ],
[-0.2459907]]
在这里,getWeights() 方法用于打印权重。
示例 2:
Javascript
// Importing the tensorflow.js library
//import * as tf from "@tensorflow/tfjs"
// Creating a model
const model = tf.sequential();
// Adding layers
model.add(tf.layers.dense({units: 1, inputShape: [2]}));
model.add(tf.layers.dense({units: 2}));
// Defining inputs
const input1 = tf.input({shape: [1, 2]});
const input2 = tf.input({shape: [1.7, 2.7, 6.5]});
// Calling build method with its
// parameter
model.layers[0].build([input1.Shape]);
model.layers[1].build([input2.Shape]);
// Printing outputs
console.log(JSON.stringify(input1.shape));
console.log(JSON.stringify(input2.shape));
model.layers[0].getWeights()[0].print();
model.layers[1].getWeights()[0].print();
输出:
[null,1,2]
[null,1.7,2.7,6.5]
Tensor
[[0.6224715],
[1.2144204]]
Tensor
[[0.8342852, 0.4770206],]
参考: https://js.tensorflow.org/api/latest/#tf.layers.Layer.build