题目
题解
首先,题目中要求我们求的是哪些路径的点积是一个立方数,而每个点的点值都可以被他给定的 \(k\) 个质数组成。
考虑将每个点的点值分解为三进制,第 \(i\) 位表示这个点的点值可以被 \(prime[i]\) 的 \(t\bmod 3\) 次方组成。
至于为什么 \(\bmod 3\) 其实比较显然。
然后,对于每个点,我们定义 map<Ternary,int>MP[i]
,键值 Ternary
即为我们重载的三进制,这个表示以 \(i\) 为根,在它的子树中路径点积为 Ternary
的数量。
但是,由于启发式合并,我们的根有可能被更改,即 MP[i]
有时候不一定对应点 \(i\),所以我们用一个 rt[i]
表示 \(i\) 真正对应的 MP
到底是哪一个。
启发式合并时,由于我们可能会更改根,涉及到对于 MP
的整个修改,所以我们用一个 tag
作为懒标记,表示整体增加或者减少了多少。
那么,合并之前我们统计答案时,只需要从小的 MP
中找大的 MP
相对应的数值是多少,用组合即数量乘积表示答案。
统计完答案之后,考虑合并 MP
,在这之前首先需要修改子树对应的 MP
的 tag
,因为合并之后,子树的 MP
应该以当前点为根,而在这之前,子树的 MP
一直以那个子节点为根,所以 tag[rt[v]]+=val[u]
。
代码
#include<bits/stdc++.h>
using namespace std;
#define rep(i,__l,__r) for(signed i=(__l),i##_end_=(__r);i<=i##_end_;++i)
#define fep(i,__l,__r) for(signed i=(__l),i##_end_=(__r);i>=i##_end_;--i)
#define writc(a,b) fwrit(a),putchar(b)
#define mp(a,b) make_pair(a,b)
#define ft first
#define sd second
#define LL long long
#define ull unsigned long long
#define uint unsigned int
#define pii pair< int,int >
#define Endl putchar('\n')
// #define int long long
// #define int unsigned
// #define int unsigned long long
#ifdef _GLIBCXX_CSTDIO
#define cg (c=getchar())
template<class T>inline void qread(T& x){
char c;bool f=0;
while(cg<'0'||'9'<c)f|=(c=='-');
for(x=(c^48);'0'<=cg&&c<='9';x=(x<<1)+(x<<3)+(c^48));
if(f)x=-x;
}
template<class T>inline T qread(const T sample){
T x=0;char c;bool f=0;
while(cg<'0'||'9'<c)f|=(c=='-');
for(x=(c^48);'0'<=cg&&c<='9';x=(x<<1)+(x<<3)+(c^48));
return f?-x:x;
}
#undef cg
template<class T>void fwrit(const T x){//just short,int and long long
if(x<0)return (void)(putchar('-'),fwrit(-x));
if(x>9)fwrit(x/10);
putchar(x%10^48);
}
#endif
// template<class T,class... Args>inline void qread(T& x,Args&... args){qread(x),qread(args...);}
template<class T>inline T Max(const T x,const T y){return x>y?x:y;}
template<class T>inline T Min(const T x,const T y){return x<y?x:y;}
template<class T>inline T fab(const T x){return x>0?x:-x;}
inline int gcd(const int a,const int b){return b?gcd(b,a%b):a;}
inline void getInv(int inv[],const int lim,const int MOD){
inv[0]=inv[1]=1;for(int i=2;i<=lim;++i)inv[i]=1ll*inv[MOD%i]*(MOD-MOD/i)%MOD;
}
inline LL mulMod(const LL a,const LL b,const LL mod){//long long multiplie_mod
return ((a*b-(LL)((long double)a/mod*b+1e-8)*mod)%mod+mod)%mod;
}
const int MAXN=50000;
const int MAXK=30;
struct Ternary{
LL num;
Ternary():num(0ll){}
bool operator <(const Ternary rhs)const{
return num<rhs.num;
}
Ternary operator +(const Ternary rhs)const{
LL t=1,f,x1=num,x2=rhs.num;
Ternary ret;
while(x1>0 || x2>0){
f=((x1%3)+(x2%3))%3;
ret.num+=f*t;
x1/=3,x2/=3;
t*=3;
}return ret;
}
Ternary operator -(const Ternary rhs)const{
LL t=1,f,x1=num,x2=rhs.num;
Ternary ret;
while(x1>0 || x2>0){
f=((x1%3)-(x2%3)+3)%3;
ret.num+=f*t;
x1/=3,x2/=3;
t*=3;
}return ret;
}
Ternary operator +=(const Ternary rhs){
return (*this)=(*this)+rhs;
}
Ternary operator -=(const Ternary rhs){
return (*this)=(*this)-rhs;
}
}val[MAXN+5],tag[MAXN+5];//表示单个点的 val 和懒标记
map<Ternary,int>MP[MAXN+5];
LL prime[MAXK+5],ans;
int rt[MAXN+5],n,k;
struct edge{int to,nxt;}e[(MAXN<<1)|2];
int tail[MAXN+5],ecnt;
inline void add_edge(const int u,const int v){
e[++ecnt]=edge{v,tail[u]};tail[u]=ecnt;
e[++ecnt]=edge{u,tail[v]};tail[v]=ecnt;
}
int Merge(int x,int y,const int u){
bool haveswap=false;
if(MP[x].size()<MP[y].size())swap(x,y),haveswap=true;
Ternary tmp;
for(auto it=MP[y].begin();it!=MP[y].end();++it){
tmp=Ternary()-(it->first+tag[y]+tag[x]);
if(MP[x].count(tmp))
ans=ans+1ll*MP[x][tmp]*it->second;
}
if(haveswap)tag[x]+=val[u];
else tag[y]+=val[u];
for(auto it=MP[y].begin();it!=MP[y].end();++it){
tmp=it->first+tag[y]-tag[x];
MP[x][tmp]+=it->second;
}
MP[y].clear();
return x;
}
void dfs(const int u,const int fa){
rt[u]=u;
MP[u][val[u]]=1;
for(int i=tail[u],v;i;i=e[i].nxt)if((v=e[i].to)^fa){
dfs(v,u);
rt[u]=Merge(rt[u],rt[v],u);
}
}
inline void Init(){
while(~scanf("%d %d",&n,&k)){
ecnt=0;ans=0;
rep(i,1,k)prime[i]=qread(1ll);
LL x,cnt,f;
rep(t,1,n){
val[t].num=0,tail[t]=0,tag[t].num=0;
x=qread(1ll),f=1;
rep(i,1,k){cnt=0;
while(x%prime[i]==0){
x/=prime[i];
(++cnt)%=3;
}val[t].num+=cnt*f;
f*=3;
}
if(!val[t].num)++ans;//可以从自己到自己
}
int u,v;
rep(i,1,n-1){
u=qread(1),v=qread(1);
add_edge(u,v);
}
dfs(1,0);
printf("%lld\n",ans);
MP[rt[1]].clear();
}
}
signed main(){
// ios::sync_with_stdio(false);
Init();
return 0;
}