The 2021 ICPC Asia Regionals Online Contest (II) L Euler Function (数论,线段树)

The 2021 ICPC Asia Regionals Online Contest (II)  L Euler Function  (数论,线段树)

  • 题意:一个长度为\(n\)的序列\(x\),\(m\)次操作,有两种,一种是对区间\([l,r]\)的数乘上\(w\),一种是询问\([l,r]\)的每个数的欧拉函数之和。

  • 题解:首先看数据范围,\(x[i]\)和\(w\)的值都很小,最大才\(100\),根据欧拉函数公式:\(\phi [N]=N*(1-\frac{1}{p_1})*...*(1-\frac{1}{p_n})\),\(p_i\)表示\(N\)质因数分解后的质数,那么有如下的性质:

    对于一个质数P,\(\phi [N*P]=\phi [N]*P\ \ if(N|P)\)和\(\phi [N*P]=\phi [N]*(P-1)\ \ if(N!|P)\)

    那么对于每个\(w\),先质因数分解,然后线段树维护每个质因数,并开25个\(tag\)进行标记,如果当前区间全部包含某个质因数,就可以直接算。

  • 代码

    #include <iostream>
    #include <vector>
    #include <set>
    #include <algorithm>
    #include <unordered_map>
    using namespace std;
    #define ll long long
    #define pb push_back
    const int N=1e5+10;
    const int mod=998244353;
    
    
    int n,m;
    ll x[N];
    ll res[200];  //欧拉函数值
    int cntp[200][200];  //记录每个i的质数j的出现次数
    bool vis[200];
    int prime[100],cnt;
    
    void get_prime(int n){
        for(int i=2;i<=n;++i){
           if(!vis[i]) prime[cnt++]=i;
           for(int j=0;j<cnt && prime[j]<=n/i;++j){
               vis[i*prime[j]]=true;
               if(i%prime[j]==0) break;
           } 
        } 
    }
    
    struct Node{
        int l,r;
        ll val; ll mul;
        bool tag[26];
    }tr[N<<4];
    
    void push_up(int u){
        for(int i=0;i<25;++i){  //如果两个儿子都有这个质因数,父亲才有
            if(tr[u<<1].tag[i] && tr[u<<1|1].tag[i]){
                tr[u].tag[i]=true;
            }
        }
        tr[u].val=(tr[u<<1].val+tr[u<<1|1].val)%mod;
    }
    
    void push_down(int u){
        tr[u<<1].val=(tr[u<<1].val*tr[u].mul)%mod;
        tr[u<<1|1].val=(tr[u<<1|1].val*tr[u].mul)%mod;
        tr[u<<1].mul=(tr[u<<1].mul*tr[u].mul)%mod;
        tr[u<<1|1].mul=(tr[u<<1|1].mul*tr[u].mul)%mod;
        tr[u].mul=1;
    }
    
    void build(int u,int l,int r){
        if(l==r){
            tr[u]={l,r,res[x[l]],1};
            for(int i=0;i<25;++i){
                if(cntp[x[l]][i]) tr[u].tag[i]=true;
            }
            return;
        }
        tr[u]={l,r,1,1};
        int mid=(tr[u].l+tr[u].r)>>1;
        build(u<<1,l,mid);
        build(u<<1|1,mid+1,r);
        push_up(u);
    }
    
    void update(int u,int l,int r,int id,int cnt){
        if(tr[u].l>=l && tr[u].r<=r && tr[u].tag[id]){   //质因数出现过
            for(int i=1;i<=cnt;++i){
                tr[u].val=(tr[u].val*prime[id])%mod;
                tr[u].mul=(tr[u].mul*prime[id])%mod;
            }
            return;
        }
        if(tr[u].l==tr[u].r){    //质因数没有出现过
            tr[u].val=(tr[u].val*(prime[id]-1))%mod; //先算一次
            for(int i=1;i<cnt;++i){
                tr[u].val=(tr[u].val*prime[id])%mod;
            }
            tr[u].tag[id]=true;
            return;
        }
        if(tr[u].mul!=1) push_down(u);
        int mid=(tr[u].l+tr[u].r)>>1;
        if(l<=mid) update(u<<1,l,r,id,cnt);
        if(r>mid) update(u<<1|1,l,r,id,cnt);
        push_up(u);
    }
    
    ll query(int u,int l,int r){
        if(tr[u].l>=l && tr[u].r<=r){
            return tr[u].val;
        }
        if(tr[u].mul!=1) push_down(u);
        int mid=(tr[u].l+tr[u].r)>>1;
        ll sum=0;
        if(l<=mid) sum=(sum+query(u<<1,l,r))%mod;
        if(r>mid) sum=(sum+query(u<<1|1,l,r))%mod;
        return sum;
    }
     
    int main(){
        get_prime(110);
        scanf("%d %d",&n,&m);
        for(int i=1;i<=n;++i){
            scanf("%lld",&x[i]);
        }
        for(int i=1;i<=100;++i){
            int tmp=i;
            res[i]=i;
            for(int j=0;j<25;++j){
                if(tmp%prime[j]==0){
                    res[i]=res[i]/prime[j]*(prime[j]-1);
                    while(tmp%prime[j]==0) tmp/=prime[j],cntp[i][j]++;
                }
            }
        }
        build(1,1,n);
        for(int i=1;i<=m;++i){
            int op;
            scanf("%d",&op);
            if(op==1){
                int l,r;
                scanf("%d %d",&l,&r);
                printf("%lld\n",query(1,l,r));
            }
            else{
                int l,r,x;
                scanf("%d %d %d",&l,&r,&x);
                for(int i=0;i<25;++i){
                    if(cntp[x][i]) update(1,l,r,i,cntp[x][i]);
                }
            }
        }
        return 0;
    }
    
上一篇:关于表格输出的反思


下一篇:跨站跟踪攻击(CST/XST)