ML实战:手动实现逻辑回归
代码实现
LogisticRegression类
import numpy as np
from sklearn.preprocessing import StandardScaler
np.set_printoptions(suppress=True)
class LogisticRegression:
def __init__(self,x,y):
'''
:param x: 数据集
:param y: 数据标签
:param theta:参数
'''
transfer = StandardScaler()
x = transfer.fit_transform(x)
self.x=x
self.y=np.array(y).reshape(-1,1)
self.theta=np.random.rand(len(x[0])+1,1)
one = np.ones(len(x))
self.x = np.c_[one, x]
def h(self):
#假设函数sigmoid
z=np.matmul(self.x,self.theta)
return 1/(1+np.exp(-z))
def single_iter(self,alpha):
#单次迭代函数
theta_size = len(self.theta)
size = len(self.x)
res = self.h()
res = res - self.y
for i in range(theta_size):
temp = self.theta[i]
temp = temp - alpha / size * (np.matmul(self.x[:, i], res))
self.theta[i]= temp
def fit(self,alpha=1,iter_count=500):
#参数拟合
for i in range(iter_count):
self.single_iter(alpha)
def predict(self,x):
#预测函数
transfer = StandardScaler()
x = transfer.fit_transform(x)
one = np.ones(len(x))
res = np.c_[one, x]
res=np.matmul(res,self.theta)
res=1/(1+np.exp(-res))
for i in range(len(res)):
if res[i]>0.5:
res[i]=1
else:
res[i]=0
return res
主函数
from sklearn import datasets
import numpy as np
from Logistic_class import LogisticRegression
import sys
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
np.set_printoptions(suppress=True)
#读取数据集
x=datasets.load_breast_cancer().data
y=datasets.load_breast_cancer().target
#分割数据集
x_train,x_test,y_train,y_test=train_test_split(x,y,test_size=0.1,random_state=np.random.randint(0,30))
X = np.arange(1, len(y_test) + 1)
#逻辑回归模型建立
logistic=LogisticRegression(x_train,y_train)
logistic.fit(10,4000)
y_predict=logistic.predict(x_test)
#画图
plt.figure(figsize=(15,4),dpi=80)
plt.scatter(X,y_test,label='real',marker='s',color='blue')
plt.scatter(X,y_predict,label='predict',marker='x',color='red')
plt.legend(loc=[1,0])
plt.grid(True,linestyle='--',alpha=0.5)
plt.yticks(y_test[::1])
plt.xticks(X[::1])
plt.xlabel('index of tests')
plt.ylabel('target')
plt.savefig('E:\python\ml\ml by myself\Logistic_Regression\Logistic_Regression.png')
sys.exit(0)
结果