题解 P5282 【模板】快速阶乘算法

传送门

总算是学会了这个算法......


【前置芝士】

  1. 多项式乘法
  2. 任意模数多项式乘法
  3. 多项式连续点值平移

前两个用于处理任意模数意义下的多项式乘法;

第三个用于在未知一个不超过 \((r-l)\) 次的多项式具体形式,但已知其在某连续区间 \([l,r]\) 的 \((r-l+1)\) 个点值时,求出任一与之不交的、长度相同的区间的 \((r-l+1)\) 个点的值。复杂度为 \(O(n\log n)\)。


【分析】

考虑令 \(s=\lfloor\sqrt n\rfloor\) ,则 \(\displaystyle n!=\prod_{i=1}^n i=\prod_{i=0}^{s-1}\prod_{t=1}^s(is+t)\cdot \prod_{i=s^2+1}^n i\) 。

若已求出 \(\displaystyle \prod_{i=0}^{s-1}\prod_{t=1}^s(is+t)\),则后面的那个式子直接暴力点乘,复杂度为 \(o(\sqrt n)\) 的。

有个比较简单的方法是由分治 FFT 或多项式启发式合并求出 \(s\) 次多项式 \(\displaystyle g_s(x)=\prod_{t=1}^s (x+t)\) ;再由多项式多点求值求出 \(g_s(0\cdot s),g_s(1\cdot s),g_s(2\cdot s),\cdots, g_s((s-1)\cdot s)\) 。这些乘起来即为待求项。

这样一来,复杂度为 \(O(s\log^2 s)+O(s\log^2 s)=O(\sqrt n\log^2 n)\) 。


但考虑到最后实际对答案的贡献仅在若干个点处取到。若这些点所处的区间连续,则可能可用多项式连续点值平移求出答案。

很幸运的是,我们令 \(h_s(x)=g_s(sx)\),则待求的点值变换为了 \(h_s(0),h_s(1),h_s(2),\cdots, h_s(s-1)\) ,是长度为 \(s\) 的连续区间。这就可以用到多项式连续点值平移了。

现在,我们可以考虑如何使用多项式连续点值平移凑出这些东西了。

而初始我们也并不知道 \(h_s(x)\) 的任何一个点值,但是我们知道 \(h_1(x)=x+1\) 的所有点值。

因此,我们考虑如何倍增得到 \(h_s(x)\) 的那些点值。

当我们已知 \(h_d(0),h_d(1),\cdots,h_d(d)\) 时,若我们想要得到 \(h_{2d}(0),h_{2d}(1),\cdots,h_{2d}(2d)\)。

对于 \(\displaystyle \forall i\leq d\to h_{2d}(i)=g_{2d}(is)=\prod_{t=1}^{2d}(is+t)=g_d(is)\cdot g_d(is+d)\) 。

由于 \(g_d(is+d)\) 在模 \(P\) 意义下,与 \(h_d(d\cdot s^{-1}+i)\) 等价,因此 \(h_{2d}(i)=h_d(i)\cdot h_d(i+d\cdot s^{-1})\) 。

求解这些点值时,我们只要将 \(h_d(0),h_d(1),\cdots,h_d(d)\) 这些连续点值平移出 \(h_{2d}(0+d\cdot s^{-1}),h_d(1+d\cdot s^{-1}),\cdots, h_d(d+d\cdot s^{-1})\) ,直接两两点乘即可倍增。

而对于后面的那几个点,如果我们已知 \(h_d(d+1),h_d(d+2),\cdots, h_d(2d)\) ,就也能通过这个方法求解。

那怎么知道呢?问就是多项式连续点值平移,用 \(h_d(0),h_d(1),\cdots,h_d(d)\) 直接平移出 \(h_d(d+1),h_d(d+2),\cdots,h_d(2d+1)\) 就可以了(这里因为要求区间不交,会额外平移出 \(h_d(2d+1)\))。

既然这样,那也别客气了,分那么清楚, 第一次先平移出 \(h_d(d+1),h_d(d+2),\cdots, h_d(2d+1)\) ;拼接到原点值后面之后,第二次再直接用 \(h_d(0),h_d(1),\cdots,h_d(2d+1)\) 直接平移出 \(h_d(0+d\cdot s^{-1}),h_d(1+d\cdot s^{-1}),\cdots,h_d(2d+1+d\cdot s^{-1})\) 。点乘则可以得到我们要求的 \(h_{2d}(0),h_{2d}(1),\cdots,h_{2d}(2d)\) 。复杂度为 \(O(d\log d)+O(2d\log 2d)+O(2d)=O(d\log d)\) 。


由于我们可以从 \(h_d\) 的状态推出 \(h_{2d}\) 的状态,我们又已知 \(h_1\) 的状态,待求 \(h_s\) 的状态;于是我们去考虑用类似递归快速幂的迭代方法求出结果。

那还差的步骤就是由 \(h_d\) 的状态推出 \(h_{d+1}\) 的状态。

这个简单, 由于 \(\displaystyle h_{d+1}(i)=g_{d+1}(is)=\prod_{t=1}^{d+1}(is+t)=g_d(is)\cdot (is+d+1)=h_d(i)\cdot (is+d+1)\) ,我们直接 \(O(d)\) 暴力跑过去就行了。

但是由 \(h_d\) 推 \(h_{d+1}\) 时,需要额外知道 \(h_{d+1}(d+1)\) 的信息,它需要用到 \(h_d(d+1)\) 的信息。

但这个无伤大雅,因为 \(h_d\) 推出 \(h_{d+1}\) 的状态一定发生在 \(h_{d\over 2}\) 推出 \(h_d\) 之后。而像上面说的,正好由于区间不交,恰好在 \(h_d\) 推出 \(h_{2d}\) 时,把 \(h_{2d}(2d+1)\) 的给顺便算了。

那就好办了,如果接着要 \(h_{2d}\) 推 \(h_{2d+1}\) ,就把这个正好算了;如果不要,就顺手丢了。于是这一步的复杂度是严格 \(O(d)\) 的。


综上,我们由已知的 \(h_1\) 状态,倍增地推出 \(h_s\) 的状态。最后累乘 \(h_s(0),\cdots, h_s(s-1)\) ,再乘上那些尾巴,就可以算出总答案了。

复杂度 \(\displaystyle T(s)=T({s\over 2})+O(s\log s)+O(s)\) 得到 \(T(s)=O(s\log s)=O(\sqrt n\log n)\) 。(原本还要加上一个 \(o(\sqrt n)\) ,但是作为低阶项舍去了。)


【核心代码】

poly c, d;
inline int get_fac(int n) {
	int pos=curStk, s=sqrt(n)+1e-6;
	vir invs=Inv[s];
	for(int i=s;i>1;i>>=1) Stk[++curStk]=i;//手动压栈

	c.resize(2); c[0]=1; c[1]=s+1;//初始化 h_1 的状态
	for(int l=Stk[curStk]; curStk>pos; l=Stk[--curStk]) {
		ValueTrans(c, d, l>>1, (l>>1)+1);//点值平移出额外状态
		c.resize(2*sz(c));
		for(int i=0; i<sz(d); ++i) c[sz(d)+i]=d[i];
		ValueTrans(c, d, sz(c)-1, invs*vir(l>>1));//点值平移出点乘状态
		for(int i=0; i<sz(c); ++i) c[i]=c[i]*d[i];
		if(l&1) {
			for(int i=0; i<=l; ++i) c[i]=c[i]*vir(i*s+l);//h_l 转为 h_{l+1}
		}
		else c.resize(l+1);
	}

	vir res=1;
	for(int i=0; i<s; ++i) res=res*c[i];//累乘
	for(int i=s*s+1; i<=n; ++i) res=res*vir(i);//处理尾巴
	return res;
}

【拓展】

你以为这就完了?

我发现我的 代码 跑得飞快,看到连时限 4s 的十分之一都没跑到,我不禁“恶向胆边生”。用这个板子改了改,测试了一下阶乘求和的结果。

做法是求出 \(\displaystyle \sum_{i=1}^n i!\) 和 \(\displaystyle \sum_{i=1}^{n-1} i!\),做差算出答案。

那这个问题怎么用点值平移处理呢?

我们记 \(\displaystyle S_n=\sum_{i=1}^n i!\) ,考虑它的转移矩阵:

\(\begin{aligned} \begin{pmatrix} (n+1)! \\S_n \end{pmatrix}&= \begin{pmatrix} n+1&0 \\1&1 \end{pmatrix}\cdot \begin{pmatrix} n! \\S_{n-1} \end{pmatrix} \\&=\prod_{i=1}^n \begin{pmatrix} i+1&0 \\1&1 \end{pmatrix}\cdot \begin{pmatrix} 1 \\0 \end{pmatrix} \end{aligned}\)

同理,我们考虑 \(\displaystyle g_d(x)=\prod_{i=1}^d\begin{pmatrix} x+i+1&0 \\1&1 \end{pmatrix},h_d(i)=g_d(is)\) 。

手玩一下会发现,这个矩阵这个矩阵的乘法也有特点:

\(\begin{pmatrix} a_1&\ \\b_1&c_1 \end{pmatrix}\cdot\begin{pmatrix} a_2&\ \\b_2&c_2 \end{pmatrix}=\begin{pmatrix} a_1a_2&\ \\b_1a_2+c_1b_2&c_1c_2 \end{pmatrix}\)

也就是说,这个矩阵的乘积永远只有三个地方有值,并且右下角恒为 \(1\) 。于是我们维护 \(h_d(x)=\begin{pmatrix}mc_{d,0}(x)&\ \\mc_{d,1}(x)&1\end{pmatrix}\) 。

由原来维护一个点值变成了维护两个点值 而已

由于此时为矩阵乘法 \(h_{2d}(i)=h_d(i+d\cdot s^{-1})\cdot h_d(i)\) 形式不会发生变化,但一定要注意矩阵的左右乘不等价。

发生变化的是 \(h_d\) 转移到 \(h_{d+1}\) 时,此时有:

\(\begin{aligned} &h_{d+1}(i) \\=&\begin{pmatrix}mc_{d+1,0}(i)&\ \\mc_{d+1,1}(i)&1\end{pmatrix} \\=&\begin{pmatrix}i+d+2&0\\1&1\end{pmatrix}h_d(i) \\=&\begin{pmatrix}i+d+2&0\\1&1\end{pmatrix}\begin{pmatrix}mc_{d,0}(x)&\ \\mc_{d,1}(x)&1\end{pmatrix} \\=&\begin{pmatrix}(i+d+2)mc_{d,0}(x)&\ \\mc_{d,0}+mc_{d,1}(x)&1\end{pmatrix} \end{aligned}\)

修改此处即可。

poly mc[2], md[2];
inline int get_sumfac(int n) {
	int pos=curStk, s=sqrt(n)+1e-6;
	vir invs=Inv[s];
	for(int i=s;i>1;i>>=1) Stk[++curStk]=i;

	mc[0].resize(2); mc[0][0]=2; mc[0][1]=s+2;
	mc[1].resize(2); mc[1][0]=mc[1][1]=1;
	for(int l=Stk[curStk]; curStk>pos; l=Stk[--curStk]) {
		for(int t=0; t<2; ++t){
			poly &c = mc[t], &d = md[t];
			ValueTrans(c, d, l>>1, (l>>1)+1);
			c.resize(2*sz(c));
			for(int i=0; i<sz(d); ++i) c[sz(d)+i]=d[i];
			ValueTrans(c, d, sz(c)-1, invs*vir(l>>1));
		}
		for(int i=0; i<sz(mc[0]); ++i) {
			mc[1][i]=mc[0][i]*md[1][i]+mc[1][i];
			mc[0][i]=mc[0][i]*md[0][i];
		}
		if(l&1) {
			for(int i=0; i<=l; ++i) {
				mc[1][i]=mc[0][i]+mc[1][i];
				mc[0][i]=mc[0][i]*vir(i*s+l+1);
			}
		}
		else {
			mc[0].resize(l+1);
			mc[1].resize(l+1);
		}
	}

	vir res0=1, res1=0;
	for(int i=0; i<s; ++i){
		res1=res0*mc[1][i]+res1;
		res0=res0*mc[0][i];
	}
	for(int i=s*s+1; i<=n; ++i){
		res1=res1+res0;
		res0=res0*vir(i+1);
	}
	return res1;
}

结果 还是没跑到一半的时间。

题解 P5282 【模板】快速阶乘算法

上一篇:1626B


下一篇:Mybatis中常用的SQL