AcWing 164 可达性统计
题目链接:
https://www.acwing.com/problem/content/166/
题意:
给定一张(N, M)的有向无环图,求出从每个点出发能够到达的点的数量,其中N、M <= 30000。
输入样例
10 10
3 8
2 3
2 5
5 9
5 9
2 3
3 9
4 8
2 10
4 9
输出样例
1
6
3
3
2
1
1
1
1
1
题目分析:
设从点x出发能到达的点的集合为f(x),显然:f(x) = {x} ∪ f(y)
y表示所有有向边(x, y)的弧头。
也就是说,所有从x出发能够到达的点,其实就是从"x的各个后继结点y"出发能够到达的点的并集,再交上{x},所以,我们需要在计算f(x)前,先计算出所有以x为首的后继结点的f(y)。
如何确定这个顺序,注意,对每个边来说,都是先计算后继结点,再计算前驱,很容易想到是拓扑序的逆序。
但是,我们在计算x时,不能简单粗暴地把f(y)相加,因为f(y1)∪f(y2) != |f(y1)| + |f(y2)|,比如2可达3和5,3和5分别可达9,那么计算出来2应该是4,而不是5。
我们可以使用状态压缩的方法,用二进制来表示可达状态,对于第i个位置,如果为1,则表示x可达i,否则不可达,显然,并集也很容易表示成二进制 | 运算。
bitset
c++STL标准容器bitset,可以帮助我们进行状态的压缩。
std :: bitset 是标准库中的一个固定大小序列,其储存的数据只包含 0/1。众所周知,由于内存地址是按字节即 byte 寻址,而非比特 bit , 我们一个 bool 类型的变量,虽然只能表示 0/1 , 但是也占了 1byte 的内存。bitset 就是通过固定的优化,使得一个字节的八个比特能分别储存 8 位的 0/1
对于一个 4 字节的 int 变量,在只存 0/1 的意义下, bitset 占用空间只是其1/32。在某些情况下通过 bitset 可以使你的复杂度除以 32。
下面是bitset的几个常用操作:
operator []
: 下标运算,快速访问容器中某一位的值
operator == / !=
: 比较两个bitset容器内值是否完全相同
operator &= |= ^= ~ << >>
:位运算,和普通变量运算相同
count()
:返回容器内true的个数
set() / reset()
:将容器内的值全部设置为1 / 0
to_string() to_ulong() to_ullong()
:字面意思,转换成相应的字符串/unsigned long /unsigned long long
flip()
:将某一位的值反转
代码如下:
#include <iostream>
#include <bitset>
#include <algorithm>
#include <vector>
#include <queue>
using namespace std;
const int N = 30010;
vector<int> e[N]; // 邻接表,用数组模拟会爆内存 e[x]存储以x为弧尾的所有有向边
vector<int> topo; // 存储拓扑序列
bitset<N> bit[N]; // bit[i][k]存储第i个结点对于k结点的可达性,1表示可达,0表示不可达。
int n, m; // n个点,m条边
int in[N]; // 每个点的入度
void topo_sort () {
queue<int> q;
for (int i = 1; i <= n; i++) if (!in[i]) q.push(i); // 把所有入度为0的点入队
while (q.size()) {
int x = q.front(); q.pop();
topo.push_back(x);
for (int k = 0; k < e[x].size(); k++) {
int ver = e[x][k]; // ver为x的弧头
if (-- in[ver] == 0) q.push(ver);
}
}
}
int main () {
cin >> n >> m;
for (int i = 1; i <= m; i++) {
int x, y; cin >> x >> y;
in[y] ++;
e[x].push_back(y);
}
topo_sort();
// 从拓扑序最后开始,因为计算前面需要用到后面
for (int i = topo.size() - 1; i >= 0; i--) {
int x = topo[i];
bit[x].reset(); bit[x][x] = 1; // 对自己可达
for (int k = 0; k < e[x].size(); k++) {
int ver = e[x][k];
bit[x] |= bit[ver];
}
}
for (int i = 1; i <= n; i++) {
cout << bit[i].count() << endl;
}
return 0;
}