📜  识别 BTS 的成员——一个图像分类器

📅  最后修改于: 2022-05-13 01:55:01.405000             🧑  作者: Mango

识别 BTS 的成员——一个图像分类器

BTS 是一支由 7 名成员组成的著名 K-Pop 乐队。本文着眼于一个图像分类器,它可以从图片中识别乐队成员的姓名。图像分类器将使用 fastai 构建。它是一个深度学习库,旨在使深度学习民主化。它建立在 PyTorch 之上,并拥有大量可立即使用的优化权重的模型。该应用程序将托管在 Binder 上,最终产品将如下所示:

准备数据集

与任何图像分类器的情况一样,模型需要在数据集上进行训练,从中可以推断和提取与特定类别对应的特征。 BTS 图像分类器将包含 7 个类别(成员总数)。可以通过手动收集不同成员的图像,然后将它们放在该类别的文件夹中来准备数据集。为了加快这个过程,可以使用Python脚本来创建数据集。该脚本将从 Google 图片搜索中获取图片。 (免责声明:使用这些图像可能会导致侵犯版权,因此请自行承担风险)。

一个名为simple_images的文件夹将出现在脚本所在的位置。在simple_images文件夹中,将存在与 7 个成员中的每个成员对应的文件夹,其中包含 150 张图像。

是时候对分类器进行编码了。建议使用 Google Collab(训练时 GPU 会派上用场)并将数据集上传到 Google Drive。

Python3
# Import fastbook
from fastbook import *
from fastai.vision.widgets import *
from google.colab import drive 
drive.mount('/content/drive')
  
import fastbook
fastbook.setup_book()
  
  
  
class DataLoaders(GetAttr):
    def __init__(self, *loaders): self.loaders = loaders
  
    def __getitem__(self, i): return self.loaders[i]
    train, valid = add_props(lambda i, self: self[i])


Python3
# Import the required function to download from the Simple Image Download library.
from simple_image_download import simple_image_download as simp
# Create a response instance.
response = simp.simple_image_download
# The following lines would look up Google Images and download the number of images specified.
# The first argument is the term to search, and the second argument is the number of images to be downloaded.
response.download('BTS Jin', 150)
response.download('BTS Jimin', 150)
response.download('BTS RM', 150)
response.download('BTS J-Hope', 150)
response.download('BTS Suga', 150)
response.download('BTS Jungkook', 150)


Python3
bts = bts.new(
    item_tfms=RandomResizedCrop(224, min_scale=0.5),
    batch_tfms=aug_transforms())
dls = bts.dataloaders(path)


Python3
learn = cnn_learner(dls, resnet18, metrics=error_rate)
learn.fine_tune(8)


Python3
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()


Python3
interp.plot_top_losses(5, nrows=5)


Python3
learn.export()
path = Path()
path.ls(file_exts='.pkl')


DataLoaders是一个负责为模型提供有效和训练数据集的类。



蟒蛇3

# Import the required function to download from the Simple Image Download library.
from simple_image_download import simple_image_download as simp
# Create a response instance.
response = simp.simple_image_download
# The following lines would look up Google Images and download the number of images specified.
# The first argument is the term to search, and the second argument is the number of images to be downloaded.
response.download('BTS Jin', 150)
response.download('BTS Jimin', 150)
response.download('BTS RM', 150)
response.download('BTS J-Hope', 150)
response.download('BTS Suga', 150)
response.download('BTS Jungkook', 150)

清理数据

已下载的图像可能尺寸不同。最好具有统一维度的数据集中的所有图像。 fastai库有一个函数:

蟒蛇3

bts = bts.new(
    item_tfms=RandomResizedCrop(224, min_scale=0.5),
    batch_tfms=aug_transforms())
dls = bts.dataloaders(path)

所有图像都调整为 224 x 224,这是训练数据集中图像的标准尺寸。

创建模型

现在是创建Learner 的时候了。学习者是将从提供的数据集中学习的模型。当提供不属于训练集的图像时,它将能够预测输出(自变量)。此处使用的学习器称为“Resnet18”。它已经过预训练,这意味着权重被调整,这样模型应该能够在没有进一步调整的情况下进行合理的预测。这个想法被称为迁移学习

蟒蛇3

learn = cnn_learner(dls, resnet18, metrics=error_rate)
learn.fine_tune(8)

Fine_tune(8)表示学习进行了 8 个epoch 。这个号码可以玩。准确性和计算能力/时间之间的权衡将是需要考虑的。

现在模型已经训练好了,结果可以通过查看混淆矩阵来可视化。

蟒蛇3

interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()

 


理想情况下,只有混淆矩阵的对角元素应该是非零的。可以看出,在模型的预测中,存在一些错误分类。



可以看到损失最高的图像。这些通常是模型非常确定地错误预测或不太确定地正确预测的图像。

蟒蛇3

interp.plot_top_losses(5, nrows=5)


部署模型

该模型将使用 Binder 进行部署。需要粘贴 notebook 的 GitHub URL。首先需要导出模型,生成一个扩展名为.pkl的文件。

蟒蛇3

learn.export()
path = Path()
path.ls(file_exts='.pkl')

访问 Binder 的网站。粘贴 GitHub 存储库的 URL,其中包含笔记本和.pkl文件。在“打开的 URL”空白处,输入笔记本的 (GitHub) URL。单击“启动”,几分钟后,网络应用程序就可以使用了。

注意:这个图像分类器最初是在 Fast.AI 深度学习课程的第 2 课中讲授的。