Declare
- 下文中,记弹 \(1\to i\) 的期望弹奏次数为 \(dp_i\),且 \(dp_1=n\)。
- 记音符串为 \(a_i\)。
Analysis
首先读懂题意,我们发现,对于当前正在弹奏的一段乐曲,有以下两种情况:
- 当前段的某一后缀与前缀相同。
- 当前段的任何后缀都与前缀不同。(当前串是“独立”的串)
第二种情况比较简单,对于第二种情况,那么有:
\[dp_i=dp_{i-1}+\frac{1}{n}+\frac{1}{n}\times(dp_i-dp_1+1)+\frac{n-2}{n}\times(dp_i+1) \] \[\Rightarrow dp_i=n\times dp_{i-1} \]其中,第一个 \(\frac{1}{n}\) 表示第 \(i\) 个音符直接弹奏正确的期望,\(\frac{1}{n}\times(dp_i-dp_1+1)\) 表示这个音符弹奏错误,但与第一个音符相同的期望,\(\frac{n-2}{n}\times(dp_i+1)\) 表示这个音符弹奏错误,同时与第一个音符不同的期望。
样例三对应的就是这种情况。
3
3
1 2 3
3
9
27
接下来讨论第一种情况。
前后缀显然可以用 KMP 预处理处理。
设已经弹奏了 \(1\to i\),现在要弹奏第 \(i+1\) 个音符。枚举当前弹奏音符 \(j\)。弹错时,需要找到最长的已经弹对的前缀,记为 \(k_j\),于是有:
\[dp_{i+1}=dp_i+\frac{1}{n}+\frac{1}{n}\times\sum\limits_{j=1}^n(dp_{i+1}-dp_{k_j}+1)(j\neq a_{i+1}) \] \[\Rightarrow dp_{i+1}=dp_i+n+\sum\limits_{j=1}^n(dp_{i+1}-dp_{k_j})(j\neq a_{i+1}) \]其中 \(\frac{1}{n}\times\sum\limits_{j=1}^n(dp_{i+1}-dp_{k_j}+1)(j\neq a_{i+1})\) 表示弹错了,但已经弹对前缀 \(1\to k\),接着从 \(k+1\) 弹到 \(n+1\) 的期望。两层循环跑一下就完了。时间复杂度 \(\mathcal{O}(n\times m)\),实际可过。
dp[1] = n;
F(i, 1, m) {//弹第 i 个音符
long long s = 0;
F(j, 1, n){//枚举弹的音符
if (j != a[i + 1]) {
k = fail[i];
while (k && a[k + 1] != j) k = fail[k];
if (a[k + 1] == j) ++k;
s += dp[i] - dp[k];
}
}
dp[i + 1] = (s + n + dp[i]) % mod;
}
虽然可过,但 \(n\times m=10^8\) 还是蛮卡的。能不能更快呢?事实上,计算 \(dp\) 可以优化到 \(\mathcal{O}(n)\) 的复杂度。
//讲解有待更新
code
#include <bits/stdc++.h>
#define reg register
#define ll long long
#define _min(x, y) ((x) < (y) ? (x) : (y))
#define _max(x, y) ((x) > (y) ? (x) : (y))
#define Min(x, y) ((x) > (y) and ((x) = (y)))
#define Max(x, y) ((x) < (y) and ((x) = (y)))
#define F(i, a, b) for (reg int i = (a); i <= (b); ++i)
#define PF(i, a, b) for (reg int i = (a); i >= (b); --i)
#define For(i, x) for (reg int i = head[(x)]; i; i = net[(i)])
using namespace std;
bool beginning;
inline int read();
const int N = 1e6 + 5, mod = 1e9 + 7;
int n, m, a[N], fail[N];
inline void init() {
int j = 0;
F(i, 2, m) {
while (j and a[j + 1] != a[i]) j = fail[j];
j += (a[j + 1] == a[i]);
fail[i] = j;
}
}
bool ending;
int main() {
// printf("%.2lfMB\n",1.0*(&beginning-&ending)/1024/1024);
n = read(), m = read();
F(i, 1, m) {
a[i] = read();
}
init();
int t = n;
F(i, 1, m) {
a[i] = a[fail[i]] + t;
if (a[i] >= mod) a[i] -= mod;
t = 1ll * t * n % mod;
printf("%d\n", a[i]);
}
return 0;
}
inline int read() {
reg int x = 0;
reg bool f = 1;
reg char c = getchar();
while (!isdigit(c)) {
f = c ^ 45;
c = getchar();
}
while (isdigit(c)) x = (x << 3) + (x << 1) + (c ^ 48), c = getchar();
return f ? x : -x;
}