T1
首先考虑 60 pts,若们可以将 t 看做一个集合,里面的元素与 t 或为 t 并且是3的倍数,然后可以容斥,\(O(t^2\log n)\)
考虑到这一过程瓶颈在于枚举 t 的子集和计算 t 集合中3的倍数的个数,也就是我们不要具体是什么,而需要数量
由于 \(2^k \mod 3\)只有 1 或 2 两种取值,可以枚举 1 的个数和 2 的个数然后匹配
接着考虑枚举子集,由于模3为1的位和模3为2的位是分别等价的,所以我们计算时跑个背包然后乘上组合系数容斥就行了
T2
由期望的线性性,\(E(最短距离)=2* E(虚树大小)-E(虚树直径)\)
这时分别计算,对于虚树大小可以枚举每条边贡献多少次,一条边被计算的当且仅当两边都有饼干
然后计算直径的期望,可以先钦定某两点是直径,然后在合法的点中选k-2个
此时合法的点可以\(O(m)\)求出,设钦定u,v,此时正在判定w
若dis(u,v)<dis(u,w)或dis(u,v)<dis(v,w),显然不合法
若dis(u,v)=dis(u,w)且v>w,或dis(u,v)=dis(v,w)且u>w,规定其不合法
此时不会被算重,复杂度\(O(m^3)\)
代码
T1
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int N = 111;
const int maxn = 60;
const int mod = 998244353;
int n, t;
int a[N], num;
int sum[3];
int f[N][3];
int c[N][N];
inline int read() {
int s = 0;
char ch = getchar();
while (ch > '9' || ch < '0') ch = getchar();
while (ch >= '0' && ch <= '9') {
s = (s << 1) + (s << 3) + (ch ^ 48);
ch = getchar();
}
return s;
}
inline void md(int& x) {
if (x >= mod)
x -= mod;
return;
}
inline int max_(int x, int y) { return x > y ? x : y; }
inline int min_(int x, int y) { return x > y ? y : x; }
int fm(int x, int y) {
int ans = 1;
while (y) {
if (y & 1)
ans = ans * x % mod;
x = x * x % mod, y >>= 1;
}
return ans;
}
int solve(int x) {
int ans = 0;
for (int i = max_(0ll, x - sum[2]); i <= min_(x, sum[1]); ++i) {
memset(f, 0, sizeof(f));
f[0][0] = 1;
for (int j = 1; j <= i; ++j)
for (int k = 0; k <= 2; ++k) md(f[j][k] += f[j - 1][(k - 1 + 3) % 3] + f[j - 1][k]);
for (int j = i + 1; j <= x; ++j)
for (int k = 0; k <= 2; ++k) md(f[j][k] += f[j - 1][(k - 2 + 3) % 3] + f[j - 1][k]);
md(ans += fm(f[x][0], n) * c[sum[1]][i] % mod * c[sum[2]][x - i] % mod);
}
return ans;
}
signed main() {
FILE* x = freopen("or.in", "r", stdin);
x = freopen("or.out", "w", stdout);
c[0][0] = 1;
for (int i = 1; i <= maxn; ++i) {
c[i][0] = 1;
for (int j = 1; j <= i; ++j) md(c[i][j] += c[i - 1][j - 1] + c[i - 1][j]);
}
n = read();
t = read();
for (int i = 0; i < maxn; ++i)
if ((1ll << i) & t)
++sum[(1ll << i) % 3];
int ans = 0;
for (int i = 0; i <= sum[1] + sum[2]; ++i)
if ((sum[1] + sum[2] - i) & 1)
ans = (ans + mod - solve(i)) % mod;
else
ans = (ans + solve(i)) % mod;
cout << ans << endl;
return 0;
}
T2
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int N = 2e3 + 11;
const int mod = 998244353;
struct qxxx {
int v, next;
} cc[2 * N];
bool jud[N];
bool vis[N];
int key[N];
int n, m, k;
int sum[N];
int ans1, ans2;
int first[N], cnt;
int dep[N], pg[N], f[N][20];
int c[N][N];
int d[N][N];
inline int read() {
int s = 0;
char ch = getchar();
while (ch > '9' || ch < '0') ch = getchar();
while (ch >= '0' && ch <= '9') s = (s << 1) + (s << 3) + (ch ^ 48), ch = getchar();
return s;
}
inline int min_(int x, int y) { return x > y ? y : x; }
inline void md(int& x) {
if (x >= mod)
x -= mod;
return;
}
void qxx(int u, int v) {
cc[++cnt] = { v, first[u] };
first[u] = cnt;
return;
}
int fm(int x, int y) {
int ans = 1;
while (y) {
if (y & 1)
ans = ans * x % mod;
y >>= 1, x = x * x % mod;
}
return ans;
}
int get_lca(int x, int y) {
if (dep[x] < dep[y])
swap(x, y);
for (int i = pg[dep[x] - dep[y]]; i >= 0; --i)
if (dep[f[x][i]] >= dep[y])
x = f[x][i];
if (x == y)
return x;
for (int i = pg[dep[x]]; i >= 0; --i)
if (f[x][i] != f[y][i])
x = f[x][i], y = f[y][i];
if (x != y)
x = f[x][0];
return x;
}
int dis(int x, int y) { return dep[x] + dep[y] - 2 * dep[get_lca(x, y)]; }
void dfs(int x, int fa) {
for (int i = 1; i <= pg[dep[x]]; ++i) f[x][i] = f[f[x][i - 1]][i - 1];
sum[x] = jud[x];
int ed = 0;
for (int i = first[x]; i; i = cc[i].next)
if (cc[i].v != fa) {
dep[cc[i].v] = dep[x] + 1;
f[cc[i].v][0] = x;
dfs(cc[i].v, x);
sum[x] += sum[cc[i].v];
ed = min_(k - 1, sum[cc[i].v]);
for (int j = 1; j <= ed; ++j) md(ans1 += c[sum[cc[i].v]][j] * c[m - sum[cc[i].v]][k - j] % mod);
}
return;
}
void pre() {
c[0][0] = 1;
for (int i = 1; i <= 2e3; ++i) {
c[i][0] = 1;
for (int j = 1; j <= i; ++j) md(c[i][j] += c[i - 1][j - 1] + c[i - 1][j]);
}
return;
}
signed main() {
FILE* p = freopen("tree.in", "r", stdin);
p = freopen("tree.out", "w", stdout);
pre();
n = read(), m = read(), k = read();
for (int i = 1; i <= m; ++i) jud[key[i] = read()] = 1;
for (int x, y, i = 1; i < n; ++i) x = read(), y = read(), qxx(x, y), qxx(y, x), pg[i] = log2(i);
dep[1] = 1;
dfs(1, 1);
for (int i = 1; i <= n; ++i)
for (int j = 1; j <= n; ++j) d[i][j] = dis(i, j);
for (int i = 1; i <= m; ++i)
for (int ds, sl, x, y, j = i + 1; j <= m; ++j) {
sl = 0;
x = key[i], y = key[j];
ds = dis(x, y);
for (int h = 1; h <= m; ++h) {
if (h == j || h == i)
continue;
if (d[key[h]][x] > ds)
continue;
if (d[key[h]][y] > ds)
continue;
if ((d[key[h]][x] == ds) & (h < j))
continue;
if ((d[key[h]][y] == ds) & (h < i))
continue;
++sl;
}
md(ans2 += ds * c[sl][k - 2] % mod);
}
cout << (2 * ans1 - ans2 + mod) * fm(c[m][k], mod - 2) % mod << endl;
return 0;
}