题意:
给定一个长度为\(m\)的序列\(a\),问\([0,L]\)中有多少个数\(x\)满足\(popcount(x+i)\mod 2 = a_i\)。
\(L\le 10^{18}\)
\(m\le 100\)
题解:
想一会儿之后会注意到\(m\le 100\)。
显然如果涉及到加法的popcount很难计算,可以想办法转化为二进制下的构造。
我们把最后7位单独拎出来,这样前面的位就跟最后一坨东西关系变小了不少。
我们把最后尾巴上的一点特例给暴力算了之后,\(l\)的最后7位就是0了,然后我们只需要考虑每个0到128的变化过程。
这时候,由于\(128>100\),所以在一段中,每个数往后\(m\)个数,最多只会跨过一段,这一段跨过造成的变化只有:对前面进位,造成1的个数的改变。
这个改变只有4种:0变成0,0变成1,1变成0,1变成1.
可以算出在每种改变下,这一段中的满足的数的个数。
于是我们只需要计算每种改变有多少个即可。
这玩意就可以通过一些数位dp的技巧求了。
首先我们预处理出\([0,2^k)\)(这是去掉后7位以后)中每种变化有多少个。
这个很好dp:
\(f[i][0/1][0/1]\)表示\([0, 2^i)\)中每种变化有多少个。
还有一个额外的转移:
\[f[i][(i - 1) \& 1][1] += 1 \]然后计算\([0,L)\)的变化个数,就从高位开始,对于每个\(1\),取0的时候,就加上对应位的变化种数,此时还要注意还要异或上前面1的数量的奇偶性。
由于上面dp是不包含\(2^i\)次的,还需要再额外转移一下区间末尾的情况。
代码:
#include <bits/stdc++.h>
#define get(x) (popcount(x) & 1)
#define int long long
#define pt(x) cout << x << endl;
#define Mid ((l + r) / 2)
#define lson (rt << 1)
#define rson (rt << 1 | 1)
using namespace std;
const int N = 109;
int m, l, a[N], ans;
int f[60][2][2];
int popcount(int x) {
return __builtin_popcount(x >> 32) + __builtin_popcount(x & ((1ll << 32) - 1));
}
int cal(int l, int r) {
int ans = 0;
for(int i = l; i <= r; i++) {
int f = 1;
for(int j = 0; j < m; j++) {
if(a[j] != get(i + j)) {
f = 0;
break;
}
}
ans += f;
}
return ans;
}
void work() {
int cnt[2][2] = {0};
ans = 0;
scanf("%lld%lld", &m, &l);
for(int i = 0; i < m; i++) scanf("%d", &a[i]);
if(l < 200) {
ans = cal(0, l);
printf("%lld\n", ans);
return ;
}
for(int i = 0; i <= 1; i++) {
for(int k = 0; k <= 1; k++) {
for(int p = 0; p < 128; p++) {
int f = 1;
for(int j = 0; j < m; j++) {
int t = get(((1ll << 7) - 1) & (p + j));
if(p + j >= 128) t ^= k;
else t ^= i;
if(t != a[j]) {
f = 0;
break;
}
}
cnt[i][k] += f;
}
}
}
int t = ((1ll << 7) - 1) & l;
ans = cal(l - t, l);
l >>= 7;
int ac[2][2] = {0}, now = 0;
for(int i = 60; i >= 0; i--) if(l >> i & 1) {
int t = get(now);
for(int qwq = 0; qwq <= 1; qwq++)
for(int qaq = 0; qaq <= 1; qaq++)
ac[qwq ^ t][qaq ^ t] += f[i][qwq][qaq];
now += 1ll << i;
ac[get(now - 1)][get(now)]++;
}
for(int qwq = 0; qwq <= 1; qwq++)
for(int qaq = 0; qaq <= 1; qaq++)
ans += ac[qwq][qaq] * cnt[qwq][qaq];
printf("%lld\n", ans);
return ;
}
signed main()
{
for(int i = 1; i <= 60; i++) {
for(int j = 0; j <= 1; j++) {
for(int k = 0; k <= 1; k++) {
f[i][j][k] = f[i - 1][j][k] + f[i - 1][j ^ 1][k ^ 1];
}
}
f[i][(i - 1) & 1][1] += 1;
}
int Case;
scanf("%lld", &Case);
while(Case--) work();
return 0;
}