模板匹配(Template matching, TM)是一种解码端推导方法,用来细化当前CU的运动信息,使得当前CU的MV更准确。
TM主要是通过寻找一个MV使得当前图片的模板(当前 CU 的顶部和/或左侧相邻块)和参考图片的模板之间的匹配误差最小。如下图所示,在 [– 8, +8] 像素搜索范围内围绕当前 CU 的初始 MV 搜索更好的 MV。其中TM是基于 AMVR 模式确定搜索步长,并且 TM 可以在Merge模式下与双边匹配(bilateral matching, BM)过程级联。
在AMVP模式下,仅对特定的MV候选项进行细化,具体地,根据模板匹配误差确定进行TM细化的MVP候选项:选取当前块模板与参考块模板差异最小的MVP候选项进行TM细化。 TM 通过使用迭代菱形搜索,从 [–8, +8] 像素搜索范围内的全像素 MVD 精度(或AMVR 模式下的 4 像素)开始优化此 MVP 候选。 可以通过使用具有全像素 MVD 精度(或AMVR 模式下的 4 像素)的交叉搜索来进一步细化 AMVP 候选,然后根据表 1 中指定的 AMVR 模式依次进行半像素和四分之一像素搜索。 这个搜索过程确保MVP候选在TM过程之后仍然保持与AMVR模式所指示的相同的MV精度。
在Merge模式下,对Merge索引所指示的Merge候选者应用相似的搜索方法。 如表 1 所示,TM 可以一直执行到 1/8 像素 MVD 精度或跳过那些超过半像素 MVD 精度的,这取决于是否根据Merge的运动信息使用替代插值滤波器(alternative interpolation filter, AltIF,当 AMVR 处于半像素模式时使用) 。 此外,当启用 TM 模式时,模板匹配可以作为基于块和基于子块的双边匹配 (BM) 方法之间的独立过程或额外的 MV 细化过程,这取决于BM是否可以根据其启用条件检查启用。
相关代码
ECM中,TM细化MV的入口函数是deriveTMMv函数,需要注意的是Merge模式下对全部的MV候选项都会进行TM细化,而AMVP模式下仅对候选列表中的模板匹配误差最小的MV进行TM细化,二者调用的函数不同,如下所示:
#if TM_MRG
// Merge模式下调用的函数
void deriveTMMv (PredictionUnit& pu);
#endif
// 对特定的MV进行细化
Distortion deriveTMMv (const PredictionUnit& pu, bool fillCurTpl, Distortion curBestCost, RefPicList eRefList, int refIdx, int maxSearchRounds, Mv& mv, const MvField* otherMvf = nullptr);
这两个函数的代码及注释如下所示:
void InterPrediction::deriveTMMv(PredictionUnit& pu)
{
if( !pu.tmMergeFlag )
{
return;
}
Distortion minCostUni[NUM_REF_PIC_LIST_01] = { std::numeric_limits<Distortion>::max(), std::numeric_limits<Distortion>::max() };
for (int iRefList = 0; iRefList < ( pu.cu->slice->isInterB() ? NUM_REF_PIC_LIST_01 : 1 ) ; ++iRefList)
{
if (pu.interDir & (iRefList + 1))
{
minCostUni[iRefList] = deriveTMMv(pu, true, std::numeric_limits<Distortion>::max(), (RefPicList)iRefList, pu.refIdx[iRefList], TM_MAX_NUM_OF_ITERATIONS, pu.mv[iRefList]);
}
}
if (pu.cu->slice->isInterB() && pu.interDir == 3
#if MULTI_PASS_DMVR
&& !PU::checkBDMVRCondition(pu)
#endif
)
{
if (minCostUni[0] == std::numeric_limits<Distortion>::max() || minCostUni[1] == std::numeric_limits<Distortion>::max())
{
return;
}
RefPicList eTargetPicList = (minCostUni[0] <= minCostUni[1]) ? REF_PIC_LIST_1 : REF_PIC_LIST_0;
MvField mvfBetterUni(pu.mv[1 - eTargetPicList], pu.refIdx[1 - eTargetPicList]);
Distortion minCostBi = deriveTMMv(pu, true, std::numeric_limits<Distortion>::max(), eTargetPicList, pu.refIdx[eTargetPicList], TM_MAX_NUM_OF_ITERATIONS, pu.mv[eTargetPicList], &mvfBetterUni);
if (minCostBi > (minCostUni[1 - eTargetPicList] + (minCostUni[1 - eTargetPicList] >> 3)))
{
pu.interDir = 1 + (1 - eTargetPicList);
pu.mv [eTargetPicList] = Mv();
pu.refIdx[eTargetPicList] = NOT_VALID;
}
}
}
#if TM_AMVP || TM_MRG
// maxSearchRounds 最大搜索次数,为0时表示不进行搜索,仅计算初始MV对应的模板的Cost
Distortion InterPrediction::deriveTMMv(const PredictionUnit& pu, bool fillCurTpl, Distortion curBestCost, RefPicList eRefList, int refIdx, int maxSearchRounds, Mv& mv, const MvField* otherMvf)
{
CHECK(refIdx < 0, "Invalid reference index for TM");
const CodingUnit& cu = *pu.cu;
const Picture& refPic = *cu.slice->getRefPic(eRefList, refIdx)->unscaledPic;
bool doSimilarityCheck = otherMvf == nullptr ? false : cu.slice->getRefPOC((RefPicList)eRefList, refIdx) == cu.slice->getRefPOC((RefPicList)(1 - eRefList), otherMvf->refIdx);
InterPredResources interRes(m_pcReshape, m_pcRdCost, m_if, m_filteredBlockTmp[0][COMPONENT_Y]
, m_filteredBlock[3][1][0], m_filteredBlock[3][0][0]
);
// 构造函数,获取当前模板和参考模板
TplMatchingCtrl tplCtrl(pu, interRes, refPic, fillCurTpl, COMPONENT_Y, true, maxSearchRounds, m_pcCurTplAbove, m_pcCurTplLeft, m_pcRefTplAbove, m_pcRefTplLeft, mv, (doSimilarityCheck ? &(otherMvf->mv) : nullptr), curBestCost);
if (!tplCtrl.getTemplatePresentFlag())
{
// 如果上模板和左模板都不存在
return std::numeric_limits<Distortion>::max();
}
if (otherMvf == nullptr) // uni prediction 单向预测
{
tplCtrl.deriveMvUni<TM_TPL_SIZE>();
mv = tplCtrl.getFinalMv(); // 返回最终细化的MV
return tplCtrl.getMinCost(); // 返回最小的代价
}
else // bi prediction 双向预测
{
const Picture& otherRefPic = *cu.slice->getRefPic((RefPicList)(1-eRefList), otherMvf->refIdx)->unscaledPic; // 另一个方向的参考帧
// 当前模板减去另一个方向的参考模板
tplCtrl.removeHighFreq<TM_TPL_SIZE>(otherRefPic, otherMvf->mv, getBcwWeight(cu.BcwIdx, eRefList));
tplCtrl.deriveMvUni<TM_TPL_SIZE>();
mv = tplCtrl.getFinalMv();
int8_t intWeight = getBcwWeight(cu.BcwIdx, eRefList);
return (tplCtrl.getMinCost() * intWeight + (g_BcwWeightBase >> 1)) >> g_BcwWeightBase;
}
}
TM过程中模板的获取以及搜索过程都是通过TplMatchingCtrl类控制的,代码如下所示:
class TplMatchingCtrl
{
enum TMSearchMethod
{
TMSEARCH_DIAMOND,
TMSEARCH_CROSS,
TMSEARCH_NUMBER_OF_METHODS
};
const CodingUnit& m_cu;
const PredictionUnit& m_pu;
InterPredResources& m_interRes;
const Picture& m_refPic;
const Mv m_mvStart;
Mv m_mvFinal;
const Mv* m_otherRefListMv;
Distortion m_minCost;
bool m_useWeight;
int m_maxSearchRounds;
ComponentID m_compID;
PelBuf m_curTplAbove;
PelBuf m_curTplLeft;
PelBuf m_refTplAbove;
PelBuf m_refTplLeft;
PelBuf m_refSrAbove; // pre-filled samples on search area
PelBuf m_refSrLeft; // pre-filled samples on search area
#if JVET_X0056_DMVD_EARLY_TERMINATION
Distortion m_earlyTerminateTh;
#endif
#if MULTI_PASS_DMVR
Distortion m_tmCostArrayDiamond[9];
Distortion m_tmCostArrayCross[5];
#endif
public:
// 构造函数,获取当前模板和参考模板
TplMatchingCtrl(const PredictionUnit& pu,
InterPredResources& interRes, // Bridge required resource from InterPrediction
const Picture& refPic,
const bool fillCurTpl,
const ComponentID compID,
const bool useWeight,
const int maxSearchRounds,
Pel* curTplAbove,
Pel* curTplLeft,
Pel* refTplAbove,
Pel* refTplLeft,
const Mv& mvStart,
const Mv* otherRefListMv,
const Distortion curBestCost
);
// 返回模板是否存在
bool getTemplatePresentFlag() { return m_curTplAbove.buf != nullptr || m_curTplLeft.buf != nullptr; }
Distortion getMinCost () { return m_minCost; } // 返回最小的cost
Mv getFinalMv () { return m_mvFinal; } // 返回最终细化后的MV
static int getDeltaMean (const PelBuf& bufCur, const PelBuf& bufRef, const int rowSubShift, const int bd);
template <int tplSize> void deriveMvUni (); // 推导单向MV
template <int tplSize> void removeHighFreq (const Picture& otherRefPic, const Mv& otherRefMv, const uint8_t curRefBcwWeight);
private:
template <int tplSize, bool TrueA_FalseL> bool xFillCurTemplate (Pel* tpl);
template <int tplSize, bool TrueA_FalseL, int sr> PelBuf xGetRefTemplate (const PredictionUnit& curPu, const Picture& refPic, const Mv& _mv, PelBuf& dstBuf);
template <int tplSize, bool TrueA_FalseL> void xRemoveHighFreq (const Picture& otherRefPic, const Mv& otherRefMv, const uint8_t curRefBcwWeight);
template <int tplSize, int searchPattern> void xRefineMvSearch (int maxSearchRounds, int searchStepShift);
#if MULTI_PASS_DMVR
template <int searchPattern> void xNextTmCostAarray (int bestDirect);
template <int searchPattern> void xDeriveCostBasedMv ();
template <bool TrueX_FalseY> void xDeriveCostBasedOffset (Distortion costLorA, Distortion costCenter, Distortion costRorB, int log2StepSize);
int xBinaryDivision (int64_t numerator, int64_t denominator, int fracBits);
#endif
template <int tplSize> Distortion xGetTempMatchError (const Mv& mv);
template <int tplSize, bool TrueA_FalseL> Distortion xGetTempMatchError (const Mv& mv);
};
TM模式中的当前模板的获取和参考模板的获取是在TplMatchingCtrl类的构造函数中实现的,分别调用xFillCurTemplate函数和xGetRefTemplate函数实现当前模板的获取和参考模板的获取。
#if TM_AMVP || TM_MRG
TplMatchingCtrl::TplMatchingCtrl( const PredictionUnit& pu,
InterPredResources& interRes,
const Picture& refPic,
const bool fillCurTpl,
const ComponentID compID,
const bool useWeight,
const int maxSearchRounds,
Pel* curTplAbove,
Pel* curTplLeft,
Pel* refTplAbove,
Pel* refTplLeft,
const Mv& mvStart,
const Mv* otherRefListMv,
const Distortion curBestCost
)
: m_cu (*pu.cu)
, m_pu (pu)
, m_interRes (interRes)
, m_refPic (refPic)
, m_mvStart (mvStart)
, m_mvFinal (mvStart)
, m_otherRefListMv (otherRefListMv)
, m_minCost (curBestCost)
, m_useWeight (useWeight)
, m_maxSearchRounds (maxSearchRounds)
, m_compID (compID)
{
// Initialization 初始化
// 填充当前模板
const bool tplAvalableAbove = xFillCurTemplate<TM_TPL_SIZE, true >((fillCurTpl ? curTplAbove : nullptr)); // 上侧模板可用
const bool tplAvalableLeft = xFillCurTemplate<TM_TPL_SIZE, false>((fillCurTpl ? curTplLeft : nullptr)); // 左侧模板可用
m_curTplAbove = tplAvalableAbove ? PelBuf(curTplAbove, pu.lwidth(), TM_TPL_SIZE ) : PelBuf();
m_curTplLeft = tplAvalableLeft ? PelBuf(curTplLeft , TM_TPL_SIZE, pu.lheight()) : PelBuf();
// 参考模板
m_refTplAbove = tplAvalableAbove ? PelBuf(refTplAbove, m_curTplAbove ) : PelBuf();
m_refTplLeft = tplAvalableLeft ? PelBuf(refTplLeft , m_curTplLeft ) : PelBuf();
#if JVET_X0056_DMVD_EARLY_TERMINATION
m_earlyTerminateTh = TM_TPL_SIZE * ((tplAvalableAbove ? m_pu.lwidth() : 0) + (tplAvalableLeft ? m_pu.lheight() : 0));
#endif
// Pre-interpolate samples on search area 在搜索区域预插样本
// 上参考模板以及其相邻长度为 8 的搜索范围
m_refSrAbove = tplAvalableAbove && maxSearchRounds > 0 ? PelBuf(interRes.m_preFillBufA, m_curTplAbove.width + 2 * TM_SEARCH_RANGE, m_curTplAbove.height + 2 * TM_SEARCH_RANGE) : PelBuf();
if (m_refSrAbove.buf != nullptr)
{
m_refSrAbove = xGetRefTemplate<TM_TPL_SIZE, true, TM_SEARCH_RANGE>(m_pu, m_refPic, mvStart, m_refSrAbove);
m_refSrAbove = m_refSrAbove.subBuf(Position(TM_SEARCH_RANGE, TM_SEARCH_RANGE), m_curTplAbove); // 定位到搜索参考模板的初始位置
}
// 左参考模板
m_refSrLeft = tplAvalableLeft && maxSearchRounds > 0 ? PelBuf(interRes.m_preFillBufL, m_curTplLeft .width + 2 * TM_SEARCH_RANGE, m_curTplLeft .height + 2 * TM_SEARCH_RANGE) : PelBuf();
if (m_refSrLeft.buf != nullptr)
{
m_refSrLeft = xGetRefTemplate<TM_TPL_SIZE, false, TM_SEARCH_RANGE>(m_pu, m_refPic, mvStart, m_refSrLeft);
m_refSrLeft = m_refSrLeft.subBuf(Position(TM_SEARCH_RANGE, TM_SEARCH_RANGE), m_curTplLeft);
}
}
xFillCurTemplate函数获取当前模板:
template <int tplSize, bool TrueA_FalseL>
bool TplMatchingCtrl::xFillCurTemplate(Pel* tpl)
{
const Position posOffset = TrueA_FalseL ? Position(0, -tplSize) : Position(-tplSize, 0); // 位置偏移
// 相邻CU
const CodingUnit* const cuNeigh = m_cu.cs->getCU(m_pu.blocks[m_compID].pos().offset(posOffset), toChannelType(m_compID));
if (cuNeigh == nullptr) // 相邻CU不可用,直接返回FALSE
{
return false;
}
if (tpl == nullptr) // 存储模板的指针为空,返回
{
return true;
}
const Picture& currPic = *m_cu.cs->picture; // 当前帧
const CPelBuf recBuf = currPic.getRecoBuf(m_cu.cs->picture->blocks[m_compID]); // 当前帧的重建分量
std::vector<Pel>& invLUT = m_interRes.m_pcReshape->getInvLUT();
const bool useLUT = isLuma(m_compID) && m_cu.cs->picHeader->getLmcsEnabledFlag() && m_interRes.m_pcReshape->getCTUFlag();
#if JVET_W0097_GPM_MMVD_TM & TM_MRG
if (m_cu.geoFlag)
{
CHECK(m_pu.geoTmType == GEO_TM_OFF, "invalid geo template type value");
if (m_pu.geoTmType == GEO_TM_SHAPE_A)
{
if (TrueA_FalseL == 0)
{
return false;
}
}
if (m_pu.geoTmType == GEO_TM_SHAPE_L)
{
if (TrueA_FalseL == 1)
{
return false;
}
}
}
#endif
const Size dstSize = (TrueA_FalseL ? Size(m_pu.lwidth(), tplSize) : Size(tplSize, m_pu.lheight()));
for (int h = 0; h < (int)dstSize.height; h++)
{
const Position recPos = TrueA_FalseL ? Position(0, -tplSize + h) : Position(-tplSize, h);
const Pel* rec = recBuf.bufAt(m_pu.blocks[m_compID].pos().offset(recPos));
Pel* dst = tpl + h * dstSize.width;
for (int w = 0; w < (int)dstSize.width; w++)
{
int recVal = rec[w];
dst[w] = useLUT ? invLUT[recVal] : recVal;
}
}
return true;
}
xGetRefTemplate函数获取参考模板:
template <int tplSize, bool TrueA_FalseL, int sr>
PelBuf TplMatchingCtrl::xGetRefTemplate(const PredictionUnit& curPu, const Picture& refPic, const Mv& _mv, PelBuf& dstBuf)
{
// read from pre-interpolated buffer 从预插值缓冲区读取
PelBuf& refSrBuf = TrueA_FalseL ? m_refSrAbove : m_refSrLeft;
// sr = 0 直接从预插值的缓冲区读取样本
if (sr == 0 && refPic.getPOC() == m_refPic.getPOC() && refSrBuf.buf != nullptr)
{
Mv mvDiff = _mv - m_mvStart;
if ((mvDiff.getAbsHor() & ((1 << MV_FRACTIONAL_BITS_INTERNAL) - 1)) == 0 && (mvDiff.getAbsVer() & ((1 << MV_FRACTIONAL_BITS_INTERNAL) - 1)) == 0)
{
mvDiff >>= MV_FRACTIONAL_BITS_INTERNAL;
if (mvDiff.getAbsHor() <= TM_SEARCH_RANGE && mvDiff.getAbsVer() <= TM_SEARCH_RANGE)
{
return refSrBuf.subBuf(Position(mvDiff.getHor(), mvDiff.getVer()), dstBuf);
}
}
}
// Do interpolation on the fly 插值
Position blkPos = ( TrueA_FalseL ? Position(curPu.lx(), curPu.ly() - tplSize) : Position(curPu.lx() - tplSize, curPu.ly()) );
Size blkSize = Size(dstBuf.width, dstBuf.height);
Mv mv = _mv - Mv(sr << MV_FRACTIONAL_BITS_INTERNAL, sr << MV_FRACTIONAL_BITS_INTERNAL);
clipMv( mv, blkPos, blkSize, *m_cu.cs->sps, *m_cu.cs->pps );
const int lumaShift = 2 + MV_FRACTIONAL_BITS_DIFF;
const int horShift = (lumaShift + ::getComponentScaleX(m_compID, m_cu.chromaFormat));
const int verShift = (lumaShift + ::getComponentScaleY(m_compID, m_cu.chromaFormat));
const int xInt = mv.getHor() >> horShift;
const int yInt = mv.getVer() >> verShift;
const int xFrac = mv.getHor() & ((1 << horShift) - 1);
const int yFrac = mv.getVer() & ((1 << verShift) - 1);
const CPelBuf refBuf = refPic.getRecoBuf(refPic.blocks[m_compID]);
const Pel* ref = refBuf.bufAt(blkPos.offset(xInt, yInt));
Pel* dst = dstBuf.buf;
int refStride = refBuf.stride;
int dstStride = dstBuf.stride;
int bw = (int)blkSize.width;
int bh = (int)blkSize.height;
const int nFilterIdx = 1;
const bool useAltHpelIf = false;
const bool biMCForDMVR = false;
if ( yFrac == 0 )
{
m_interRes.m_if.filterHor( m_compID, (Pel*) ref, refStride, dst, dstStride, bw, bh, xFrac, true, m_cu.chromaFormat, m_cu.slice->clpRng(m_compID), nFilterIdx, biMCForDMVR, useAltHpelIf );
}
else if ( xFrac == 0 )
{
m_interRes.m_if.filterVer( m_compID, (Pel*) ref, refStride, dst, dstStride, bw, bh, yFrac, true, true, m_cu.chromaFormat, m_cu.slice->clpRng(m_compID), nFilterIdx, biMCForDMVR, useAltHpelIf );
}
else
{
const int vFilterSize = isLuma(m_compID) ? NTAPS_BILINEAR : NTAPS_CHROMA;
PelBuf tmpBuf = PelBuf(m_interRes.m_ifBuf, Size(bw, bh+vFilterSize-1));
m_interRes.m_if.filterHor( m_compID, (Pel*)ref - ((vFilterSize>>1) -1)*refStride, refStride, tmpBuf.buf, tmpBuf.stride, bw, bh+vFilterSize-1, xFrac, false, m_cu.chromaFormat, m_cu.slice->clpRng(m_compID), nFilterIdx, biMCForDMVR, useAltHpelIf );
JVET_J0090_SET_CACHE_ENABLE( false );
m_interRes.m_if.filterVer( m_compID, tmpBuf.buf + ((vFilterSize>>1) -1)*tmpBuf.stride, tmpBuf.stride, dst, dstStride, bw, bh, yFrac, false, true, m_cu.chromaFormat, m_cu.slice->clpRng(m_compID), nFilterIdx, biMCForDMVR, useAltHpelIf );
JVET_J0090_SET_CACHE_ENABLE( true );
}
return dstBuf;
}
在deriveMvUni函数中进行单向MV的细化:
template <int tplSize>
void TplMatchingCtrl::deriveMvUni()
{
if (m_minCost == std::numeric_limits<Distortion>::max())
{
m_minCost = xGetTempMatchError<tplSize>(m_mvStart); // 计算初始位置处模板的Cost
}
if (m_maxSearchRounds <= 0)
{
return;
}
// 搜索步长
int searchStepShift = (m_cu.imv == IMV_4PEL ? MV_FRACTIONAL_BITS_INTERNAL + 2 : MV_FRACTIONAL_BITS_INTERNAL);
xRefineMvSearch<tplSize, TplMatchingCtrl::TMSEARCH_DIAMOND>(m_maxSearchRounds, searchStepShift);
xRefineMvSearch<tplSize, TplMatchingCtrl::TMSEARCH_CROSS >( 1, searchStepShift);
xRefineMvSearch<tplSize, TplMatchingCtrl::TMSEARCH_CROSS >( 1, searchStepShift - 1);
#if MULTI_PASS_DMVR
if (!m_pu.bdmvrRefine)
{
#endif
xRefineMvSearch<tplSize, TplMatchingCtrl::TMSEARCH_CROSS >( 1, searchStepShift - 2);
xRefineMvSearch<tplSize, TplMatchingCtrl::TMSEARCH_CROSS >( 1, searchStepShift - 3);
#if MULTI_PASS_DMVR
}
else
{
xDeriveCostBasedMv<TplMatchingCtrl::TMSEARCH_CROSS>();
}
#endif
}