题意:
传送门
给你\(A,B,C\),要求你给出有多少对\((x, y)\)满足\(x\in [1,A],y\in [1,B]\),且满足以下任意一个条件:\(x \& y > C\)或者\(x \oplus y < C\)。
思路:
数位\(DP\),以前做的数位\(DP\)只是和一个数相关,今天是和两个数相关,有点神奇。这里我开了九维,第\(i\)位\(x\)是\(j\),\(y\)是\(k\),对\(第一种\)情况,对\(第二种\)情况,\(x\)到达上界,\(y\)到达上界,\(x\)前导零,\(y\)前导零。一开始只开了前五维,但是\(T\)了。因为在二进制中,其中一个数达到上界的情况其实非常多,那么如果我每次都要求\(!limita\ \&\&\ !limitb\)时才返回\(dp\)那么势必造成很多情况都要\(dfs\)很多次求解。前导零同理。
代码:
#include<map>
#include<set>
#include<queue>
#include<stack>
#include<ctime>
#include<cmath>
#include<cstdio>
#include<string>
#include<vector>
#include<cstring>
#include<sstream>
#include<iostream>
#include<algorithm>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int maxn = 100000 + 5;
const int INF = 0x3f3f3f3f;
const ull seed = 131;
const ll MOD = 1e9 + 7;
using namespace std;
ll dp[40][3][3][3][3][3][3][3][3]; //第i位x是j,y是k,对第一种情况,对第二种情况,x到达上界,y到达上界,x前导零,y前导零
//0不知 1不满足 2满足
int bit1[40], bit2[40], C;
//x and y > C
//x xor y < C
ll dfs(int pos, int x, int y, int oxor, int oand, int stx, int sta, bool limita, bool limitb, bool leadx, bool leady){
if(pos == -1){
if((stx == 2 || sta == 2) && !leadx && !leady) return 1;
return 0;
}
if(dp[pos][x][y][stx][sta][limita][limitb][leadx][leady] != -1) return dp[pos][x][y][stx][sta][limita][limitb][leadx][leady];
int top1 = limita? bit1[pos] : 1;
int top2 = limitb? bit2[pos] : 1;
ll ret = 0;
for(int i = 0; i <= top1; i++){
for(int j = 0; j <= top2; j++){
int nxor = (oxor << 1) + (i ^ j), nand = (oand << 1) + (i & j);
int nstx = stx, nsta = sta;
if(stx == 0 && nxor > (C >> pos)){
nstx = 1;
}
else if(stx == 0 && nxor < (C >> pos)){
nstx = 2;
}
if(sta == 0 && nand > (C >> pos)){
nsta = 2;
}
else if(sta == 0 && nand < (C >> pos)){
nsta = 1;
};
ret += dfs(pos - 1, i, j, nxor, nand, nstx, nsta, limita && i == top1, limitb && j == top2, leadx && !i, leady && !j);
}
}
dp[pos][x][y][stx][sta][limita][limitb][leadx][leady] = ret;
return ret;
}
ll solve(int A, int B){
int pos = 0;
if(A < B) swap(A, B);
while(A){
bit1[pos] = A & 1;
A >>= 1;
bit2[pos++] = B & 1;
B >>= 1;
}
ll ans = dfs(pos - 1, 0, 0, 0, 0, 0, 0, true, true, true, true);
return ans;
}
int main(){
int T;
scanf("%d", &T);
while(T--){
memset(dp, -1, sizeof(dp));
int a, b;
scanf("%d%d%d", &a, &b, &C);
ll ans = solve(a, b);
printf("%lld\n", ans);
}
return 0;
}