📅  最后修改于: 2023-12-03 14:55:27.406000             🧑  作者: Mango
线性回归是机器学习中最基础也是最常用的方法之一。简单线性回归是指只有一个自变量的回归模型。其原理是基于自变量和因变量之间的线性关系,在给定的训练数据中找到一条最佳拟合直线,用于预测新数据的结果。
以下是使用Python实现简单线性回归的代码:
import numpy as np
import matplotlib.pyplot as plt
# 准备数据
data = np.array([[1, 5], [2, 7], [3, 9], [4, 11], [5, 13], [6, 15], [7, 17], [8, 19], [9, 21], [10, 23]])
X_train = data[:, 0].reshape(-1, 1)
y_train = data[:, 1].reshape(-1, 1)
X_test = np.array([[11], [12], [13], [14], [15]])
y_test = np.array([[25], [27], [29], [31], [33]])
# 训练模型(使用梯度下降法)
alpha = 0.01
theta = np.array([[0], [0]])
m = X_train.shape[0]
for i in range(1000):
h = np.dot(X_train, theta)
loss = h - y_train
gradient = np.dot(X_train.T, loss) / m
theta = theta - alpha * gradient
# 测试模型
h_test = np.dot(X_test, theta)
rmse = np.sqrt(np.average((h_test - y_test) ** 2))
# 可视化结果
plt.plot(X_train, y_train, 'bo')
plt.plot(X_test, y_test, 'gx')
plt.plot(X_test, h_test, 'r-')
plt.xlabel('X')
plt.ylabel('y')
plt.legend(['Train data', 'Test data', 'Prediction'])
plt.title('Simple Linear Regression')
plt.show()
代码注释:
简单线性回归是机器学习中的基础模型,实现简单,但是应用广泛,是其他高级模型的基础。在实际问题中,线性回归可以用于数据分析、金融预测、自然语言处理等领域。