Description
程序员 ZS 有一棵树,它可以表示为 \(n\) 个顶点的无向连通图,顶点编号从 \(0\) 到 \(n-1\),它们之间有 \(n-1\) 条边。每条边上都有一个非零的数字。
一天,程序员 ZS 无聊,他决定研究一下这棵树的一些特性。他选择了一个十进制正整数 \(M\),\(\gcd(M,10)=1\)。
对于一对有序的不同的顶点 \((u, v)\),他沿着从顶点 \(u\) 到顶点 \(v\)的最短路径,按经过顺序写下他在路径上遇到的所有数字(从左往右写),如果得到一个可以被 \(M\) 整除的十进制整数,那么就认为 \((u,v)\) 是有趣的点对。
帮助程序员 ZS 得到有趣的对的数量。
Hint
- \(1\le n\le 10^5\)
- \(1\le m\le 10^9,\gcd(m, 10) = 1\)
- \(1\le \text{边权} < 10\)
Solution
这种树上路径的统计问题基本都是 点分治,而点分治的重点和难点就是如何 统计经过分治中心的满足条件的路径的个数。
这里采用 容斥法:即现分治中心为 \(s\),当前答案等于整个子树 \(s\) 的答案减去以 \(s\) 各个子结点为根的子树的答案。
考虑如何统计。
我们设有一条路径是 \(x\rightarrow y\),分治中心为 \(s\),路径 \(x\rightarrow s\) 对应的数字为 \(pd\),\(s\rightarrow y\) 对应 \(nd\),\(s\) 到 \(y\) 的距离为 \(l\)。
那么只有 \(pd \times 10^l + nd \equiv 0 \pmod m\) 成立时满足要求。
变形一下:\(pd \equiv -nd \times 10^{-l}\pmod m\)。
于是我们可以这样搞:把所有的 \(pd\) 用 map
存起来,记录一下个数,用 pair
数组把 \((nd, l)\) 记录下来。
导入所有了路径信息后,枚举 pair
数组,查找 map
中的元素配对即可。
预处理一下 \(10\) 的幂及其逆元的话,时间复杂度 \(O(n\log^2 n)\)。如果用 Hash Table 可以优化到理论 \(O(n\log n)\),但没什么必要。
Code
#include <cstdio>
#include <map>
#include <utility>
#include <vector>
using namespace std;
const int N = 1e5 + 5;
namespace Inv {
void extgcd(long long a, long long b, long long& x, long long& y) {
if (!b) x = 1, y = 0;
else extgcd(b, a % b, y, x), y -= a / b * x;
}
inline long long get(long long b, long long p) {
long long x, y;
extgcd(b, p, x, y);
x = (x % p + p) % p;
return x;
}
}
int n, m;
long long p10[N], invp[N];
long long ans;
struct edge { int to, len; };
vector<edge> G[N];
int root;
int maxp[N], size[N];
bool centr[N];
int getSize(int x, int f) {
size[x] = 1;
for (auto y : G[x])
if (!centr[y.to] && y.to != f)
size[x] += getSize(y.to, x);
return size[x];
}
void getCentr(int x, int f, int t) {
maxp[x] = 0;
for (auto y : G[x])
if (!centr[y.to] && y.to != f) {
getCentr(y.to, x, t);
maxp[x] = max(maxp[x], size[y.to]);
}
maxp[x] = max(maxp[x], t - size[x]);
if (maxp[x] < maxp[root]) root = x;
}
vector<pair<long long, int> > dat;
map<long long, int> cnt;
void getData(int x, int f, long long pd, long long nd, int dep) {
if (dep >= 0) cnt[pd]++, dat.push_back(make_pair(nd, dep));
for (auto y : G[x]) {
if(centr[y.to] || y.to == f) continue;
long long tpd = (pd + y.len * p10[dep + 1] % m) % m;
long long tnd = (nd * 10 % m + y.len) % m;
getData(y.to, x, tpd, tnd, dep + 1);
}
}
inline long long count(int x, int d) {
long long ret = 0;
cnt.clear(), dat.clear();
if (d == 0) getData(x, 0, 0, 0, -1);
else getData(x, 0, d % m, d % m, 0);
for (auto p : dat) {
long long t = ((-p.first * invp[p.second + 1] % m) + m) % m;
if (cnt.count(t)) ret += cnt[t];
if (d == 0 && p.first == 0) ++ret;
}
return ret + (d == 0 ? cnt[0] : 0);
}
void solve(int x) {
maxp[root = 0] = N;
getCentr(x, 0, getSize(x, 0));
int s = root; centr[s] = true;
for (auto y : G[s])
if (!centr[y.to])
solve(y.to);
ans += count(s, 0);
for (auto y : G[s])
if (!centr[y.to])
ans -= count(y.to, y.len);
centr[s] = false;
}
signed main() {
scanf("%d%d", &n, &m);
for (register int i = 1; i < n; i++) {
int u, v, l;
scanf("%d%d%d", &u, &v, &l);
++u, ++v;
G[u].push_back(edge{v, l});
G[v].push_back(edge{u, l});
}
p10[0] = 1 % m;
for (register int i = 1; i <= n; i++)
p10[i] = p10[i - 1] * 10 % m;
invp[n] = Inv::get(p10[n], m);
for (register int i = n - 1; i; i--)
invp[i] = invp[i + 1] * 10 % m;
ans = 0, solve(1);
printf("%lld\n", ans);
return 0;
}