本文是利用Python K-means实现简单图像聚类的后续分析。
上文我们提到过,利用ResNet可以进行图像特征的抽取,进而帮助我们去进行聚类。但是其实这里面有个问题,拿resnet提取到的特征高达114688维,如果样本数量上去来的话,会变得非常耗时。
容易想到,那么多维特征,并不是每种特征都"有用",那么这个时候就可以对图像的特征进行一定的降维,这里我们使用PCA进行处理:
pca = PCA(n_components=10)
all_images = pca.fit_transform(all_images)
由于本文的例子里是对十张图像进行个二聚类,样本数总共就10,因此维数只能降到10及以下。但是呢,我们可以发现,哪怕是10维,效果也和原114688维差不多:
可以做到100%的分类准确率(足球&其他球)。实验证明,在维度大于6的情况下,性能都是高度可用的,这也侧面印证了视觉图像中存在大量的冗余信息。代码如下:
import os
import numpy as np
from sklearn.cluster import KMeans
import cv2
from imutils import build_montages
import torch.nn as nn
import torchvision.models as models
from PIL import Image
from sklearn.decomposition import PCA
from torchvision import transforms
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
resnet50 = models.resnet50(pretrained=True)
self.resnet = nn.Sequential(resnet50.conv1,
resnet50.bn1,
resnet50.relu,
resnet50.maxpool,
resnet50.layer1,
resnet50.layer2,
resnet50.layer3,
resnet50.layer4)
def forward(self, x):
x = self.resnet(x)
return x
net = Net().eval()
image_path = []
all_images = []
images = os.listdir('./images')
for image_name in images:
image_path.append('./images/' + image_name)
for path in image_path:
image = Image.open(path).convert('RGB')
image = transforms.Resize([224,244])(image)
image = transforms.ToTensor()(image)
image = image.unsqueeze(0)
image = net(image)
image = image.reshape(-1, )
print(image.shape)
all_images.append(image.detach().numpy())
print("starting pca")
pca = PCA(n_components=10)
all_images = pca.fit_transform(all_images)
print(pca.explained_variance_ratio_)
print("finish pca")
print(all_images)
clt = KMeans(n_clusters=2, random_state=1234)
clt.fit(all_images)
labelIDs = np.unique(clt.labels_)
for labelID in labelIDs:
idxs = np.where(clt.labels_ == labelID)[0]
idxs = np.random.choice(idxs, size=min(25, len(idxs)),
replace=False)
show_box = []
for i in idxs:
image = cv2.imread(image_path[i])
image = cv2.resize(image, (96, 96))
show_box.append(image)
montage = build_montages(show_box, (96, 96), (5, 5))[0]
title = "Type {}".format(labelID)
cv2.imshow(title, montage)
cv2.waitKey(0)