📜  Tensorflow.js tf.input()函数

📅  最后修改于: 2022-05-13 01:56:46.608000             🧑  作者: Mango

Tensorflow.js tf.input()函数

深度学习中的模型是连接层的集合,可以训练、评估并用于预测某些事情。要执行此操作,您需要实例化模型的输入。在这篇文章中,我们将了解输入工厂函数的工作原理。

tf.input()函数在使用 tf.model()函数创建模型时使用。

句法:

tf.input(Args) 

参数: Args对象包含以下道具。

  • 形状:它表示预期输入将是 32 维向量的批次。
  • batchShape:表示形状元组,包括批量大小。
  • name:表示图层的名称。
  • dtype:用于表示输入的类型。
  • sparse:一个布尔值表示创建的占位符是稀疏的。

返回:它返回 tf.SymbolicTensor。

示例 1:在此示例中,我们将使用默认参数形状。

Javascript
// Importing the tensorflow.Js library
const tf = require("@tensorflow/tfjs")
 
// Define input
const inp = tf.input({ shape: [64] });
 
// Define op
const op = tf.layers.dense({ units: 8, activation: 'softmax' }).apply(inp);
 
// Create model and pass inp and op
const model = tf.model({ inputs: inp, outputs: op });
 
// Predict something
model.predict(tf.ones([2, 64])).print();


Javascript
// Importing the tensorflow.js library
const tf = require("@tensorflow/tfjs")
 
// Define input and pass all parameters
const inp = tf.input({ shape: [16] }, { name: 'abc' },
    { dtype: 'float32' }, { sparse: false });
 
// Define op
const op = tf.layers.dense({ units: 2, activation: 'softmax' }).apply(inp);
 
// Create model and pass inp and op
const model = tf.model({ inputs: inp, outputs: op });
 
// Predict something
model.summary();


输出:

Tensor
   [[0.0285837, 0.1409771, 0.1021329, 0.0912676, 0.2361873, 0.0262359, 
   0.2991393, 0.0754762],
    [0.0285837, 0.1409771, 0.1021329, 0.0912676, 0.2361873, 0.0262359, 
    0.2991393, 0.0754762]]

示例 2:在此示例中,我们将使用所有参数 shape、name、type 和 sparse。

Javascript

// Importing the tensorflow.js library
const tf = require("@tensorflow/tfjs")
 
// Define input and pass all parameters
const inp = tf.input({ shape: [16] }, { name: 'abc' },
    { dtype: 'float32' }, { sparse: false });
 
// Define op
const op = tf.layers.dense({ units: 2, activation: 'softmax' }).apply(inp);
 
// Create model and pass inp and op
const model = tf.model({ inputs: inp, outputs: op });
 
// Predict something
model.summary();

输出:

Layer (type)                 Output shape              Param #    
=================================================================
input8 (InputLayer)          [null,16]                 0          
_________________________________________________________________
dense_Dense8 (Dense)         [null,2]                  34        
=================================================================
Total params: 34
Trainable params: 34
Non-trainable params: 0
_________________________________________________________________

参考资料: https://js.tensorflow.org/api/latest/#input