📜  使用 Web App 进行图像分类(1)

📅  最后修改于: 2023-12-03 15:06:51.738000             🧑  作者: Mango

使用 Web App 进行图像分类

随着深度学习的发展,图像分类已经成为了计算机视觉领域中最受关注的任务之一,它广泛应用于人脸识别、物体识别、自动驾驶等领域。在实际应用中,我们需要一些简单易用的工具来帮助我们进行图像分类。

本文将介绍使用 Web App 进行图像分类的方法,Web App 是一种通过浏览器来访问的应用程序,使用它可以方便地进行图像分类操作。我们将使用 Python 和 Flask 框架来开发这个应用程序。

开发环境

在进行开发前,我们需要准备以下开发环境:

  • Python 3.6 或以上版本
  • Flask 1.1.2 或以上版本
  • PyTorch 1.4.0 或以上版本

可以通过以下方式安装 Flask 和 PyTorch:

pip3 install flask
pip3 install torch
数据集

在进行图像分类时,我们需要有一些训练数据来训练模型。这里我们选择使用 CIFAR-10 数据集,该数据集包含了 10 个类别的图片,每个类别有 6,000 张图片,其中 5,000 张用作训练,1,000 张用作测试。

可以通过以下方式下载和解压数据集:

mkdir data
cd data
wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
tar -xzvf cifar-10-python.tar.gz
模型训练

我们选择使用卷积神经网络来进行图像分类,具体可以参考 PyTorch 官方文档。我们在训练时使用 Adam 优化器和交叉熵损失函数,训练 100 次。代码如下:

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)

net = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

for epoch in range(100):

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):

        inputs, labels = data
        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        if i % 2000 == 1999:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

我们使用 trainset 进行训练,每次使用 4 张图片进行训练,总共训练 100 次。训练结果保存在 net 中。

Web App

完成模型训练后,我们可以开始开发 Web App 来进行图像分类了。在开发过程中,我们需要解决以下两个问题:

  1. 如何将图片上传到服务器。
  2. 如何将图片传入模型进行分类。

我们可以通过以下方式来实现这两个问题:

  1. 使用 form 表单将图片上传到服务器。
  2. 使用 Pillow 库来读取上传的图片,并使用模型对其进行分类。

代码实现如下:

from flask import Flask, request, render_template
from PIL import Image
import numpy as np
import torch

net = torch.load('net.pt')
app = Flask(__name__)

@app.route('/')
def index():
    return render_template('index.html')

@app.route('/upload', methods=['POST'])
def upload():
    file = request.files['file']
    img = Image.open(file.stream).resize((32, 32))
    img_array = np.array(img) / 255.0
    img_tensor = torch.FloatTensor(img_array.transpose(2, 0, 1).reshape(1, 3, 32, 32))
    outputs = net(img_tensor)
    _, predicted = torch.max(outputs.data, 1)
    return str(predicted.item())

if __name__ == '__main__':
    app.run()

我们使用 Flask 来开发 Web App,其中的 net.pt 是之前训练得到的模型。上传的图片使用 file 参数来获取,读取图片使用 Pillow 库,将图片转换成 32x32x3 的张量,并使用之前训练得到的模型对图片进行分类。

结论

通过以上步骤,我们已经成功地搭建了一个图像分类的 Web App,可以通过浏览器来上传图片,并且得到分类结果。这一过程涉及到了 Python 编程、深度学习和 Web 开发等多个领域。希望这篇文章能够对大家有所帮助。