继续巩固PointNet++代码的实现这篇博客,把代码逐行注释一遍!
pointnet++的所有代码和数据集都在github上,Pytorch代码:https://github.com/yanx27/Pointnet2_pytorch
data_utils中的modelnetdataloader部分的python代码注释如下:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# 导入第三方库
import numpy as np
import warnings
import os
from torch.utils.data import Dataset
warnings.filterwarnings('ignore')
# 将数据归一化
def pc_normalize(pc):
# 计算pc簇的中心点,新的中心点每一个特征的值,是该簇所有数据在该特征的平均值
centroid = np.mean(pc, axis=0)
# 3D数据簇减去中心得到到中心的绝对距离
pc = pc - centroid
# 取到最大距离
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
# 将数据归一化
pc = pc / m
return pc
# 最远点采样
def farthest_point_sample(point, npoint):
"""
Input:
xyz: pointcloud data, [N, D]
npoint: number of samples
Return:
centroids: sampled pointcloud index, [npoint, D]
"""
N, D = point.shape
xyz = point[:, :3]
# 先随机初始化一个centroids矩阵,
# 后面用于存储npoint个采样点的索引位置
centroids = np.zeros((npoint,))
# 利用distance矩阵记录某个样本中所有点到某一个点的距离
distance = np.ones((N,)) * 1e10 # 初值给个比较大的值,后面会迭代更新
# 利用farthest表示当前最远的点,也是随机初始化,范围为0~N
farthest = np.random.randint(0, N)
# 直到采样点达到npoint,否则进行如下迭代
for i in range(npoint):
# 设当前的采样点centroids为当前的最远点farthest;
centroids[i] = farthest
# 取出这个中心点centroid的坐标
centroid = xyz[farthest, :]
# 求出所有点到这个farthest点的欧式距离,存在dist矩阵中
dist = np.sum((xyz - centroid) ** 2, -1)
# 建立一个mask,如果dist中的元素小于distance矩阵中保存的距离值,
# 则更新distance中的对应值,
# 即记录某个样本中每个点距离所有已出现的采样点的最小距离
mask = dist < distance
distance[mask] = dist[mask]
# 最后从distance矩阵取出最远的点为farthest,继续下一轮迭代
farthest = np.argmax(distance, -1)
point = point[centroids.astype(np.int32)]
# 返回结果是npoint个采样点在原始点云中的索引
return point
# 加载数据集
class ModelNetDataLoader(Dataset):
def __init__(self, root, npoint=1024, split='train', uniform=False, normal_channel=True, cache_size=15000):
self.root = root # 数据集根目录
self.npoints = npoint # 对原始数据集下采样至1024个点
self.uniform = uniform # 是否归一化
self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt')
self.cat = [line.rstrip() for line in open(self.catfile)]
self.classes = dict(zip(self.cat, range(len(self.cat))))
self.normal_channel = normal_channel
shape_ids = {}
# 将数据集划分为训练集和测试集
shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))]
shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))]
assert (split == 'train' or split == 'test')
shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]]
# (shape_name, shape_txt_file_path) 元组列表
self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i
in range(len(shape_ids[split]))]
print('The size of %s data is %d' % (split, len(self.datapath)))
# 在内存中缓存数据点的大小
self.cache_size = cache_size # 核函数cache缓存大小,默认设置为15000
self.cache = {} # 从索引到(point_set, cls) 元组
def __len__(self):
return len(self.datapath)
def _get_item(self, index):
if index in self.cache:
point_set, cls = self.cache[index]
else:
fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]]
cls = np.array([cls]).astype(np.int32)
point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)
if self.uniform:
point_set = farthest_point_sample(point_set, self.npoints)
else:
point_set = point_set[0:self.npoints,:]
point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
if not self.normal_channel:
point_set = point_set[:, 0:3]
if len(self.cache) < self.cache_size:
self.cache[index] = (point_set, cls)
return point_set, cls # 返回点云集及其分类
def __getitem__(self, index):
return self._get_item(index)
if __name__ == '__main__': # 测试模型导入是否无误
# 导入第三方库
import torch
# 导入数据
data = ModelNetDataLoader('/data/modelnet40_normal_resampled/', split='train', uniform=False, normal_channel=True,)
DataLoader = torch.utils.data.DataLoader(data, batch_size=12, shuffle=True)
for point, label in DataLoader:
print(point.shape)
print(label.shape)
对比每一行的输出就会明白代码: