题目大概说给一棵有点权的树,输出字典序最小的点对,使这两点间路径上点权的乘积模1000003的结果为k。
树的点分治搞了。因为是点权过根的两条路径的LCA会被重复统计,而注意到1000003是质数,所以这个用乘法逆元搞一下就OK了。还有要注意“治”的各个实现,把时间复杂度“控制”在O(nlogn)。
WA了几次,WA在漏了点到子树根的路径,还有每次分治忘了清空数组。
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define INF (1<<30)
#define MAXN 111111
struct Edge{
int v,next;
}edge[MAXN<<];
int NE,head[MAXN];
void addEdge(int u,int v){
edge[NE].v=v; edge[NE].next=head[u]; head[u]=NE++;
}
bool vis[MAXN];
int mini,cen,size[MAXN];
void getSize(int u,int fa){
size[u]=;
for(int i=head[u]; i!=-; i=edge[i].next){
int v=edge[i].v;
if(v==fa || vis[v]) continue;
getSize(v,u);
size[u]+=size[v];
}
}
void getCen(int u,int fa,int &tot){
int res=tot-size[u];
for(int i=head[u]; i!=-; i=edge[i].next){
int v=edge[i].v;
if(v==fa || vis[v]) continue;
getCen(v,u,tot);
res=max(res,size[v]);
}
if(res<mini) mini=res,cen=u;
}
int getCen(int u){
getSize(u,u);
mini=INF;
getCen(u,u,size[u]);
return cen;
}
long long ine(long long a){
long long res=,n=;
while(n){
if(n&) res*=a,res%=;
a*=a; a%=;
n>>=;
}
return res;
}
int n,k,val[MAXN];
int ansx,ansy;
int record[],tn,tmpx[MAXN],tmpy[MAXN],all[MAXN],an;
void dfs(int u,int fa,long long dist,int &top){
int v=record[ine(dist)*k%*top%];
if(v){
if(u<v){
if(u<ansx) ansx=u,ansy=v;
else if(u==ansx && v<ansy) ansy=u,ansy=v;
}else{
if(v<ansx) ansx=v,ansy=u;
else if(v==ansx && u<ansy) ansy=v,ansy=u;
}
}
tmpx[tn]=u; tmpy[tn]=dist; ++tn;
all[an++]=dist;
for(int i=head[u]; i!=-; i=edge[i].next){
int v=edge[i].v;
if(v==fa || vis[v]) continue;
dfs(v,u,dist*val[v]%,top);
}
}
void conquer(int u){
an=;
all[an++]=val[u];
record[val[u]]=u;
for(int i=head[u]; i!=-; i=edge[i].next){
int v=edge[i].v;
if(vis[v]) continue;
tn=;
dfs(v,v,(long long)val[u]*val[v]%,val[u]);
for(int j=; j<tn; ++j){
if(record[tmpy[j]]== || record[tmpy[j]]>tmpx[j]) record[tmpy[j]]=tmpx[j];
}
}
for(int i=; i<an; ++i) record[all[i]]=;
}
void divide(int u){
u=getCen(u);
vis[u]=;
conquer(u);
for(int i=head[u]; i!=-; i=edge[i].next){
int v=edge[i].v;
if(vis[v]) continue;
divide(v);
}
}
int main(){
int a,b;
while(~scanf("%d%d",&n,&k)){
for(int i=; i<=n; ++i){
scanf("%d",val+i);
}
NE=;
memset(head,-,sizeof(head));
for(int i=; i<n; ++i){
scanf("%d%d",&a,&b);
addEdge(a,b);
addEdge(b,a);
}
memset(vis,,sizeof(vis));
ansx=ansy=INF;
divide();
if(ansx==INF) puts("No solution");
else printf("%d %d\n",ansx,ansy);
}
return ;
}