题意:
给出一棵树,让你寻找一条路径,使得路径上的点相乘 \(mod\ 10^6+3\) 等于 \(k\),输出路径的两个端点,按照字典序最小输出。
分析:
树上路径问题,点分治。
按点分治的思路写即可。注意的是,这里不是边,而是点。并且,因为是乘积,不用每次都要遍历所有的点进行寻找,预处理出模数以内的数的逆元,直接判断是否存在即可。由于是多组输入,但不能每次都清空,应该会超时。同时记录下该值是在哪一次 \(solve\) 时出现的。
此外,此处的点分治采用遍历子树的方法,就不用容斥来减去不满足要求的部分。
还有就是处理当前子树时,子树的大小问题。一开始用:\(nt=sz[u]>sz[v]?tnt-sz[v]:sz[u]\),发现超时了,因为 \(tnt\) 写成了 \(nt\),即应该是点 \(v\) 时的 \(nt\),而不上一个 \(nt\)。或者直接用\(nt=sz[u]\) 也可以。
代码:
#include <bits/stdc++.h>
#define pb push_back
using namespace std;
typedef long long ll;
const int mod=1e6+3;
typedef pair<ll,int> P;
const int N=1e5+5;
vector<int>pic[N];
int val[N],sz[N],has[mod+100],pt[mod+100];
bool vis[N];
ll inv[mod+100];
P dis[N];
int nt,minn,rt,k,a,b,tol;
void read(int &x)
{
x=0;
int f=1;
char ch=getchar();
while(!isdigit(ch))
{
if(ch=='-')
f=-1;
ch=getchar();
}
while(isdigit(ch))
{
x=(x<<3)+(x<<1)+ch-'0';
ch=getchar();
}
x*=f;
}
void init()
{
inv[1]=1;
for(int i=2;i<mod;i++)
inv[i]=1LL*(mod-mod/i)*inv[mod%i]%mod;
}
void get(int x,int y)
{
if(x>y)
swap(x,y);
if(x<a)
a=x,b=y;
else if(x==a&&y<b)
a=x,b=y;
else if(a==-1||b==-1)
a=x,b=y;
}
void dfs(int v,int p)
{
sz[v]=1;
int res=0;
for(int i=0;i<pic[v].size();i++)
{
int u=pic[v][i];
if(u==p||vis[u])
continue;
dfs(u,v);
sz[v]+=sz[u];
res=max(res,sz[u]);
}
res=max(res,nt-sz[v]);
if(res<minn)
{
minn=res;
rt=v;
}
}
void dfs2(int v,int p,int &cnt,ll d)
{
dis[++cnt]=make_pair(d*val[v]%mod,v);
for(int i=0;i<pic[v].size();i++)
{
int u=pic[v][i];
if(u==p||vis[u])
continue;
dfs2(u,v,cnt,1LL*d*val[v]%mod);
}
}
void solve(int v,int p)
{
has[val[v]]=v;
pt[val[v]]=++tol;
for(int i=0;i<pic[v].size();i++)
{
int u=pic[v][i];
if(u==p||vis[u])
continue;
int cnt=0;
dfs2(u,v,cnt,val[v]);
for(int j=1;j<=cnt;j++)
{
P t=dis[j];
ll tmp=1LL*k*val[v]%mod*inv[t.first%mod]%mod;
if(pt[tmp]==tol&&has[tmp])
{
int x=has[tmp],y=t.second;
get(x,y);
}
}
for(int j=1;j<=cnt;j++)
{
P t=dis[j];
if(has[t.first%mod]==0||pt[t.first%mod]!=tol)
has[t.first%mod]=t.second;
else
has[t.first%mod]=min(has[t.first%mod],t.second);
pt[t.first%mod]=tol;
}
}
}
void divide(int v,int p)
{
solve(v,p);
vis[v]=1;
for(int i=0;i<pic[v].size();i++)
{
int u=pic[v][i];
if(u==p||vis[u])
continue;
nt=sz[u],minn=N;//!!!
dfs(u,v);
divide(rt,rt);
}
}
int main()
{
int n,x,y;
init();
while(scanf("%d%d",&n,&k)!=EOF)
{
a=-1,b=-1;
for(int i=1;i<=n;i++)
{
pic[i].clear();//没有清空,一直爆栈
read(val[i]);
vis[i]=0;
}
for(int i=1;i<n;i++)
{
read(x),read(y);
pic[x].pb(y);
pic[y].pb(x);
}
nt=n,minn=N;
dfs(1,1);
divide(rt,rt);
if(a==-1||b==-1)
printf("No solution\n");
else
printf("%d %d\n",a,b);
}
return 0;
}
另一种子树大小求法:
#include <bits/stdc++.h>
#define pb push_back
using namespace std;
typedef long long ll;
const int mod=1e6+3;
typedef pair<ll,int> P;
const int N=1e5+5;
vector<int>pic[N];
int val[N],sz[N],has[mod+100],pt[mod+100];
bool vis[N];
ll inv[mod+100];
P dis[N];
int nt,minn,rt,k,a,b,tol;
void read(int &x)
{
x=0;
int f=1;
char ch=getchar();
while(!isdigit(ch))
{
if(ch=='-')
f=-1;
ch=getchar();
}
while(isdigit(ch))
{
x=(x<<3)+(x<<1)+ch-'0';
ch=getchar();
}
x*=f;
}
void init()
{
inv[1]=1;
for(int i=2;i<mod;i++)
inv[i]=1LL*(mod-mod/i)*inv[mod%i]%mod;
}
void get(int x,int y)
{
if(x>y)
swap(x,y);
if(x<a)
a=x,b=y;
else if(x==a&&y<b)
a=x,b=y;
else if(a==-1||b==-1)
a=x,b=y;
}
void dfs(int v,int p)
{
sz[v]=1;
int res=0;
for(int i=0;i<pic[v].size();i++)
{
int u=pic[v][i];
if(u==p||vis[u])
continue;
dfs(u,v);
sz[v]+=sz[u];
res=max(res,sz[u]);
}
res=max(res,nt-sz[v]);
if(res<minn)
{
minn=res;
rt=v;
}
}
void dfs2(int v,int p,int &cnt,ll d)
{
dis[++cnt]=make_pair(d*val[v]%mod,v);
for(int i=0;i<pic[v].size();i++)
{
int u=pic[v][i];
if(u==p||vis[u])
continue;
dfs2(u,v,cnt,1LL*d*val[v]%mod);
}
}
void solve(int v,int p)
{
has[val[v]]=v;
pt[val[v]]=++tol;
for(int i=0;i<pic[v].size();i++)
{
int u=pic[v][i];
if(u==p||vis[u])
continue;
int cnt=0;
dfs2(u,v,cnt,val[v]);
for(int j=1;j<=cnt;j++)
{
P t=dis[j];
ll tmp=1LL*k*val[v]%mod*inv[t.first%mod]%mod;
if(pt[tmp]==tol&&has[tmp])
{
int x=has[tmp],y=t.second;
get(x,y);
}
}
for(int j=1;j<=cnt;j++)
{
P t=dis[j];
if(has[t.first%mod]==0||pt[t.first%mod]!=tol)
has[t.first%mod]=t.second;
else
has[t.first%mod]=min(has[t.first%mod],t.second);
pt[t.first%mod]=tol;
}
}
}
void divide(int v,int p)
{
solve(v,p);
vis[v]=1;
int tnt=nt;
for(int i=0;i<pic[v].size();i++)
{
int u=pic[v][i];
if(u==p||vis[u])
continue;
nt=sz[u]>sz[v]?tnt-sz[v]:sz[u],minn=N;//!!!
dfs(u,v);
divide(rt,rt);
}
}
int main()
{
int n,x,y;
init();
while(scanf("%d%d",&n,&k)!=EOF)
{
a=-1,b=-1;
for(int i=1;i<=n;i++)
{
pic[i].clear();//没有清空,一直爆栈
read(val[i]);
vis[i]=0;
}
for(int i=1;i<n;i++)
{
read(x),read(y);
pic[x].pb(y);
pic[y].pb(x);
}
nt=n,minn=N;
dfs(1,1);
divide(rt,rt);
if(a==-1||b==-1)
printf("No solution\n");
else
printf("%d %d\n",a,b);
}
return 0;
}
/*
https://blog.csdn.net/jtjy568805874/article/details/51332768
*/