[ICPC]2020沈阳L.Bit Sequence

题意:

给定一个长度为\(m\)的序列\(a\),问\([0,L]\)中有多少个数\(x\)满足\(popcount(x+i)\mod 2 = a_i\)。
\(L\le 10^{18}\)
\(m\le 100\)

题解:

想一会儿之后会注意到\(m\le 100\)。
显然如果涉及到加法的popcount很难计算,可以想办法转化为二进制下的构造。
我们把最后7位单独拎出来,这样前面的位就跟最后一坨东西关系变小了不少。

我们把最后尾巴上的一点特例给暴力算了之后,\(l\)的最后7位就是0了,然后我们只需要考虑每个0到128的变化过程。

这时候,由于\(128>100\),所以在一段中,每个数往后\(m\)个数,最多只会跨过一段,这一段跨过造成的变化只有:对前面进位,造成1的个数的改变。
这个改变只有4种:0变成0,0变成1,1变成0,1变成1.
可以算出在每种改变下,这一段中的满足的数的个数。

于是我们只需要计算每种改变有多少个即可。
这玩意就可以通过一些数位dp的技巧求了。

首先我们预处理出\([0,2^k)\)(这是去掉后7位以后)中每种变化有多少个。
这个很好dp:
\(f[i][0/1][0/1]\)表示\([0, 2^i)\)中每种变化有多少个。

\[f[i][j][k] = f[i-1][j][k] + f[i-1][j\oplus 1][k\oplus 1] \]

还有一个额外的转移:

\[f[i][(i - 1) \& 1][1] += 1 \]

然后计算\([0,L)\)的变化个数,就从高位开始,对于每个\(1\),取0的时候,就加上对应位的变化种数,此时还要注意还要异或上前面1的数量的奇偶性。

由于上面dp是不包含\(2^i\)次的,还需要再额外转移一下区间末尾的情况。

代码:

#include <bits/stdc++.h>
#define get(x) (popcount(x) & 1)
#define int long long
#define pt(x) cout << x << endl;
#define Mid ((l + r) / 2)
#define lson (rt << 1)
#define rson (rt << 1 | 1)
using namespace std;
const int N = 109;
int m, l, a[N], ans;
int f[60][2][2];
int popcount(int x) {
	return __builtin_popcount(x >> 32) + __builtin_popcount(x & ((1ll << 32) - 1));
}
int cal(int l, int r) {
	int ans = 0;
	for(int i = l; i <= r; i++) {
		int f = 1;
		for(int j = 0; j < m; j++) {
			if(a[j] != get(i + j)) {
				f = 0;
				break;
			}
		}
		ans += f;
	}
	return ans;
}
void work() {
	int cnt[2][2] = {0};
	ans = 0;
	scanf("%lld%lld", &m, &l);
	for(int i = 0; i < m; i++) scanf("%d", &a[i]);
	if(l < 200) {
		ans = cal(0, l);
		printf("%lld\n", ans);
		return ;
	}
	for(int i = 0; i <= 1; i++) {
		for(int k = 0; k <= 1; k++) {
			for(int p = 0; p < 128; p++) {
				int f = 1;
				for(int j = 0; j < m; j++) {
					int t = get(((1ll << 7) - 1) & (p + j));
					if(p + j >= 128) t ^= k;
					else t ^= i;
					if(t != a[j]) {
						f = 0;
						break;
					}
				}
				cnt[i][k] += f;
			}
		}
	}
	int t = ((1ll << 7) - 1) & l;
	ans = cal(l - t, l);
	l >>= 7;
	int ac[2][2] = {0}, now = 0;
	for(int i = 60; i >= 0; i--) if(l >> i & 1) {
		int t = get(now);
		for(int qwq = 0; qwq <= 1; qwq++)
			for(int qaq = 0; qaq <= 1; qaq++)
				ac[qwq ^ t][qaq ^ t] += f[i][qwq][qaq];
		now += 1ll << i;
		ac[get(now - 1)][get(now)]++;
	}
	for(int qwq = 0; qwq <= 1; qwq++)
		for(int qaq = 0; qaq <= 1; qaq++)
			ans += ac[qwq][qaq] * cnt[qwq][qaq];
	printf("%lld\n", ans);
	return ;
}
signed main()
{
	for(int i = 1; i <= 60; i++) {
		for(int j = 0; j <= 1; j++) {
			for(int k = 0; k <= 1; k++) {
				f[i][j][k] = f[i - 1][j][k] + f[i - 1][j ^ 1][k ^ 1];
			}
		}
		f[i][(i - 1) & 1][1] += 1;
	}
	int Case;
	scanf("%lld", &Case);
	while(Case--) work();
	return 0;
}
上一篇:SQL Server 2008安装过程中的问题


下一篇:面试必问:JVM类加载机制详细解析