📅  最后修改于: 2023-12-03 15:06:51.738000             🧑  作者: Mango
随着深度学习的发展,图像分类已经成为了计算机视觉领域中最受关注的任务之一,它广泛应用于人脸识别、物体识别、自动驾驶等领域。在实际应用中,我们需要一些简单易用的工具来帮助我们进行图像分类。
本文将介绍使用 Web App 进行图像分类的方法,Web App 是一种通过浏览器来访问的应用程序,使用它可以方便地进行图像分类操作。我们将使用 Python 和 Flask 框架来开发这个应用程序。
在进行开发前,我们需要准备以下开发环境:
可以通过以下方式安装 Flask 和 PyTorch:
pip3 install flask
pip3 install torch
在进行图像分类时,我们需要有一些训练数据来训练模型。这里我们选择使用 CIFAR-10 数据集,该数据集包含了 10 个类别的图片,每个类别有 6,000 张图片,其中 5,000 张用作训练,1,000 张用作测试。
可以通过以下方式下载和解压数据集:
mkdir data
cd data
wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
tar -xzvf cifar-10-python.tar.gz
我们选择使用卷积神经网络来进行图像分类,具体可以参考 PyTorch 官方文档。我们在训练时使用 Adam 优化器和交叉熵损失函数,训练 100 次。代码如下:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
net = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)
for epoch in range(100):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 2000 == 1999:
print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
print('Finished Training')
我们使用 trainset
进行训练,每次使用 4 张图片进行训练,总共训练 100 次。训练结果保存在 net
中。
完成模型训练后,我们可以开始开发 Web App 来进行图像分类了。在开发过程中,我们需要解决以下两个问题:
我们可以通过以下方式来实现这两个问题:
form
表单将图片上传到服务器。Pillow
库来读取上传的图片,并使用模型对其进行分类。代码实现如下:
from flask import Flask, request, render_template
from PIL import Image
import numpy as np
import torch
net = torch.load('net.pt')
app = Flask(__name__)
@app.route('/')
def index():
return render_template('index.html')
@app.route('/upload', methods=['POST'])
def upload():
file = request.files['file']
img = Image.open(file.stream).resize((32, 32))
img_array = np.array(img) / 255.0
img_tensor = torch.FloatTensor(img_array.transpose(2, 0, 1).reshape(1, 3, 32, 32))
outputs = net(img_tensor)
_, predicted = torch.max(outputs.data, 1)
return str(predicted.item())
if __name__ == '__main__':
app.run()
我们使用 Flask
来开发 Web App,其中的 net.pt
是之前训练得到的模型。上传的图片使用 file
参数来获取,读取图片使用 Pillow
库,将图片转换成 32x32x3
的张量,并使用之前训练得到的模型对图片进行分类。
通过以上步骤,我们已经成功地搭建了一个图像分类的 Web App,可以通过浏览器来上传图片,并且得到分类结果。这一过程涉及到了 Python 编程、深度学习和 Web 开发等多个领域。希望这篇文章能够对大家有所帮助。