题目链接:https://hihocoder.com/problemset/problem/1305
描述
给定两个区间集合 A 和 B,其中集合 A 包含 N 个区间[ A1, A2 ], [ A3, A4 ], ..., [ A2N-1, A2N ],集合 B 包含 M 个区间[ B1, B2 ], [ B3, B4 ], ..., [ B2M-1, B2M ]。求 A - B 的长度。
例如对于 A = {[2, 5], [4, 10], [14, 18]}, B = {[1, 3], [8, 15]}, A - B = {(3, 8), (15, 18]},长度为8。
输入
第一行:包含两个整数 N 和 M (1 ≤ N, M ≤ 100000)。
第二行:包含 2N 个整数 A1, A2, ..., A2N (1 ≤ Ai ≤ 100000000)。
第三行:包含 2M 个整数 B1, B2, ..., B2M (1 ≤= Bi ≤ 100000000)。
输出
一个整数,代表 A - B 的长度。
- 样例输入
-
3 2
2 5 4 10 14 18
1 3 8 15 - 样例输出
-
8
题解:
一开始我的想法很简单,离散化后用区间修改单点查询的BIT,得到标记了所有还在的点,统计一下之后直接求和即可;
于是乎我我就有了一个和标程对拍了10086组数据,但依然WA的程序:
#include<cstdio>
#include<algorithm>
#define MAXN 400010
using namespace std;
int n,m;
struct Inte{
int l,r;
}inte[MAXN/]; struct Discr{//离散化模板
int _size;
int idx[MAXN];//离散化索引
void init(){_size=;}
void push(int val){idx[_size++]=val;}
void discretize()
{
sort(idx,idx+_size);
_size=unique(idx,idx+_size)-idx;
}
int id(int val){return lower_bound(idx,idx+_size,val)-idx+;}//若有"+1"则离散化后的值从1开始,否则从0开始.
int val(int id){return idx[id-];}//若离散化后的值从1开始,则需要"-1".
}discr; //BIT - 区间增加,单点查询 - st
struct _BIT{
int N;
long long C[MAXN];
int lowbit(int x){return x&(-x);}
void init(int n)//初始化共有n个点
{
N=n;
for(int i=;i<=N;i++) C[i]=;
}
long long query(int pos)//查询点pos的值
{
long long ret=;
while(pos<=N)
{
ret+=C[pos];
pos+=lowbit(pos);
}
return ret;
}
void add(int pos,long long val)//区间1~pos加上val
{
while(pos>)
{
C[pos]+=val;
pos-=lowbit(pos);
}
}
}BIT;
//BIT - 区间增加,单点查询 - ed int endpoint[MAXN];
int main()
{
freopen("input.txt","r",stdin);
freopen("output.txt","w",stdout); scanf("%d%d",&n,&m);
discr.init();
for(int i=,l,r;i<=n;i++)
{
scanf("%d%d",&l,&r);
inte[i].l=l, inte[i].r=r;
discr.push(inte[i].l);
discr.push(inte[i].r);
}
for(int i=n+,l,r;i<=n+m;i++)
{
scanf("%d%d",&l,&r);
inte[i].l=l, inte[i].r=r;
discr.push(inte[i].l);
discr.push(inte[i].r);
}
discr.discretize(); BIT.init(discr._size+);
for(int i=;i<=n;i++)
{
int l=discr.id(inte[i].l);
int r=discr.id(inte[i].r);
BIT.add(r,);
BIT.add(l-,-);
}
for(int i=n+;i<=n+m;i++)
{
int l=discr.id(inte[i].l)+;
int r=discr.id(inte[i].r)-;
if(r<l) continue;
BIT.add(r,-);
BIT.add(l-,);
} int cnt=;
long long now=,nex=BIT.query();
if(now<= && nex>) endpoint[cnt++]=; for(int i=;i<=discr._size;i++)
{
now=nex;
nex=BIT.query(i+);
if(i+>discr._size) nex=;
//printf("%d -> %d = %lld\n",discr.val(i),i,now); if(now<= && nex>) endpoint[cnt++]=i+;
else if(now> && nex<=) endpoint[cnt++]=i;
} int ans=;
for(int i=;i<cnt;i+=)
{
int l=discr.val(endpoint[i]);
int r=discr.val(endpoint[i+]);
//printf("%d %d\n",l,r);
ans+=r-l;
}
printf("%d\n",ans);
}
//你问我服不服,当然是不服的
那就老老实实按题解说的做呗;
官方题解https://hihocoder.com/discuss/question/4554:
这道题是一类区间问题的变体,我们先来看一道最基础的区间问题: 给定N个区间[S1, E1], [S2, E2], ... [SN, EN],求这些区间并集的长度。 这道题通常的解法是,我们把这N个区间的2N个端点从左到右排列在数轴上P1, P2, ... P2N。并且如果一个点Pi是原区间的左端点,我们就把它标记成绿色;如果是右端点,就标记成蓝色。 值得注意的是这2N个点中可能存在重合的点。比如假设有两个区间[1, 3]和[3, 5],那么在3这个位置上就同时存在一个绿点(左端点)和蓝点(右端点)。某些情况下我们在排序时需要特别处理重合的点,例如要保证蓝点都排在绿点之前。不过本题我们无需特殊处理,重合的点无论谁在前谁在后都不影响结果。 这2N个点把数轴分成了2N+1段,(-INF, P1), (P1, P2), (P2, P3) ... (P2N-1, P2N), (P2N, +INF)。每一段内部被原来区间集合覆盖的情况都是相同的。换句话说,不会出现(Pi, Pi+1)的左半部分被第1、3、5号区间覆盖,而右半部分只被第1、3号区间覆盖这种情况。 所以我们可以从左到右扫描每一段,令cnt计数器初始值=0。当扫过一个绿点时,cnt++;扫过一个蓝点时cnt--。我们可以发点对于(Pi, Pi+1)这一段,处理完Pi时的cnt值恰好代表了这一段被几个原来的区间同时覆盖。 有了每一段的cnt值,我们可以做很多事情。例如要求区间并集的长度,我们可以找出所有cnt值大于0的段(Pi, Pi+1),并把这些段的长度(Pi+1 - Pi)求和。 我们还可以知道哪段被覆盖了最多次:自然是cnt值最大的段。 对于给定的坐标X,我们可以在O(logN)的时间内求出X这个点被覆盖多少次:我们只需要在P1, P2, ... P2N中二分查找出X的位置,即Pi < X < Pi+1,那么(Pi, Pi+1)这一段的cnt值就是答案。(当X恰好是端点时需要特判,取决于给出的区间是开区间还是闭区间) 好了,我们回到《区间求差》这道题目。我们可以把A和B集合中2N+2M个端点都从左到右排列在数轴上。并且用4种颜色标记出每个点是A的左端点、A的右端点、B的左端点、B的右端点。 然后我们用两个计数器cntA和cntB来分别维护每一段被A集合中的区间覆盖多少次、以及被B集合的区间覆盖多少次。那么如果某一段(Pi, Pi+1)满足cntA>0且cntB=0,那么它一定是A-B的一部分。我们对于这些段的长度求和即可。 整个算法对于端点排序的部分复杂度是O(NlogN)的,对于从左到右扫描复杂度是O(N)的。总体复杂度是O(NlogN)。 |
有了这么详细的题解后,敲成代码就不难了;
AC代码:
#include<bits/stdc++.h>
#define MAX 100010
#define INF 0x3f3f3f3f
using namespace std; int n,m; struct Point{
int pos;
int type;//1-Al,2-Ar;3-Bl,4-Br.
}point[*MAX];
bool cmp(Point a,Point b){return a.pos<b.pos;} struct Interval{
int cntA,cntB;
}interval[*MAX];
int main()
{
scanf("%d%d",&n,&m);
int _size=;
for(int i=,l,r;i<=n;i++)
{
scanf("%d%d",&l,&r);
point[_size++]=(Point){l,};
point[_size++]=(Point){r,};
}
for(int i=,l,r;i<=m;i++)
{
scanf("%d%d",&l,&r);
point[_size++]=(Point){l,};
point[_size++]=(Point){r,};
}
sort(point,point+_size,cmp);
//for(int i=0;i<_size;i++) printf("pos=%d type=%d\n",point[i].pos,point[i].type); int ans=;
interval[]=(Interval){,};
for(int i=;i<_size;i++)
{
Point& lp=point[i-];
Point& rp=point[i]; if(lp.type==)//遇到一个A集合中的左端点
{
interval[i].cntA=interval[i-].cntA+;
interval[i].cntB=interval[i-].cntB;
}
if(lp.type==)//遇到一个A集合中的右端点
{
interval[i].cntA=interval[i-].cntA-;
interval[i].cntB=interval[i-].cntB;
}
if(lp.type==)//遇到一个B集合中的左端点
{
interval[i].cntA=interval[i-].cntA;
interval[i].cntB=interval[i-].cntB+;
}
if(lp.type==)//遇到一个B集合中的右端点
{
interval[i].cntA=interval[i-].cntA;
interval[i].cntB=interval[i-].cntB-;
} //printf("%d - %d : cntA=%d cntB=%d\n",lp.pos,rp.pos,interval[i].cntA,interval[i].cntB);
if(interval[i].cntA> && interval[i].cntB==) ans+=rp.pos-lp.pos;
} printf("%d\n",ans);
}