Tensorflow.js tf.LayersModel 类 .getLayer() 方法
Tensorflow.js 是由谷歌开发的开源库,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。
.getLayer()函数用于获取基于其名称(必须是唯一的)或索引的层。其中,索引以自下而上的方式依赖于水平图遍历的顺序。此外,如果同时给出名称和索引,则索引将优先。
句法:
getLayer(name?, index?)
参数:
- name:这是层的规定名称。它是可选的,类型为字符串。
- index:表示图层的索引。它是可选的,类型为 number。
返回值:返回 tf.layers.Layer。
示例 1:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Defining model
const model = tf.sequential();
// Adding a layer
model.add(tf.layers.dense({units: 4, inputShape: [1]}));
// Calling getLayer() method
const layer_0 = model.getLayer(null, 0);
// Printing weights of the layer_0
// using getWeights() method
layer_0.getWeights()[0].print();
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Defining model
const model = tf.sequential();
// Adding layers
model.add(tf.layers.dense({units: 4, inputShape: [1]}));
model.add(tf.layers.dense({units: 2, inputShape: [3], activation: 'relu6'}));
model.add(tf.layers.dense({units: 3, inputShape: [5], activation: 'sigmoid'}));
// Calling getLayer() method
const layer_0 = model.getLayer(NaN, 0);
const layer_1 = model.getLayer('denselayer', 1);
const layer_2 = model.getLayer(undefined, 2);
// Printing number of numbers in the weights
// of the layer_0, layer_1, and layer_2
// using countParams() method
console.log(layer_0.countParams());
console.log(layer_1.countParams());
console.log(layer_2.countParams());
输出:
Tensor
[[-0.0678914, 0.6647689, -0.3708572, -0.1764591],]
示例 2:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Defining model
const model = tf.sequential();
// Adding layers
model.add(tf.layers.dense({units: 4, inputShape: [1]}));
model.add(tf.layers.dense({units: 2, inputShape: [3], activation: 'relu6'}));
model.add(tf.layers.dense({units: 3, inputShape: [5], activation: 'sigmoid'}));
// Calling getLayer() method
const layer_0 = model.getLayer(NaN, 0);
const layer_1 = model.getLayer('denselayer', 1);
const layer_2 = model.getLayer(undefined, 2);
// Printing number of numbers in the weights
// of the layer_0, layer_1, and layer_2
// using countParams() method
console.log(layer_0.countParams());
console.log(layer_1.countParams());
console.log(layer_2.countParams());
输出:
8
10
9
参考: https://js.tensorflow.org/api/latest/#tf.LayersModel.getLayer