文章目录
前言
本篇是MMdet逐行解读第四篇,代码地址:mmdet/core/bbox/samplers/random_sampler.py。随机采样正负样本主要针对在训练过程中,经过MAXIOUAssigner后,确定出每个anchor和哪个gt匹配后,从这些正负样本中采样来进行loss计算。本文以RPN的config进行讲解,因为该部分用到了随机采样来克服正负样本不平衡;而在RetinaNet中则使用focal loss来克服正负样本不平衡问题,即没有随机采样的过程。
历史文章如下:
AnchorGenerator解读
MaxIOUAssigner解读
DeltaXYWHBBoxCoder解读
1、构造一个简单的sampler
from mmdet.core.bbox import build_sampler
# 构造一个sampler
sampler = dict(
type='RandomSampler',# 构造一个随机采样器
num=256, # 正负样本总数量
pos_fraction=0.5, # 正样本比例
neg_pos_ub=-1, # 负样本上限
add_gt_as_proposals=False) # 是否添加gt作为正样本,默认不添加。
sp = build_sampler(sampler)
# 这块不必详细理解,知道意思即可。
# 就是随机生成一个assigner、bboxes和gt_bboxes,
from mmdet.core.bbox import AssignResult
from mmdet.core.bbox.demodata import ensure_rng, random_boxes
rng = ensure_rng(None)
assign_result = AssignResult.random(rng=rng)
bboxes = random_boxes(assign_result.num_preds, rng=rng)
gt_bboxes = random_boxes(assign_result.num_gts, rng=rng)
gt_labels = None
# 调用sample方法进行正负样本采样
self = sp.sample(assign_result, bboxes, gt_bboxes, gt_labels)
2、BaseSampler类
class BaseSampler(metaclass=ABCMeta):
"""Base class of samplers"""
def __init__(self,
num,
pos_fraction,
neg_pos_ub=-1,
add_gt_as_proposals=True,
**kwargs):
self.num = num
self.pos_fraction = pos_fraction
self.neg_pos_ub = neg_pos_ub
self.add_gt_as_proposals = add_gt_as_proposals
self.pos_sampler = self
self.neg_sampler = self
@abstractmethod
def _sample_pos(self, assign_result, num_expected, **kwargs):
"""Sample positive samples"""
pass
@abstractmethod
def _sample_neg(self, assign_result, num_expected, **kwargs):
"""Sample negative samples"""
pass
def sample(self,
assign_result,
bboxes,
gt_bboxes,
gt_labels=None,
**kwargs):
pass
基类比较容易理解,核心是sample方法,内部调用_sample_pos方法和_sample_neg方法。后续继承该类的子类只需实现_sample_pos方法和_sample_neg方法即可。
3、RandomSampler类
3.1 sample方法
以RandomSampler类来讲解代码。首先看下sample方法:
# 确定正样本个数: 256*0.5 = 128
num_expected_pos = int(self.num * self.pos_fraction)
# 调用_sample_pos方法返回采样后正样本的id。
pos_inds = self.pos_sampler._sample_pos(
assign_result, num_expected_pos, bboxes=bboxes, **kwargs)
pos_inds = pos_inds.unique() # 挑选出tensor独立不重复元素
num_sampled_pos = pos_inds.numel() # 确定出正样本个数
num_expected_neg = self.num - num_sampled_pos #确定负样本个数
# 由于该参数为-1,故不执行if语句,即实打实的采样254个负样本
if self.neg_pos_ub >= 0:
_pos = max(1, num_sampled_pos)
#确定负样本的上限是正样本个数的neg_pos_ub倍
neg_upper_bound = int(self.neg_pos_ub * _pos)
# 负样本的个数不能超过上限
if num_expected_neg > neg_upper_bound:
num_expected_neg = neg_upper_bound
# 调用_sample_neg方法返回采样后负样本的id
neg_inds = self.neg_sampler._sample_neg(
assign_result, num_expected_neg, bboxes=bboxes, **kwargs)
neg_inds = neg_inds.unique() # 同理,将id取集合操作。
# 用SamplingResult进行封装
sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
assign_result, gt_flags)
代码还是比较容易理解:首先确定正样本个数,然后完成采样;采样后确定负样本的个数,若指定了负样本上限:neg_upper_bound,则负样本个数最多采样不能超过正样本个数的neg_upper_bound倍;若无指定,则负样本个数就是总的数量-正样本个数。
3.2 _sample_pos方法
再来看下采样正样本的方法:
def _sample_pos(self, assign_result, num_expected, **kwargs):
"""Randomly sample some positive samples."""
# 找出非0的正样本的id
pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False)
if pos_inds.numel() != 0:
pos_inds = pos_inds.squeeze(1)
# 若正样本个数<期望的128个,则直接返回。
if pos_inds.numel() <= num_expected:
return pos_inds
# 否则就从pos_inds里面随机采够128个。
else:
return self.random_choice(pos_inds, num_expected)
3.2 _sample_neg方法
和采样正样本方法大同小异,这里看下即可。
def _sample_neg(self, assign_result, num_expected, **kwargs):
"""Randomly sample some negative samples."""
neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False)
if neg_inds.numel() != 0:
neg_inds = neg_inds.squeeze(1)
if len(neg_inds) <= num_expected: # 若负样本数量比期望的还小则直接返回
return neg_inds
else:
return self.random_choice(neg_inds, num_expected)
总结
下篇将开启model模块介绍,敬请期待。