Codechef Prime Distance On Tree

【传送门】

FFT第四题!

暑假的时候只会点分,然后合并是暴力合并的...水过去了...

其实两条路径长度的合并就是卷积的过程嘛,每次统计完路径就自卷积一下。

刚开始卷积固定了值域。T了。然后就不偷懒了,每次取最大权值乘二去找值域了。

#include <bits/stdc++.h>

const double pi = acos(-1.0);

struct Complex {
double r, i;
void clear() { r = i = 0.0; }
Complex(double r = , double i = ): r(r), i(i) {}
Complex operator + (const Complex &p) const { return Complex(r + p.r, i + p.i); }
Complex operator - (const Complex &p) const { return Complex(r - p.r, i - p.i); }
Complex operator * (const Complex &p) const { return Complex(r * p.r - i * p.i, r * p.i + i * p.r); }
}; void FFT(Complex *a, int n, int pd, int *r) {
for (int i = ; i < n; i++)
if (i < r[i])
std::swap(a[i], a[r[i]]);
for (int mid = ; mid < n; mid <<= ) {
Complex wn(cos(pi / mid), pd * sin(pi / mid));
for (int l = mid << , j = ; j < n; j += l) {
Complex w(1.0, 0.0);
for (int k = ; k < mid; k++, w = w * wn) {
Complex u = a[k + j], v = w * a[k + j + mid];
a[k + j] = u + v;
a[k + j + mid] = u - v;
}
}
}
if (pd < )
for (int i = ; i < n; i++)
a[i] = Complex(a[i].r / n, a[i].i / n);
} #define ll long long const int N = 2e5 + ;
int n, sz[N], maxsz[N], root, totsz;
std::vector<int> vec[N];
int prime[N], prin;
bool vis[N], is[N];
ll cnt[N], ccnt[N];
int dis[N], r[N];
Complex A[N];
int limit, l; void init() {
for (int i = ; i < N; i++) {
if (!is[i]) prime[++prin] = i;
for (int j = ; j <= prin && i * prime[j] < N; j++) {
is[i * prime[j]] = ;
if (i % prime[j] == ) break;
}
}
} inline bool chkmax(int &a, int b) { return a < b ? a = b, : ; } void getroot(int u, int fa) {
sz[u] = ; maxsz[u] = ;
for (int v : vec[u]) {
if (v == fa || vis[v]) continue;
getroot(v, u);
sz[u] += sz[v];
chkmax(maxsz[u], sz[v]);
}
chkmax(maxsz[u], totsz - sz[u]);
if (maxsz[u] < maxsz[root]) root = u;
} int f[N], tto, val; void getdis(int u, int fa) {
f[++tto] = dis[u];
val = std::max(val, f[tto]);
for (int v : vec[u]) {
if (vis[v] || v == fa) continue;
dis[v] = dis[u] + ;
getdis(v, u);
}
} void cal(int u, int d, int opt) {
tto = ;
dis[u] = d;
val = ;
getdis(u, );
for (int i = ; i <= tto; i++)
ccnt[f[i]]++;
limit = , l = ;
while (limit <= * val)
limit <<= , l++;
for (int i = ; i < limit; i++)
r[i] = r[i >> ] >> | ((i & ) << (l - ));
for (int i = ; i < limit; i++)
A[i] = Complex((double)ccnt[i], 0.0);
FFT(A, limit, , r);
for (int i = ; i < limit; i++)
A[i] = A[i] * A[i];
FFT(A, limit, -, r);
for (int i = ; i < limit; i++)
cnt[i] += opt * (ll)(A[i].r + 0.5);
for (int i = ; i <= tto; i++)
ccnt[f[i]]--;
} void solve(int u) {
vis[u] = ;
cal(u, , );
for (int v : vec[u]) {
if (vis[v]) continue;
cal(v, , -);
totsz = sz[v];
root = ;
getroot(v, );
solve(root);
}
} int main() {
init();
scanf("%d", &n);
for (int i = ; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
vec[u].push_back(v);
vec[v].push_back(u);
}
maxsz[root = ] = n;
totsz = n;
getroot(, );
solve(root);
ll ans = ;
for (int i = ; i <= prin; i++) {
ans += cnt[prime[i]];
}
ll sum = 1LL * n * (n - );
printf("%.7f\n", 1.0 * ans / sum);
return ;
}
上一篇:effective java 笔记1--序言


下一篇:Linux下查看Nginx安装目录、版本号信息?