一: 输入输出
输入:
- ROIs: RPN to ROI 流程的输出, shape 为 [300, 4]
- P_cls: Classifier网络的输出, shape为 [1, 32, 2]
- P_regr: Classifier网络的输出, shape为 [1, 32, 4]
输出:
- candidate_bboxes: 备选的boxes, shape为 [N, 4] N 表示未知
- candidate_probs: 备选的boxes的概率, shape为 [N, ] N 表示未知
二: 流程
- 过滤掉背景和概率小于0.8的bbox
- 使用P_regr修正ROI
三: code by code
将classifier 网络的输出P_cls, P_regr 转换为Nd4j INDArray
INDArray P_cls = TypeConvertor.tensorToNDArray_s3(classifier_output.getCls());
INDArray P_regr = TypeConvertor.tensorToNDArray_s3(classifier_output.getRegression());
遍历每个ROIs,并切过滤掉背景和概率小于0.8 的ROI
for (int ii = 0; ii < P_cls.shape()[1]; ii++)
{
INDArray theCls = P_cls.get(NDArrayIndex.point(0), NDArrayIndex.point(ii), NDArrayIndex.all());
// 是否背景.
boolean isBackground = theCls.argMax(0).getInt(0) == (P_cls.shape()[2] - 1);
// 小于0.8概率 || 是背景
if (theCls.maxNumber().floatValue() < bbox_threshold || isBackground)
{
continue;
}
提取对应的ROI坐标和回归值
int x = ROIs.getInt(0, ii, 0);
int y = ROIs.getInt(0, ii, 1);
int w = ROIs.getInt(0, ii, 2);
int h = ROIs.getInt(0, ii, 3);
float tx = P_regr.getFloat(new int[]{0, ii, 0});
float ty = P_regr.getFloat(new int[]{0, ii, 1});
float tw = P_regr.getFloat(new int[]{0, ii, 2});
float th = P_regr.getFloat(new int[]{0, ii, 3});
使用Classifier的输出P_regr 来修正ROI, 这个算法逻辑与训练时生成Classifier 网络标注算法相反.
tx /= classifier_regr_std[0];
ty /= classifier_regr_std[1];
tw /= classifier_regr_std[2];
th /= classifier_regr_std[3];
int[] coor_out;
try {
// [x, y, w, h] 格式.
// coor_out[0]: x
// coor_out[1]: y
// coor_out[2]: w
// coor_out[3]: h
coor_out = apply_regr(x, y, w, h, tx, ty, tw, th);
}
catch (Exception e)
{
continue;
}
将修正后的ROI坐标转换到VGG16的feature map 维度上,rpn_stride = 16
并将坐标从 [x1, y1, x2, y2] 转为 [x, y, w, h]
float x1 = coor_out[0] * rpn_stride;
float y1 = coor_out[1] * rpn_stride;
float x2 = (coor_out[0] + coor_out[2]) * rpn_stride;
float y2 = (coor_out[1] + coor_out[3]) * rpn_stride;
float[] bbox = new float[] {
coor_out[0] * rpn_stride,
coor_out[1] * rpn_stride,
(coor_out[0] + coor_out[2]) * rpn_stride,
(coor_out[1] + coor_out[3]) * rpn_stride
};
纵向排列一下,将N个bboxes的数组构建成INDArray shape = [N, 4]
修正为NMS(非最大值抑制)的数据输入格式. 下一个流程就需要执行NMS了。
INDArray candidate_bboxes = Nd4j.vstack(bboxes);
INDArray candidate_probs = Nd4j.create(probs);