📜  PyTorch预测和线性分类(1)

📅  最后修改于: 2023-12-03 15:34:33.227000             🧑  作者: Mango

PyTorch预测和线性分类

PyTorch是一个由Facebook开源的机器学习框架,提供了高效的张量操作和自动求导功能。本文将介绍如何使用PyTorch来进行预测和线性分类。

数据准备

首先,我们需要准备数据。这里使用sklearn中的一个自带数据集load_boston,该数据集包含了波士顿房价的相关信息。

from sklearn.datasets import load_boston
import pandas as pd

boston_dataset = load_boston()
boston = pd.DataFrame(boston_dataset.data, columns=boston_dataset.feature_names)
boston['MEDV'] = boston_dataset.target

我们将数据集分为训练集和测试集,并将其转化为PyTorch的Tensor数据类型并进行标准化处理:

import torch
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler
import numpy as np

X = boston.drop('MEDV', axis=1).values
y = boston['MEDV'].values

# 标准化处理
scaler = StandardScaler()
X = scaler.fit_transform(X)

# 转化为Tensor并分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

X_train = torch.from_numpy(X_train.astype(np.float32))
y_train = torch.from_numpy(y_train.astype(np.float32))

X_test = torch.from_numpy(X_test.astype(np.float32))
y_test = torch.from_numpy(y_test.astype(np.float32))
线性回归模型

我们将使用一个简单的线性回归模型来预测房价。模型的表达式为:

y = w * x + b

其中,w为权重,b为偏置。

我们可以使用PyTorch中的nn.Linear层来实现该模型:

import torch.nn as nn

input_size = X.shape[1]

# 定义模型结构
model = nn.Linear(input_size, 1)

print(model)

在定义完模型后,我们需要指定模型的损失函数和优化器。这里使用均方误差作为损失函数,使用随机梯度下降作为优化器:

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
模型训练

现在,我们可以开始训练模型了。在每个epoch中,我们都会根据模型的输出和标签计算损失并进行反向传播:

num_epochs = 100

for epoch in range(num_epochs):
    # 前向传播
    y_pred = model(X_train)
    
    # 计算损失
    loss = criterion(y_pred, y_train)
    
    # 反向传播
    loss.backward()
    
    # 更新权重和偏置
    optimizer.step()
    
    # 清空梯度
    optimizer.zero_grad()
    
    # 打印损失
    if (epoch+1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
模型评估

训练完成后,我们需要评估模型的性能。我们可以计算模型在测试集上的均方误差来评估模型的表现:

# 测试模型
with torch.no_grad():
    y_pred = model(X_test)
    mse = criterion(y_pred, y_test)
    print(f'MSE: {mse.item():.4f}')
结论

本文介绍了如何使用PyTorch来进行房价预测和线性分类。我们首先准备数据,然后定义模型结构、损失函数和优化器,最后进行模型训练和评估。通过本文,希望读者能够更加深入地了解PyTorch的使用方法。