思路:先二分出\(k\)大值,在计算比\(k\)大值大的和。
\(part 1:\)二分求\(k\)大值
考虑建一棵\(01trie\),每次二分值\(mid\),枚举每个数,记异或值大于等于\(mid\)的数量。
二分一个\(log\),枚举每个数是\(\Theta(n)\),查询异或值大于等于\(mid\)的数量是一个\(log\),故此部分复杂度\(\Theta(n\log^2n)\)。
inline long long check(int x){
long long tot=0;
for(int i=1;i<=n;i++){
int u=0;
for(int j=30;j>=0;j--){
int t1=((1<<j)&a[i])!=0;
int t2=((1<<j)&x)!=0;
if(!t2)tot+=val[ch[u][t1^1]],u=ch[u][t1];
else u=ch[u][t1^1];
if(!u)break;
}
tot+=val[u];
}
return tot/2;
}
int l=0,r=1<<30,kth=0;
while(l<=r){
int mid=(l+r)>>1;
if(check(mid)>=k)kth=mid,l=mid+1;
else r=mid-1;
}
\(part 2:\)计算异或值中大于等于\(k\)大值的和
预处理\(tr\)数组,\(tr[x][y]\)表示在\(x\)子树中的叶子节点上有多少个数的\(2^y\)上是\(1\)。
inline void pre(int x,int dep,int z){
if(x==0)return;
if(dep==0){
for(int i=0;i<=30;i++)
if((1<<i)&z)tr[x][i]=val[x];
return;
}
pre(ch[x][0],dep-1,z);
pre(ch[x][1],dep-1,z|(1<<dep-1));
for(int i=0;i<=30;i++)
tr[x][i]=tr[ch[x][0]][i]+tr[ch[x][1]][i];
}
\(tr\)数组使我们可以计算整棵子树异或一个数的和。
接着就可以在\(01trie\)上计算了。
枚举每一个数,走一遍\(01trie\),遇见可以直接加的就枚举每一位,计算贡献。
这一部分的复杂度也是\(\Theta(n\log^2n)\)。
inline void solve(int kth){
for(int i=1;i<=n;i++){
int u=0;
for(int j=30;j>=0;j--){
int t1=((1<<j)&a[i])!=0;
int t2=((1<<j)&kth)!=0;
if(!t2){
int t=ch[u][t1^1];
for(int k=0;k<=30;k++){
int t3=((1<<k)&a[i])!=0;
if(t3)ans=(ans+1ll*(val[t]-tr[t][k])*(1ll<<k))%mod;
else ans=(ans+1ll*tr[t][k]*(1ll<<k))%mod;
}u=ch[u][t1];
}else u=ch[u][t1^1];
if(u==0)break;
}
ans=(ans+1ll*val[u]*kth)%mod;
}
}
注意:
1.每一个值被计算了两次,故答案要除以\(2\)。
ans=ans*inv2%mod;
2.可能\(k\)大值与\(k+1,k+2,k+3……\)大值相等,注意最后要减掉这些“凑合的”。
ans=((ans-1ll*(check(kth)-k)*kth%mod)%mod+mod)%mod;
放上完整代码,以供参考:
#include<bits/stdc++.h>
using namespace std;
const int maxn=5e4+10;
const int mod=1e9+7;
const int inv2=5e8+4;
int n,a[maxn];
long long ans,k;
inline int read(){
int x=0,f=1;char c=getchar();
while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9'){x=(x<<1)+(x<<3)+c-'0';c=getchar();}
return x*f;
}
int ch[maxn*20][2],val[maxn*20],cnt;
inline void insert(int x){
int u=0;
for(int i=30;i>=0;i--){
int t=((1<<i)&x)!=0;
if(!ch[u][t])ch[u][t]=++cnt;
u=ch[u][t];val[u]++;
}
}
inline long long check(int x){
long long tot=0;
for(int i=1;i<=n;i++){
int u=0;
for(int j=30;j>=0;j--){
int t1=((1<<j)&a[i])!=0;
int t2=((1<<j)&x)!=0;
if(!t2)tot+=val[ch[u][t1^1]],u=ch[u][t1];
else u=ch[u][t1^1];
if(!u)break;
}
tot+=val[u];
}
return tot/2;
}
int tr[maxn*20][35];
inline void pre(int x,int dep,int z){
if(x==0)return;
if(dep==0){
for(int i=0;i<=30;i++)
if((1<<i)&z)tr[x][i]=val[x];
return;
}
pre(ch[x][0],dep-1,z);
pre(ch[x][1],dep-1,z|(1<<dep-1));
for(int i=0;i<=30;i++)
tr[x][i]=tr[ch[x][0]][i]+tr[ch[x][1]][i];
}
inline void solve(int kth){
for(int i=1;i<=n;i++){
int u=0;
for(int j=30;j>=0;j--){
int t1=((1<<j)&a[i])!=0;
int t2=((1<<j)&kth)!=0;
if(!t2){
int t=ch[u][t1^1];
for(int k=0;k<=30;k++){
int t3=((1<<k)&a[i])!=0;
if(t3)ans=(ans+1ll*(val[t]-tr[t][k])*(1ll<<k))%mod;
else ans=(ans+1ll*tr[t][k]*(1ll<<k))%mod;
}u=ch[u][t1];
//printf("%lld %lld %lld\n",i,j,ans);
}else u=ch[u][t1^1];
if(u==0)break;
}
ans=(ans+1ll*val[u]*kth)%mod;
//printf("%lld %lld\n",i,ans);
}
}
int main(){
n=read();scanf("%lld",&k);
if(!k)return puts("0"),0;
for(int i=1;i<=n;i++)
a[i]=read();
for(int i=1;i<=n;i++)
insert(a[i]);
int l=0,r=1<<30,kth=0;
while(l<=r){
int mid=(l+r)>>1;
if(check(mid)>=k)kth=mid,l=mid+1;
else r=mid-1;
}
pre(ch[0][0],30,0);
pre(ch[0][1],30,1<<30);
solve(kth);//printf("%lld\n",ans);
ans=ans*inv2%mod;
ans=((ans-1ll*(check(kth)-k)*kth%mod)%mod+mod)%mod;
printf("%lld\n",ans);
return 0;
}
深深地感到自己的弱小。