D Tree HDU - 4812

题意:

给出一棵树,让你寻找一条路径,使得路径上的点相乘 \(mod\ 10^6+3\) 等于 \(k\),输出路径的两个端点,按照字典序最小输出。

分析:

  树上路径问题,点分治。
  按点分治的思路写即可。注意的是,这里不是边,而是点。并且,因为是乘积,不用每次都要遍历所有的点进行寻找,预处理出模数以内的数的逆元,直接判断是否存在即可。由于是多组输入,但不能每次都清空,应该会超时。同时记录下该值是在哪一次 \(solve\) 时出现的。
  此外,此处的点分治采用遍历子树的方法,就不用容斥来减去不满足要求的部分。
  还有就是处理当前子树时,子树的大小问题。一开始用:\(nt=sz[u]>sz[v]?tnt-sz[v]:sz[u]\),发现超时了,因为 \(tnt\) 写成了 \(nt\),即应该是点 \(v\) 时的 \(nt\),而不上一个 \(nt\)。或者直接用\(nt=sz[u]\) 也可以。

代码:

#include <bits/stdc++.h>
#define pb push_back
using namespace std;
typedef long long ll;
const int mod=1e6+3;
typedef pair<ll,int> P;
const int N=1e5+5;
vector<int>pic[N];
int val[N],sz[N],has[mod+100],pt[mod+100];
bool vis[N];
ll inv[mod+100];
P dis[N];
int nt,minn,rt,k,a,b,tol;
void read(int &x)
{
    x=0;
    int f=1;
    char ch=getchar();
    while(!isdigit(ch))
    {
        if(ch=='-')
            f=-1;
        ch=getchar();
    }
    while(isdigit(ch))
    {
        x=(x<<3)+(x<<1)+ch-'0';
        ch=getchar();
    }
    x*=f;
}
void init()
{
    inv[1]=1;
    for(int i=2;i<mod;i++)
        inv[i]=1LL*(mod-mod/i)*inv[mod%i]%mod;
}
void get(int x,int y)
{
    if(x>y)
        swap(x,y);
    if(x<a)
        a=x,b=y;
    else if(x==a&&y<b)
        a=x,b=y;
    else if(a==-1||b==-1)
        a=x,b=y;
}
void dfs(int v,int p)
{
    sz[v]=1;
    int res=0;
    for(int i=0;i<pic[v].size();i++)
    {
        int u=pic[v][i];
        if(u==p||vis[u])
            continue;
        dfs(u,v);
        sz[v]+=sz[u];
        res=max(res,sz[u]);
    }
    res=max(res,nt-sz[v]);
    if(res<minn)
    {
        minn=res;
        rt=v;
    }
}
void dfs2(int v,int p,int &cnt,ll d)
{
    dis[++cnt]=make_pair(d*val[v]%mod,v);
    for(int i=0;i<pic[v].size();i++)
    {
        int u=pic[v][i];
        if(u==p||vis[u])
            continue;
        dfs2(u,v,cnt,1LL*d*val[v]%mod);
    }
}
void solve(int v,int p)
{
    has[val[v]]=v;
    pt[val[v]]=++tol;
    for(int i=0;i<pic[v].size();i++)
    {
        int u=pic[v][i];
        if(u==p||vis[u])
            continue;
        int cnt=0;
        dfs2(u,v,cnt,val[v]);
        for(int j=1;j<=cnt;j++)
        {
            P t=dis[j];
            ll tmp=1LL*k*val[v]%mod*inv[t.first%mod]%mod;
            if(pt[tmp]==tol&&has[tmp])
            {
                int x=has[tmp],y=t.second;
                get(x,y);
            }
        }
        for(int j=1;j<=cnt;j++)
        {
            P t=dis[j];
            if(has[t.first%mod]==0||pt[t.first%mod]!=tol)
                has[t.first%mod]=t.second;
            else
                has[t.first%mod]=min(has[t.first%mod],t.second);
            pt[t.first%mod]=tol;
        }
    }
}
void divide(int v,int p)
{
    solve(v,p);
    vis[v]=1;
    for(int i=0;i<pic[v].size();i++)
    {
        int u=pic[v][i];
        if(u==p||vis[u])
            continue;
        nt=sz[u],minn=N;//!!!
        dfs(u,v);
        divide(rt,rt);
    }
}
int main()
{
    int n,x,y;
    init();
    while(scanf("%d%d",&n,&k)!=EOF)
    {
        a=-1,b=-1;
        for(int i=1;i<=n;i++)
        {
            pic[i].clear();//没有清空,一直爆栈
            read(val[i]);
            vis[i]=0;
        }
        for(int i=1;i<n;i++)
        {
            read(x),read(y);
            pic[x].pb(y);
            pic[y].pb(x);
        }
        nt=n,minn=N;
        dfs(1,1);
        divide(rt,rt);
        if(a==-1||b==-1)
            printf("No solution\n");
        else
            printf("%d %d\n",a,b);
    }
    return 0;
}


另一种子树大小求法:

#include <bits/stdc++.h>
#define pb push_back
using namespace std;
typedef long long ll;
const int mod=1e6+3;
typedef pair<ll,int> P;
const int N=1e5+5;
vector<int>pic[N];
int val[N],sz[N],has[mod+100],pt[mod+100];
bool vis[N];
ll inv[mod+100];
P dis[N];
int nt,minn,rt,k,a,b,tol;
void read(int &x)
{
    x=0;
    int f=1;
    char ch=getchar();
    while(!isdigit(ch))
    {
        if(ch=='-')
            f=-1;
        ch=getchar();
    }
    while(isdigit(ch))
    {
        x=(x<<3)+(x<<1)+ch-'0';
        ch=getchar();
    }
    x*=f;
}
void init()
{
    inv[1]=1;
    for(int i=2;i<mod;i++)
        inv[i]=1LL*(mod-mod/i)*inv[mod%i]%mod;
}
void get(int x,int y)
{
    if(x>y)
        swap(x,y);
    if(x<a)
        a=x,b=y;
    else if(x==a&&y<b)
        a=x,b=y;
    else if(a==-1||b==-1)
        a=x,b=y;
}
void dfs(int v,int p)
{
    sz[v]=1;
    int res=0;
    for(int i=0;i<pic[v].size();i++)
    {
        int u=pic[v][i];
        if(u==p||vis[u])
            continue;
        dfs(u,v);
        sz[v]+=sz[u];
        res=max(res,sz[u]);
    }
    res=max(res,nt-sz[v]);
    if(res<minn)
    {
        minn=res;
        rt=v;
    }
}
void dfs2(int v,int p,int &cnt,ll d)
{
    dis[++cnt]=make_pair(d*val[v]%mod,v);
    for(int i=0;i<pic[v].size();i++)
    {
        int u=pic[v][i];
        if(u==p||vis[u])
            continue;
        dfs2(u,v,cnt,1LL*d*val[v]%mod);
    }
}
void solve(int v,int p)
{
    has[val[v]]=v;
    pt[val[v]]=++tol;
    for(int i=0;i<pic[v].size();i++)
    {
        int u=pic[v][i];
        if(u==p||vis[u])
            continue;
        int cnt=0;
        dfs2(u,v,cnt,val[v]);
        for(int j=1;j<=cnt;j++)
        {
            P t=dis[j];
            ll tmp=1LL*k*val[v]%mod*inv[t.first%mod]%mod;
            if(pt[tmp]==tol&&has[tmp])
            {
                int x=has[tmp],y=t.second;
                get(x,y);
            }
        }
        for(int j=1;j<=cnt;j++)
        {
            P t=dis[j];
            if(has[t.first%mod]==0||pt[t.first%mod]!=tol)
                has[t.first%mod]=t.second;
            else
                has[t.first%mod]=min(has[t.first%mod],t.second);
            pt[t.first%mod]=tol;
        }
    }
}
void divide(int v,int p)
{
    solve(v,p);
    vis[v]=1;
    int tnt=nt;
    for(int i=0;i<pic[v].size();i++)
    {
        int u=pic[v][i];
        if(u==p||vis[u])
            continue;
        nt=sz[u]>sz[v]?tnt-sz[v]:sz[u],minn=N;//!!!
        dfs(u,v);
        divide(rt,rt);
    }
}
int main()
{
    int n,x,y;
    init();
    while(scanf("%d%d",&n,&k)!=EOF)
    {
        a=-1,b=-1;
        for(int i=1;i<=n;i++)
        {
            pic[i].clear();//没有清空,一直爆栈
            read(val[i]);
            vis[i]=0;
        }
        for(int i=1;i<n;i++)
        {
            read(x),read(y);
            pic[x].pb(y);
            pic[y].pb(x);
        }
        nt=n,minn=N;
        dfs(1,1);
        divide(rt,rt);
        if(a==-1||b==-1)
            printf("No solution\n");
        else
            printf("%d %d\n",a,b);
    }
    return 0;
}
/*
https://blog.csdn.net/jtjy568805874/article/details/51332768
*/

上一篇:P2371 [国家集训队]墨墨的等式


下一篇:匹配统计