import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn import datasets, linear_model
# 读取所需数据
def get_data(file_name):
data = pd.read_csv(file_name) # 获取Dataframe对象
X_parameter = []
Y_parameter = []
for single_square_feet, single_price_value in zip(data['square_feet'], data['price']):
X_parameter.append([float(single_square_feet)]) #加中括号变成二维数组,x,以及用这个x预测的y值
Y_parameter.append(float(single_price_value))
return X_parameter, Y_parameter
# 拟合线性模型
def linear_model_main(X_parameters, Y_parameters, predict_value):
regr = linear_model.LogisticRegression() # 创建线性回归对象
regr.fit(X_parameters, Y_parameters) # 拟合
predict_outcome = regr.predict(predict_value) # 调用线性回归对象的预测方法
predictions = {} # 定义一个空字典,存储拟合得到的斜率和截距,预测值
predictions['intercept'] = regr.intercept_
predictions['coefficient'] = regr.coef_
predictions['predicted_value'] = predict_outcome
return predictions
# 显示拟合线性模型的结果
def show_linear_line(X_parameters, Y_parameters):
regr = linear_model.LinearRegression()
regr.fit(X_parameters, Y_parameters)
plt.scatter(X_parameters, Y_parameters, color='blue')
plt.plot(X_parameters, regr.predict(X_parameters), color='red', linewidth=4)
# plt.xticks(()) # 参数是xtick位置的列表。和一个可选参数。如果将一个空列表作为参数传递,则它将删除所有xticks
# plt.yticks(())
plt.show()
X, Y = get_data('input_data.csv') # 传入所需数据
predictvalue = 700
result = linear_model_main(X, Y, predictvalue) # 结果字典
print("Intercept value:", result['intercept'])
print("Coefficient:", result['coefficient'])
print("Predicted value:", result['predicted_value'])
show_linear_line(X, Y)