刚看到这道题的时候觉得特别难,里面的限制条件太恶心了,不好做
题面
给出三个
1
1
1 到
n
n
n 的排列
a
,
b
,
c
a,b,c
a,b,c 。
称三元组
(
x
,
y
,
z
)
(x,y,z)
(x,y,z) 是合法的当且仅当存在一个下标集合
S
⊆
[
n
]
S⊆[n]
S⊆[n] 满足
(
x
,
y
,
z
)
=
(
max
i
⊆
S
a
i
,
max
i
⊆
S
b
i
,
max
i
⊆
S
c
i
)
(x,y,z)=(\max_{i⊆S}a_i,\max_{i⊆S}b_i,\max_{i⊆S}c_i)
(x,y,z)=(i⊆Smaxai,i⊆Smaxbi,i⊆Smaxci)求合法三元组的个数。
题解
最关键能打消你安心骗分念头的是这部推导:若 ( x , y , z ) (x,y,z) (x,y,z) 合法,则一定存在 ∣ S ∣ ≤ 3 |S|\leq 3 ∣S∣≤3 的集合 S S S 可以满足上述条件。因为三个下标就足以提供三个最大值了,多者无益。
接下来,相当于要找三个数 i , j , k i,j,k i,j,k 满足 a [ i ] ≥ max ( a [ j ] , a [ k ] ) , b [ j ] ≥ max ( b [ i ] , b [ k ] ) , c [ k ] ≥ max ( c [ i ] , c [ j ] ) a[i]≥\max(a[j],a[k]),b[j]≥\max(b[i],b[k]),c[k]≥\max(c[i],c[j]) a[i]≥max(a[j],a[k]),b[j]≥max(b[i],b[k]),c[k]≥max(c[i],c[j]),这个限制条件很多,直接做颇费脑筋,做不出来,不妨用用容斥。
我们根据 S 的大小和三个最大值点的位置考虑,
对于 ∣ S ∣ = 1 |S|=1 ∣S∣=1 的情况,有 n 个三元组。
对于 ∣ S ∣ = 2 |S|=2 ∣S∣=2 的情况,可以总数减去一组完全大于另一组的方案数。
对于 ∣ S ∣ = 3 |S|=3 ∣S∣=3 的情况,分类讨论容斥,答案等于总数减去一组完全大于另两组、一组 2/3 大于另两组的方案:
- 一组完全大于另两组:跑个简单的 cdq 分治。
- 一组 2/3 大于另两组:再做个小容斥,算出只考虑某两个参量的方案数,再分别减去一组完全大于另两组的方案数,具体地说,是(一组 a,b 大于另两组的方案 - 一组完全大于另两组方案)+(一组 b,c 大于另两组的方案 - 一组完全大于另两组方案)+(一组 a,c 大于另两组的方案 - 一组完全大于另两组方案),跑三个简单的 cdq 分治。
注意对于计算比某列更小的二元组时,表达式是 c n t ⋅ ( c n t − 1 ) / 2 cnt\cdot(cnt-1)/2 cnt⋅(cnt−1)/2 ,这里的 cnt 是最终求出的比它小的列数,因此不能在 cdq 分治中直接加贡献。
把这三种情况加上就完了,感觉也不难推,但是考场上却想不到,无乃余是笑欤!
CODE
#include<set>
#include<map>
#include<queue>
#include<vector>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
#define MAXN 100005
#define DB double
#define LL long long
#define ENDL putchar('\n')
#define SI set<cp>::iterator
#define lowbit(x) (-(x) & (x))
LL read() {
LL f = 1,x = 0;char s = getchar();
while(s < '0' || s > '9') {if(s=='-')f = -f;s = getchar();}
while(s >= '0' && s <= '9') {x=x*10+(s-'0');s = getchar();}
return f * x;
}
int n,m,i,j,s,o,k;
struct it{
int a,b,c,cnt;it(){a=b=c=cnt=0;}
it(int A,int B,int C){a=A;b=B;c=C;cnt = 0;}
}ss[MAXN];
bool cmpa(it x,it y) {return x.a < y.a;}
bool cmpb(it x,it y) {return x.b < y.b;}
bool cmpc(it x,it y) {return x.c < y.c;}
int c[MAXN];
void addt(int x,int y) {while(x<=n)c[x]+=y,x+=lowbit(x);}
int sum(int x) {int as=0;while(x>0)as+=c[x],x-=lowbit(x);return as;}
LL ans;
void cdq(int l,int r,int op) {
if(l >= r) return ;
if(l == 1 && r == n) sort(ss + 1,ss + 1 + n,cmpa);
int mid = (l + r) >> 1;
cdq(l,mid,op);cdq(mid+1,r,op);
sort(ss + l,ss + r + 1,cmpb);
int sm2 = 0;
for(int i = l;i <= r;i ++) {
if(ss[i].a <= mid) addt(ss[i].c,1),sm2 ++;
else {
int sm = sum(ss[i].c);
if(op == 1) ans += sm;
else if(op == 2) ss[i].cnt += sm;
else if(op == 3) ss[i].cnt += sm2;
}
}
for(int i = l;i <= r;i ++) {
if(ss[i].a <= mid) addt(ss[i].c,-1);
}
return ;
}
int main() {
freopen("subset.in","r",stdin);
freopen("subset.out","w",stdout);
n = read();
for(int i = 1;i <= n;i ++) {
ss[i].a = read();
}
for(int i = 1;i <= n;i ++) {
ss[i].b = read();
}
for(int i = 1;i <= n;i ++) {
ss[i].c = read();
}
LL asA = n,asB = 0,asC = n*1ll*(n-1)/2ll*(n-2)/3ll;
LL ans1 = 0,ans2 = 0;
ans = 0; cdq(1,n,1); asB = n*1ll*(n-1)/2ll - ans;
ans = 0; cdq(1,n,2);
for(int i = 1;i <= n;i ++) {
ans1 += ss[i].cnt *1ll* (ss[i].cnt-1) / 2ll;
ss[i].cnt = 0;
}
ans = 0; cdq(1,n,3);
for(int i = 1;i <= n;i ++) {
ans2 += ss[i].cnt *1ll* (ss[i].cnt-1) / 2ll;
ss[i].cnt = 0;
} ans2 -= ans1;
for(int i = 1;i <= n;i ++) swap(ss[i].b,ss[i].c);
ans = 0; cdq(1,n,3);
for(int i = 1;i <= n;i ++) {
ans2 += ss[i].cnt *1ll* (ss[i].cnt-1) / 2ll;
ss[i].cnt = 0;
} ans2 -= ans1;
for(int i = 1;i <= n;i ++) {swap(ss[i].b,ss[i].c);swap(ss[i].a,ss[i].c);}
ans = 0; cdq(1,n,3);
for(int i = 1;i <= n;i ++) {
ans2 += ss[i].cnt *1ll* (ss[i].cnt-1) / 2ll;
ss[i].cnt = 0;
} ans2 -= ans1;
for(int i = 1;i <= n;i ++) swap(ss[i].a,ss[i].c);
asC -= ans1; asC -= ans2;
printf("%lld\n",asA+asB+asC);
return 0;
}