R - Weak Pair
这个题目的初步想法,首先用dfs序建一颗树,然后判断对于每一个节点进行遍历,判断他的子节点和他相乘是不是小于等于k,
这么暴力的算法很自然的超时了。
然后上网搜了一下题解,感觉想的很巧妙。
就是我们要搜 子节点和父节点的乘积小于一个定值的对数。
一般求对数,有逆序对,都是把满足的放进去,到时候直接求答案就可以了。这个题目也很类似,但是还是有很大的区别的。
这个题目就是先把所有a[i] 和 k/a[i] 都放进一个数组,离散化,这一步是因为要直接求值,就是要把这个值放进线段树的这个离散化后的位置,权值为1 .
这个满足了a[i]*a[j]<=k 的要求,然后就是他们的关系必须是子节点和父节点。
这一点可以用dfs序来实现,先把父节点放进去,然后之后的子节点都可以查找这个节点,最后这个父节点的所有子节点都查找完之后就是把这个父节点弹出。
以上做法都是上网看题解的,我觉得还是没有那么难想了,这种差不多就是树上要满足是父节点子节点的关系都是可以用dfs来满足的。
其次就是弹出操作没有那么好想,最后就是放入线段树直接查找的这种逆序对思想。
这个知道怎么写之后就很好写了,注意细节
#include <cstdio> #include <cstring> #include <cstdlib> #include <algorithm> #include <iostream> #include <queue> #include <string> #include <cmath> #include <vector> #include <map> #define inf 0x3f3f3f3f #define inf64 0x3f3f3f3f3f3f3f3f using namespace std; const int maxn = 2e5 + 10; typedef long long ll; ll sum[maxn * 8], a[maxn], b[maxn], ans; vector<int>G[maxn]; int len, f[maxn]; bool vis[maxn]; ll n, k; void update(int id, int l, int r, int pos, int val) { if (l == r) { sum[id] += val; return; } int mid = (l + r) >> 1; if (pos <= mid) update(id << 1, l, mid, pos, val); else update(id << 1 | 1, mid + 1, r, pos, val); sum[id] = sum[id << 1] + sum[id << 1 | 1]; } ll query(int id, int l, int r, int x, int y) { if (x <= l && y >= r) return sum[id]; int mid = (l + r) >> 1; ll ans = 0; if (x <= mid) ans += query(id << 1, l, mid, x, y); if (y > mid) ans += query(id << 1 | 1, mid + 1, r, x, y); return ans; } void dfs(int u) { vis[u] = 1; int t2 = lower_bound(b + 1, b + 1 + len, a[u]) - b; update(1, 1, len, t2, 1); for (int i = 0; i < G[u].size(); i++) { int v = G[u][i]; if (vis[v]) continue; int t1 = lower_bound(b + 1, b + 1 + len, k / a[v]) - b; ans += query(1, 1, len, 1, t1); dfs(v); } update(1, 1, len, t2, -1); } int main() { int t; scanf("%d", &t); while (t--) { ans = 0; scanf("%lld%lld", &n, &k); memset(vis, 0, sizeof(vis)); for (int i = 1; i <= n; i++) scanf("%lld", &a[i]), b[i] = a[i], G[i].clear(); for (int i = 1; i <= n; i++) b[i + n] = k / a[i]; sort(b + 1, b + 1 + 2 * n); len = unique(b + 1, b + 1 + 2 * n) - b - 1; memset(sum, 0, sizeof(sum)); for (int i = 1; i < n; i++) { int u, v; scanf("%d%d", &u, &v); G[u].push_back(v); G[v].push_back(u); vis[v] = 1; } int root = 1; for (int i = 1; i <= n; i++) { if (vis[i] == 0) { root = i; break; } } memset(vis, 0, sizeof(vis)); dfs(root); printf("%lld\n", ans); } return 0; }View Code
#include <cstdio> #include <cstring> #include <cstdlib> #include <algorithm> #include <iostream> #include <queue> #include <string> #include <cmath> #include <vector> #include <map> #define inf 0x3f3f3f3f #define inf64 0x3f3f3f3f3f3f3f3f using namespace std; const int maxn = 2e5 + 10; typedef long long ll; ll sum[maxn*8], a[maxn], b[maxn], ans; vector<int>G[maxn]; int len, f[maxn]; bool vis[maxn]; ll n, k; void build(int id,int l,int r) { sum[id] = 0; if (l == r) return; int mid = (l + r) >> 1; build(id << 1, l, mid); build(id << 1 | 1, mid + 1, r); } void update(int id,int l,int r,int pos,int val) { if(l==r) { sum[id] += val; return; } int mid = (l + r) >> 1; if(pos<=mid) update(id << 1, l, mid, pos, val); else update(id << 1 | 1, mid + 1, r, pos, val); sum[id] = sum[id << 1] + sum[id << 1 | 1]; } ll query(int id,int l,int r,int x,int y) { if (x <= l && y >= r) return sum[id]; int mid = (l + r) >> 1; ll ans = 0; if (x <= mid) ans += query(id << 1, l, mid, x, y); if (y > mid) ans += query(id << 1 | 1, mid + 1, r, x, y); return ans; } void dfs(int u) { vis[u] = 1; int t1 = lower_bound(b + 1, b + 1 + len, k / a[u]) - b; int t2 = lower_bound(b + 1, b + 1 + len, a[u]) - b; ans += query(1, 1, len, 1, t1); update(1, 1, len, t2, 1); for(int i=0;i<G[u].size();i++) { int v = G[u][i]; if (vis[v]) continue; dfs(v); } update(1, 1, len, t2, -1); } int main() { int t; scanf("%d", &t); while(t--) { ans = 0; scanf("%lld%lld", &n, &k); memset(vis, 0, sizeof(vis)); for (int i = 1; i <= n; i++) scanf("%lld", &a[i]), b[i] = a[i], G[i].clear(); for (int i = 1; i <= n; i++) b[i + n] = k / a[i]; sort(b + 1, b + 1 + 2 * n); len = unique(b + 1, b + 1 + 2 * n) - b - 1; build(1, 1, len); for(int i=1;i<n;i++) { int u, v; scanf("%d%d", &u, &v); G[u].push_back(v); G[v].push_back(u); vis[v] = 1; } int root = 1; for (int i = 1; i <= n; i++) { if (vis[i] == 0) { root = i; break; } } memset(vis, 0, sizeof(vis)); dfs(root); printf("%lld\n", ans); } return 0; }View Code