vislee / leevis.com

Blog
87 stars 13 forks source link

使用sklearn学习线性回归模型训练 #198

Open vislee opened 9 months ago

vislee commented 9 months ago

环境准备

使用pyenv管理python版本,切换到3.8.5。 安装线性回归测试所需要的包:

pip install numpy
pip install matplotlib
pip install scikit-learn
pip install pandas
pip install scipy
pip install graphviz

训练测试模型

import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression

m = 200

# 构造训练数据
x = np.random.rand(m, 1)*6 - 3
y = x + 0.5*x**2 + np.random.randn(m, 1)

# 画出训练数据
# plt.plot(x, y, 'b.')
# plt.xlabel('x')
# plt.ylabel('y')
# plt.axis([-3, 3, -5, 10])
# plt.show()

# 对数据做特征变换
poly_features = PolynomialFeatures(degree = 2, include_bias = False)
x_poly = poly_features.fit_transform(x)

print(x[0])
# [0.55472076]
print(x_poly[0])
# [0.55472076 0.30771512]

# 初始化线性回归
lin_reg = LinearRegression()
# 使用数据训练
lin_reg.fit(x_poly, y)

# 输出训练结果, lin_reg 就是训练好的线性回归方程。
print(lin_reg.coef_)
# [[1.04560117 0.47370004]]
print(lin_reg.intercept_)
# [0.05677187]

# 构造新的测试数据:100个-3~3之间的数
x_test = np.linspace(-3, 3, m).reshape(m, 1)
x_test_poly = poly_features.transform(x_test)
y_test = lin_reg.predict(x_test_poly)

plt.plot(x, y, 'b.')
plt.plot(x_test, y_test, 'r--', label='predict')
plt.xlabel('x')
plt.ylabel('y')
plt.axis([-3, 3, -5, 10])
plt.legend()
plt.show()
截屏2024-01-30 23 15 36

上图红色为训练好的函数曲线, 蓝色点为构造的训练数据。 同时根据linreg.coef 和 linreg.intercept 也可以看出训练好的函数符合: y = x + 0.5*x**2 + 一个1以内的常数。

使用Jupyter学习测试更方便。

# 安装
pip install jupyter
# 启动
jupyter notebook

相关文档

https://scikit-learn.org.cn https://matplotlib.org/ https://numpy.org/