📜  tensorflow - Python 代码示例

📅  最后修改于: 2022-03-11 14:45:33.915000             🧑  作者: Mango

代码示例2
import tensorflow as tf

mnist = tf.keras.datasets.mnist

(xTrain, yTrain), (xTest, yTest) = mnist.load_data()

xTrain, xTest = xTrain / 255, xTest / 255

network = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10, activation="sigmoid")
])

predictions = network(xTrain[:1]).numpy()

loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

loss(yTrain[:1], predictions).numpy()

network.compile("adam", loss, ["accuracy"])

network.fit(xTrain, yTrain, epochs=5)

network.evaluate(xTest, yTest, verbose=2)