JZOJ 5608 Subset(cdq 分治)

刚看到这道题的时候觉得特别难,里面的限制条件太恶心了,不好做

题面

给出三个 1 1 1 到 n n n 的排列 a , b , c a,b,c a,b,c 。
称三元组 ( x , y , z ) (x,y,z) (x,y,z) 是合法的当且仅当存在一个下标集合 S ⊆ [ n ] S⊆[n] S⊆[n] 满足
( x , y , z ) = ( max ⁡ i ⊆ S a i , max ⁡ i ⊆ S b i , max ⁡ i ⊆ S c i ) (x,y,z)=(\max_{i⊆S}a_i,\max_{i⊆S}b_i,\max_{i⊆S}c_i) (x,y,z)=(i⊆Smax​ai​,i⊆Smax​bi​,i⊆Smax​ci​)求合法三元组的个数。

题解

最关键能打消你安心骗分念头的是这部推导:若 ( x , y , z ) (x,y,z) (x,y,z) 合法,则一定存在 ∣ S ∣ ≤ 3 |S|\leq 3 ∣S∣≤3 的集合 S S S 可以满足上述条件。因为三个下标就足以提供三个最大值了,多者无益。

接下来,相当于要找三个数 i , j , k i,j,k i,j,k 满足 a [ i ] ≥ max ⁡ ( a [ j ] , a [ k ] ) , b [ j ] ≥ max ⁡ ( b [ i ] , b [ k ] ) , c [ k ] ≥ max ⁡ ( c [ i ] , c [ j ] ) a[i]≥\max(a[j],a[k]),b[j]≥\max(b[i],b[k]),c[k]≥\max(c[i],c[j]) a[i]≥max(a[j],a[k]),b[j]≥max(b[i],b[k]),c[k]≥max(c[i],c[j]),这个限制条件很多,直接做颇费脑筋,做不出来,不妨用用容斥

我们根据 S 的大小和三个最大值点的位置考虑,

对于 ∣ S ∣ = 1 |S|=1 ∣S∣=1 的情况,有 n 个三元组。

对于 ∣ S ∣ = 2 |S|=2 ∣S∣=2 的情况,可以总数减去一组完全大于另一组的方案数。

对于 ∣ S ∣ = 3 |S|=3 ∣S∣=3 的情况,分类讨论容斥,答案等于总数减去一组完全大于另两组、一组 2/3 大于另两组的方案:

  • 一组完全大于另两组:跑个简单的 cdq 分治。
  • 一组 2/3 大于另两组:再做个小容斥,算出只考虑某两个参量的方案数,再分别减去一组完全大于另两组的方案数,具体地说,是(一组 a,b 大于另两组的方案 - 一组完全大于另两组方案)+(一组 b,c 大于另两组的方案 - 一组完全大于另两组方案)+(一组 a,c 大于另两组的方案 - 一组完全大于另两组方案),跑三个简单的 cdq 分治。

注意对于计算比某列更小的二元组时,表达式是 c n t ⋅ ( c n t − 1 ) / 2 cnt\cdot(cnt-1)/2 cnt⋅(cnt−1)/2 ,这里的 cnt 是最终求出的比它小的列数,因此不能在 cdq 分治中直接加贡献。

把这三种情况加上就完了,感觉也不难推,但是考场上却想不到,无乃余是笑欤!

CODE

#include<set>
#include<map>
#include<queue>
#include<vector>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
#define MAXN 100005
#define DB double
#define LL long long
#define ENDL putchar('\n')
#define SI set<cp>::iterator
#define lowbit(x) (-(x) & (x))
LL read() {
	LL f = 1,x = 0;char s = getchar();
	while(s < '0' || s > '9') {if(s=='-')f = -f;s = getchar();}
	while(s >= '0' && s <= '9') {x=x*10+(s-'0');s = getchar();}
	return f * x;
}
int n,m,i,j,s,o,k;
struct it{
	int a,b,c,cnt;it(){a=b=c=cnt=0;}
	it(int A,int B,int C){a=A;b=B;c=C;cnt = 0;}
}ss[MAXN];
bool cmpa(it x,it y) {return x.a < y.a;}
bool cmpb(it x,it y) {return x.b < y.b;}
bool cmpc(it x,it y) {return x.c < y.c;}
int c[MAXN];
void addt(int x,int y) {while(x<=n)c[x]+=y,x+=lowbit(x);}
int sum(int x) {int as=0;while(x>0)as+=c[x],x-=lowbit(x);return as;}
LL ans;
void cdq(int l,int r,int op) {
	if(l >= r) return ;
	if(l == 1 && r == n) sort(ss + 1,ss + 1 + n,cmpa);
	int mid = (l + r) >> 1;
	cdq(l,mid,op);cdq(mid+1,r,op);
	sort(ss + l,ss + r + 1,cmpb);
	int sm2 = 0;
	for(int i = l;i <= r;i ++) {
		if(ss[i].a <= mid) addt(ss[i].c,1),sm2 ++;
		else {
			int sm = sum(ss[i].c);
			if(op == 1) ans += sm;
			else if(op == 2) ss[i].cnt += sm;
			else if(op == 3) ss[i].cnt += sm2;
		}
	}
	for(int i = l;i <= r;i ++) {
		if(ss[i].a <= mid) addt(ss[i].c,-1);
	}
	return ;
}
int main() {
	freopen("subset.in","r",stdin);
	freopen("subset.out","w",stdout);
	n = read();
	for(int i = 1;i <= n;i ++) {
		ss[i].a = read();
	}
	for(int i = 1;i <= n;i ++) {
		ss[i].b = read();
	}
	for(int i = 1;i <= n;i ++) {
		ss[i].c = read();
	}
	LL asA = n,asB = 0,asC = n*1ll*(n-1)/2ll*(n-2)/3ll;
	LL ans1 = 0,ans2 = 0;
	ans = 0; cdq(1,n,1); asB = n*1ll*(n-1)/2ll - ans;
	ans = 0; cdq(1,n,2);
	for(int i = 1;i <= n;i ++) {
		ans1 += ss[i].cnt *1ll* (ss[i].cnt-1) / 2ll;
		ss[i].cnt = 0;
	}
	ans = 0; cdq(1,n,3);
	for(int i = 1;i <= n;i ++) {
		ans2 += ss[i].cnt *1ll* (ss[i].cnt-1) / 2ll;
		ss[i].cnt = 0;
	} ans2 -= ans1;
	for(int i = 1;i <= n;i ++) swap(ss[i].b,ss[i].c);
	ans = 0; cdq(1,n,3);
	for(int i = 1;i <= n;i ++) {
		ans2 += ss[i].cnt *1ll* (ss[i].cnt-1) / 2ll;
		ss[i].cnt = 0;
	} ans2 -= ans1;
	for(int i = 1;i <= n;i ++) {swap(ss[i].b,ss[i].c);swap(ss[i].a,ss[i].c);}
	ans = 0; cdq(1,n,3); 
	for(int i = 1;i <= n;i ++) {
		ans2 += ss[i].cnt *1ll* (ss[i].cnt-1) / 2ll;
		ss[i].cnt = 0;
	} ans2 -= ans1;
	for(int i = 1;i <= n;i ++) swap(ss[i].a,ss[i].c);
	asC -= ans1; asC -= ans2;
	printf("%lld\n",asA+asB+asC);
	return 0;
}
上一篇:[R]如何篩選出特定子集數據? subset()


下一篇:Leetcode算法刷题笔记-递归与回溯