白板推导系列Pytorch-PCA降维

白板推导系列Pytorch-PCA降维

前面在看花书的时候就遇到过PCA,但是花书上的推导和白板推导中的推导不太一样,花书上的推导我至今还没弄清楚,但是这个我懂了,接下来我将以mnist数据集为例实现PCA降维并利用sklearn朴素贝叶斯分类器分类

导入相关包

import torch
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import MultinomialNB

定义PCA类

class PCA:
    def __init__(self,output_dim) -> None:
        self.output_dim = output_dim
    
    def fit(self,X_data):
        N = len(X_data)
        H = torch.eye(n=N)-1/N*(torch.matmul(torch.ones(size=(N,1)),torch.ones(size=(1,N))))
        X_data = torch.matmul(H,X_data)
        _,_,v = torch.svd(X_data)
        self.base = v[:,:self.output_dim]

    def fit_transform(self,X_data):
        self.fit(X_data)
        return self.transform(X_data)

    def transform(self,X_data):
        return torch.matmul(X_data,self.base)

    def inverse_transform(self,X_data):
        return torch.matmul(X_data,self.base.T)

加载mnist数据集

digits = load_digits()
X_train, X_test, y_train, y_test = train_test_split(digits.data,digits.target, test_size=0.2, random_state=42)

X_train = torch.tensor(X_train,dtype=torch.float)
X_test = torch.tensor(X_test,dtype=torch.float)
y_train = torch.tensor(y_train,dtype=torch.float)
y_test = torch.tensor(y_test,dtype=torch.float)

数据集二维可视化

pca = PCA(2)
X_train_pca = pca.fit_transform(X_train)
plt.scatter(X_train_pca[:,0],X_train_pca[:,1],c=y_train)

白板推导系列Pytorch-PCA降维

恢复原数据

我们可以看看降到不同的维度后还原成的数据损失了多少信息

plt.figure()
plt.subplot(331)

for i,dim in enumerate([2,10,20,30,40,50,60]):
    pca = PCA(dim)
    X_train_pca = pca.fit_transform(X_train)
    X_data = pca.inverse_transform(X_data=X_train_pca)
    plt.subplot(2,4,i+1)
    plt.imshow(X_data[0].view(8,8))
plt.subplot(2,4,8)
plt.imshow(X_train[0].view(8,8))
plt.show()

白板推导系列Pytorch-PCA降维

朴素贝叶斯分类

# 将图片降到20维用于训练
pca = PCA(20)
X_train_pca = pca.fit_transform(X_train)
model = GaussianNB()
model.fit(X_train_pca,y_train)
X_test_pca = pca.transform(X_test)
model.score(X_test_pca,y_test)

得到的准确率为0.9361111111111111

上一篇:主成分分析(PCA)及其可视化——python


下一篇:PCA主成分分析