ML实战:手动实现朴素贝叶斯分类器
代码实现
NB类
import numpy as np
from math import pi,e,sqrt,exp
np.set_printoptions(suppress=True)
class NB:
def __init__(self,x,y,tag):
'''
:param x: 数据集
:param y: 标签
:param tag: 标志对应属性是离散(置0)还是连续(置1)
:param Pc: 第c类的出现的频率
:param d: 特征数目
:param N: 类的个数
:param Ni: 每个属性的所有可能取值数
'''
self.x=x
self.y=y
self.Pc=None
self.d=None
self.N=None
self.Ni = np.zeros(len(x[0]))
self.tag=tag
def fit(self):
#初始化参数
self.d=len(self.x[0])
self.N=len(np.unique(self.y))
self.Pc=np.zeros(self.N)
for i in range(len(self.Ni)):
self.Ni[i]=len(np.unique(self.x[:,i]))
for i in range(self.N):
self.Pc[i]=(sum(self.y==i)+1)/(len(self.y)+self.N)
def normal(self,x,u,theta):
#正态分布概率密度函数
return exp(-(x-u)*(x-u)/(2*theta))/sqrt(2*pi*theta)
def predict_c(self,x,c:int):
#计算P(c|x)
Dc=sum(self.y==c)
res=self.Pc[c]
for i in range(self.d):
if self.tag[i]==0:
#离散属性加上拉普拉斯修正
Dci=len(self.x[(self.y==c) & (self.x[:,i]==x[i])])
res*=(Dci+1)/(Dc+self.Ni[i])
if self.tag[i]==1:
#连续属性默认服从正态分布
x_i=self.x[self.y==c,i]
u=np.mean(x_i)
theta=np.var(x_i)
res*=self.normal(x[i],u,theta)
return res
def predict(self,x):
#选取最大的c作为输出
maxc = 0
max = 0
for c in range(self.N):
temp = self.predict_c(x, c)
if max < temp:
max = temp
maxc = c
return maxc
主函数
from sklearn.datasets import load_iris
import numpy as np
from NB_class import NB
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
#获取数据
x=load_iris().data
y=load_iris().target
#分割数据集
x_train,x_test,y_train,y_test=train_test_split(x,y,test_size=0.2,random_state=np.random.randint(0,30))
X = np.arange(1, len(y_test) + 1)
#调用NB类,预测测试集
nb=NB(x_train,y_train,[1,1,1,1])
nb.fit()
y_predict=[]
for i in range(len(y_test)):
y_predict.append(nb.predict(x_test[i]))
predict=np.array(y_predict)
#画图
plt.figure(figsize=(15,4),dpi=80)
plt.scatter(X,y_test,label='real',marker='s',color='blue')
plt.scatter(X,y_predict,label='predict',marker='x',color='red')
plt.legend(loc=[1,0])
plt.grid(True,linestyle='--',alpha=0.5)
plt.yticks(y_test[::1])
plt.xticks(X[::1])
plt.xlabel('index of tests')
plt.ylabel('target')
plt.savefig('E:\python\ml\ml by myself\Bayes\\Navie_Bayes\\Iris_NB.png')
结果