📅  最后修改于: 2021-01-11 10:44:04             🧑  作者: Mango
CIFAR-10(加拿大高级研究所)和CIFAR-100被标记为8000万个微型图像数据集的子集。它们由Alex Krizhevsky,Geoffrey Hinton和Vinod Nair收集。数据集分为五个训练批次和一个测试批次,每个批次具有10000张图像。
测试批次包含每个类别的1000张随机选择的图像。训练批次以随机顺序包含剩余图像,但是一些训练批次以随机顺序包含剩余图像,但是一些训练批次包含从一个类到另一类的更多图像。在它们之间,每个批次的培训批次正好包含5000张图像。
这些类将完全互斥。汽车和卡车之间不会重叠。汽车包括类似于轿车和越野车的东西。卡车类仅包括大型卡车,也不包括皮卡车。如果我们通过CIFAR数据集查看,我们就会意识到这不仅仅是鸟或猫的一种。鸟和猫类包含许多不同类型的鸟和猫。鸟和猫类提供许多种类的鸟和猫,它们的大小,颜色,放大率,不同的角度和不同的姿势都不同。
对于无穷无尽的数据集,我们可以通过多种方式来编写第一和第二。它只是没有那么多样化,最重要的是,无尽的数据集是一个灰度标量。 CIFAR数据集包含32张乘32张彩色图像,每张照片具有三个不同的色彩通道。现在我们最重要的问题是,在无尽的数据集上表现如此出色的LeNet模型是否足以对CIFAR数据集进行分类?
就像CIFAR-10数据集一样。唯一的区别是它有100个类,每个类包含600个图像。每个班级有100张测试图像和500张训练图像。这100个类别分为20个超类,每个图像带有一个“粗”标签(它所属的超类),一个“精细”标签(它所属的类)和一个“精细”标签(该类)它所属的)。
CIFAR-100数据集中的以下类:
S. No | Superclass | Classes |
---|---|---|
1. | Flowers | Orchids, poppies, roses, sunflowers, tulips |
2. | Fish | Aquarium fish, flatfish, ray, shark, trout |
3. | Aquatic mammals | Beaver, dolphin, otter, seal, whale |
4. | food containers | Bottles, bowls, cans, cups, plates |
5. | Household electrical devices | Clock, lamp, telephone, television, computer keyboard |
6. | Fruit and vegetables | Apples, mushrooms, oranges, pears, sweet peppers |
7. | Household furniture | Table, Chair, couch, wardrobe, bed, |
8. | Insects bee, beetle, butterfly, caterpillar, cockroach | |
9. | Large natural outdoor scenes | Cloud, forest, mountain, plain, sea |
10. | Large human-made outdoor things | Bridge, castle, house, road, skyscraper |
11. | Large carnivores | Bear, leopard, lion, tiger, wolf |
12. | Medium-sized mammals | Fox, porcupine, possum, raccoon, skunk |
13. | Large Omnivores and herbivores | Camel, cattle, chimpanzee, elephant, kangaroo |
14. | Non-insect invertebrates | Crab, lobster, snail, spider, worm |
15. | reptiles | Crocodile, dinosaur, lizards, snake, turtle |
16. | trees | Maple, oak, palm, pine, willow |
17. | people | girl, man, women, baby, boy |
18. | Small mammals | Hamster, rabbit, mouse, shrew, squirrel |
19. | Vehicles 1 | Bicycle, bus, motorcycle, pickup truck, train |
20. | Vehicles 2 | Lawn-mower, rocket, streetcar, tractor, tank |
现在,使用内置的卷积神经网络TensorFlow训练网络以对来自CIFAR10数据集的图像进行分类。
考虑以下流程图,以了解用例的工作原理:
图
pip3 install numpy tensorflow pickle
import numpy as np
import tensorflow as tf
from time import time
import math
from include .data import get_data_set
from include.model import model, lr
train_x, train_y= get_data_set("train")
test_x, test_y = get_data_set("test")
tf. set_random_seed(21)
x, y, output, y_pred_cls, global_step, learning_rate=model()
global_accuracy =0
epoch_start=0
#PARAM
_BATCH_SIZE=128
_EPOCH=60
_SAVE_PATH="./tensorboard/cifar-10-v1.0.0/"
#LOSS AND OPTIMIZER
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=output, labels=y))
optimizer=tf.train.AdamOptimizer(learning_rate= learning_rate, beta1=0.9, beta2=0.999, epsilon=1e-08). Minimize(loss, global_step=global_step)
#PREDICTION AND ACCURACY CALCULATION
correct_prediction=tf.equal(y_pred_cls, tf.argmax(y, axis=1))
accuracy = tf.reduce_mean(tf.cast(correct_predictiction, tf.float32))
# SAVER
merged = tf.summary.merge_all()
saver = tf.train.Saver()
sess = tf.Session()
train_writer= tf.summary.FileWriter(_SAVE_PATH, sess.graph)
try:
print(" Trying to restore last checkpoint?")
last_chk_path= tf.train.latest_checkpoint(checkpoint_dir=SAVE_PATH)
saver.restore(sess, save_path=last_chk_path)
print("Restored checkpoint from:", last_chk_path)
except ValueError:
print("Failed to restore checkpoint. Initializing variable instead.")
sess.run(tf.global_variables_initializer())
def train(epoch):
global epoch_start
epoch_start= time()
batch_size=int(math.ceil(len(train_x)/_BATCH_SIZE))
i_global = 0
for s in range(batch_size):
batch_xs= train_x[s*_BATCH_SIZE: (s+1)*_BATCH_SIZE]
batch_ys = train_y[s*_BATCH_SIZE: (s+1)*_BATCH_SIZE]
start_time= time()
i_global, _, batch_loss, batch_acc=sess.run( [global_step, optimizer, loss, accuracy],
feed_dict={x: batch_xs, y: batch_ys, learning_rate: lr(epoch)})
duration = time() - start_time
if s% 10== 0:
percentage = int(round((s/batch_size)*100))
bar_len=29
filled_len= int ((bar_len*int(percentage))/100)
bar='=' *filled_len + ?>' + ?-? * (bar_len _filled_len)
msg= "Global step: { :>5} - [{}] {:>3}% -acc: {:.{:>4f} - loss: {:.4f} -{:.1f} sample/sec"
print(msg.format(i_global, bar, percentage, batch_acc, batch_loss, _BATCH_SIZE/duration))
test_and_save(i_global, epoch)
def test_and_save(_global_step, epoch):
global global_accuracy
global epoch_start
i=0
predicted_class=np.zeroes(shape=len(test_x), dtype=np.int)
while i< len (test_x) : j=min(i+_BATCH_SIZE, len(test_x)) batch_xs=test_x[I:j, :] batch_ys=test_y[i:j,:] predicted_class[i:j]=sess.run(y_pred_cls, feed_dict=x:batch_xs, y: batch_ys, learning_rate: lr(epoch)} ) i=j correct= (np.argmax(test_y, axis=1) == predicted_class) acc = correct.mean()*100 correct_numbers = correct.sum() hours, rem = divmod(time() - epoch_start, 3600) minutes, seconds = divmod(rem, 60) mes = "
Epoch {} - accuracy: {: .2f}% ({}/{})- time: {:0>2}:{:0.2}:{:05.2f}"
print(mes.format((epoch+1), acc, correct_numbers, len(test_x), int(hours), int(minutes), seconds))
if global_accuracy != 0 and global_accuracy < acc: summary = tf.Summary(value=[ tf.Summary.Value(tag="Accuracy/test", simple_value=acc), ]) train_writer.add_summary(summary, _global_step) saver.save(sess, save_path=_SAVE_PATH, global_step=_global_step) mes = "This epoch receive better accuracy: {:.2f} > {:.2f}. Saving session...
print(mes.format((acc, global_accuracy))
global_accuracy = acc
elif global_accuracy==0:
global_accuracy=acc
print("################################################################
def main():
train_start=time()
for i in range(_EPOCH):
print(" Epoch: {}/{}".format(( i+1),_EPOCH))
train(i)
hours, rem=divmod(time()-train_start, 3600 minutes, seconds=divmod(rem,60)
mes= "Best accuracy per session: {:.2f}, time: {:0>2}:{:0>2}:{:05.2f}"
print(mes.format(global_accuracy, int(hours), int(minutes), seconds))
if _name_ =="_main_":
main()
sess.close()
输出:
Epoch: 60/60
Global step: 23070 - [>-----------------------------] 0% - acc: 0.9531 - loss: 1.5081 - 7045.4 sample/sec
Global step: 23080 - [>-----------------------------] 3% - acc: 0.9453 - loss: 1.5159 - 7147.6 sample/sec
Global step: 23090 - [=>----------------------------] 5% - acc: 0.9844 - loss: 1.4764 - 7154.6 sample/sec
Global step: 23100 - [==>---------------------------] 8% - acc: 0.9297 - loss: 1.5307 - 7104.4 sample/sec
Global step: 23110 - [==>---------------------------] 10% - acc: 0.9141 - loss: 1.5462 - 7091.4 sample/sec
Global step: 23120 - [===>--------------------------] 13% - acc: 0.9297 - loss: 1.5314 - 7162.9 sample/sec
Global step: 23130 - [====>-------------------------] 15% - acc: 0.9297 - loss: 1.5307 - 7174.8 sample/sec
Global step: 23140 - [=====>------------------------] 18% - acc: 0.9375 - loss: 1.5231 - 7140.0 sample/sec
Global step: 23150 - [=====>------------------------] 20% - acc: 0.9297 - loss: 1.5301 - 7152.8 sample/sec
Global step: 23160 - [======>-----------------------] 23% - acc: 0.9531 - loss: 1.5080 - 7112.3 sample/sec
Global step: 23170 - [=======>----------------------] 26% - acc: 0.9609 - loss: 1.5000 - 7154.0 sample/sec
Global step: 23180 - [========>---------------------] 28% - acc: 0.9531 - loss: 1.5074 - 6862.2 sample/sec
Global step: 23190 - [========>---------------------] 31% - acc: 0.9609 - loss: 1.4993 - 7134.5 sample/sec
Global step: 23200 - [=========>--------------------] 33% - acc: 0.9609 - loss: 1.4995 - 7166.0 sample/sec
Global step: 23210 - [==========>-------------------] 36% - acc: 0.9375 - loss: 1.5231 - 7116.7 sample/sec
Global step: 23220 - [===========>------------------] 38% - acc: 0.9453 - loss: 1.5153 - 7134.1 sample/sec
Global step: 23230 - [===========>------------------] 41% - acc: 0.9375 - loss: 1.5233 - 7074.5 sample/sec
Global step: 23240 - [============>-----------------] 43% - acc: 0.9219 - loss: 1.5387 - 7176.9 sample/sec
Global step: 23250 - [=============>----------------] 46% - acc: 0.8828 - loss: 1.5769 - 7144.1 sample/sec
Global step: 23260 - [==============>---------------] 49% - acc: 0.9219 - loss: 1.5383 - 7059.7 sample/sec
Global step: 23270 - [==============>---------------] 51% - acc: 0.8984 - loss: 1.5618 - 6638.6 sample/sec
Global step: 23280 - [===============>--------------] 54% - acc: 0.9453 - loss: 1.5151 - 7035.7 sample/sec
Global step: 23290 - [================>-------------] 56% - acc: 0.9609 - loss: 1.4996 - 7129.0 sample/sec
Global step: 23300 - [=================>------------] 59% - acc: 0.9609 - loss: 1.4997 - 7075.4 sample/sec
Global step: 23310 - [=================>------------] 61% - acc: 0.8750 - loss:1.5842 - 7117.8 sample/sec
Global step: 23320 - [==================>-----------] 64% - acc: 0.9141 - loss:1.5463 - 7157.2 sample/sec
Global step: 23330 - [===================>----------] 66% - acc: 0.9062 - loss: 1.5549 - 7169.3 sample/sec
Global step: 23340 - [====================>---------] 69% - acc: 0.9219 - loss: 1.5389 - 7164.4 sample/sec
Global step: 23350 - [====================>---------] 72% - acc: 0.9609 - loss: 1.5002 - 7135.4 sample/sec
Global step: 23360 - [=====================>--------] 74% - acc: 0.9766 - loss: 1.4842 - 7124.2 sample/sec
Global step: 23370 - [======================>-------] 77% - acc: 0.9375 - loss: 1.5231 - 7168.5 sample/sec
Global step: 23380 - [======================>-------] 79% - acc: 0.8906 - loss: 1.5695 - 7175.2 sample/sec
Global step: 23390 - [=======================>------] 82% - acc: 0.9375 - loss: 1.5225 - 7132.1 sample/sec
Global step: 23400 - [========================>-----] 84% - acc: 0.9844 - loss: 1.4768 - 7100.1 sample/sec
Global step: 23410 - [=========================>----] 87% - acc: 0.9766 - loss: 1.4840 - 7172.0 sample/sec
Global step: 23420 - [==========================>---] 90% - acc: 0.9062 - loss: 1.5542 - 7122.1 sample/sec
Global step: 23430 - [==========================>---] 92% - acc: 0.9297 - loss: 1.5313 - 7145.3 sample/sec
Global step: 23440 - [===========================>--] 95% - acc: 0.9297 - loss: 1.5301 - 7133.3 sample/sec
Global step: 23450 - [============================>-] 97% - acc: 0.9375 - loss: 1.5231 - 7135.7 sample/sec
Global step: 23460 - [=============================>] 100% - acc: 0.9250 - loss: 1.5362 - 10297.5 sample/sec
Epoch 60 - accuracy: 78.81% (7881/10000)
This epoch receive better accuracy: 78.81 > 78.78. Saving session...
##################################################################################################
import numpy as np
import tensorflow as tf
from include.data import get_data_set
from include.model import model
test_x, test_y= get_data_set("test")
x, y, output, y_pred_cls, global_step, learning_rate =model()
_BATCH_SIZE = 128
_CLASS_SIZE = 10
_SAVE_PATH = "./tensorboard/cifar-10-v1.0.0/"
saver= tf.train.Saver()
Sess=tf.Session()
try;
print(" Trying to restore last checkpoint ...")
last_chk_path = tf.train.latest_checkpoint(checkpoint_dir=_SAVE_PATH
saver.restore(sess, save_path=last_chk_path)
print("Restored checkpoint from:", last_chk_path)
expect ValueError:
print("
Failed to restore checkpoint. Initializing variables instead.")
sess.run(tf.global_variables_initializer())
def main():
i=0
predicted_class= np.zeros(shape=len(test_x), dtype=np.int)
while i< lens(test_x):
j=min(i+_BATCH_SIZE, len(test_x))
batch_xs=test_x[i:j,:]
batch_xs=test_y[i:j,:]
pre dicted_class[i:j] = sess.run(y_pred_cls, feed_dict={x: batch_xs, y: batch_ys})
i=j
corr ect = (np.argmax(test_y, axis=1) == predicted_class)
acc=correct.mean()*100
correct_numbers=correct.sum()
print()
print("Accuracy is on Test-Set: {0:.2f}% ({1} / {2})".format(acc, correct_numbers, len(test_x)))
if__name__=="__main__":
main()
sess.close()
简单输出
Trying to restore last checkpoint ...
Restored checkpoint from: ./tensorboard/cifar-10-v1.0.0/-23460
Accuracy on Test-Set: 78.81% (7881 / 10000)
在这里,我们看到60个纪元需要多少时间:
Device | Batch | Time | Accuracy[%] |
---|---|---|---|
NVidia | 128 | 8m4s | 79.12 |
Inteli77700HQ | 128 | 3h30m | 78.91 |