[CF191C] Fools and Roads - LCA,树上差分

Description

有一颗 \(n\) 个节点的树,\(k\) 次旅行,问每一条边被走过的次数。

Solution

模板题,复习一下 LCA 与树上差分的基本用法。

首先,对于每条边被走过的次数,我们把它转化为深度较大的那个点被经过的次数。

对于每一次旅行 \(begin,end\),计算 \(lca=LCA(begin,end)\),则每次操作转化为对 \(lca \sim begin\) 修改一次,对 \(lca \sim end\) 修改一次,对 \(lca\) 自身反向修改两次。

LCA 的求法很多,这里复习一下 ST 表求 LCA。

首先我们需要预处理出每个点在它的 Euler 遍历序上第一次出现的位置 \(first_occur[p]\) 和最后一次出现的位置 \(last_occur[p]\),那么对于一 \((p,q)\),其 LCA 即为 \(last_occur[p] \sim first_occfur[q]\) 这一的 Euler 遍历序列中出现过的深度最小的点。

所谓 Euler 遍历序列是这样一种 DFS 序列,每次我们访问一个点(包括回溯到到它)都将其扔进队列的尾部,因此这样我们得到的序列长度是总的度数 \(+1\)。

ST 表就被用来支持 Euler 遍历序列的 RMQ 问题,相当于我们构造了一种 Sequence,其支持对其任意一段以对应关键字的 RMQ,其 key 即为树上的深度 depth,显然我们需要维护的是深度最小的点。

在 Sequence 的实现中,我们预处理出 \(f[i][j]\) 表示在 \([i,i+2^j)\) 这个区间中的 \(key\) 最小的点是哪一个。

整个实现的结构是这样,Sequence 提供一个支持区间最小关键字位置询问的操作接口,Tree 类用于树的存储,自己进行 DFS,并构建一个 Euler 遍历序列以 Sequence 形式存储,预处理的时候调用 Sequence 类的预处理,在遇到 LCA 询问的时候去调用 Sequence 提供的 RMQ 接口。对答案的计算放在主程序中进行。

#include <bits/stdc++.h>
using namespace std;

#define dbg(x) ;

struct Sequence
{
private:
    int n;
    vector<int> a;
    vector<int> key;
    vector<vector<int>> f;

public:
    Sequence(int n = 1) : n(n)
    {
        a.resize(n + 2);
        f.resize(n + 2);
        key.resize(n + 2);
        int tmp = log2(n) + 2;
        for (auto &i : f)
        {
            i.resize(tmp);
        }
    }

    void Resize(int n_)
    {
        n = n_;

        a.resize(n + 2);
        f.resize(n + 2);
        key.resize(n + 2);
        int tmp = log2(n) + 2;
        for (auto &i : f)
        {
            i.resize(tmp);
        }
    }

    void Set(int id, int val, int valkey)
    {
        a[id] = val;
        key[id] = valkey;
    }

    void Presolve()
    {
        for (int i = 1; i <= n; i++)
        {
            f[i][0] = i;
        }

        int tmp = log2(n) + 2;

        for (int j = 1; j < tmp; j++)
        {
            int disp = 1 << (j - 1);
            for (int i = 1; i + disp <= n; i++)
            {
                if (key[f[i][j - 1]] < key[f[i + disp][j - 1]])
                {
                    f[i][j] = f[i][j - 1];
                }
                else
                {
                    f[i][j] = f[i + disp][j - 1];
                }
            }
            for (int i = n - disp + 1; i <= n; i++)
            {
                f[i][j] = f[i][j - 1];
            }
        }
    }

    int QueryPos(int l, int r)
    {
        int lg2 = log2(r - l + 1);
        int mid = r - (1 << lg2) + 1;
        if (key[f[l][lg2]] < key[f[mid][lg2]])
        {
            return f[l][lg2];
        }
        else
        {
            return f[mid][lg2];
        }
    }

    int QueryVal(int l, int r)
    {
        return a[QueryPos(l, r)];
    }

    int QueryKey(int l, int r)
    {
        return key[QueryPos(l, r)];
    }

    int bfquery(int l, int r)
    {
        int ans = l;
        for (int i = l + 1; i <= r; i++)
        {
            if (key[i] < key[ans])
            {
                ans = i;
            }
        }
        return ans;
    }

    bool check(int l, int r)
    {
        if (bfquery(l, r) != QueryPos(l, r))
        {
            return false;
        }
        return true;
    }
};

struct Tree
{
    vector<vector<int>> g;
    vector<int> dep;
    vector<int> euler_seq;
    vector<int> euler_begin;
    vector<int> euler_end;
    vector<int> father;

    Sequence seq;

    int n;
    Tree(int n) : n(n)
    {
        g.resize(n + 2);
        dep.resize(n + 2);
        father.resize(n + 2);
        euler_begin.resize(n + 2);
        euler_end.resize(n + 2);
        euler_seq.resize(1);
    }

public:
    void Make(int p, int q)
    {
        g[p].push_back(q);
        g[q].push_back(p);
    }

private:
    void dfs(int p, int fa)
    {

        euler_seq.push_back(p);
        euler_begin[p] = euler_seq.size() - 1;

        for (int q : g[p])
        {
            if (q != fa)
            {
                dep[q] = dep[p] + 1;
                father[q] = p;
                dfs(q, p);
                euler_seq.push_back(p);
            }
        }
        euler_end[p] = euler_seq.size() - 1;
    }

private:
    void dfs()
    {
        dep[1] = 1;
        dfs(1, 0);
    }

public:
    void dbg_eulerseqtest()
    {

    }

public:
    void Presolve()
    {
        this->dfs();
        // this->dbg_eulerseqtest();

        int len = euler_seq.size() - 1;

        seq.Resize(len);

        for (int i = 1; i <= len; i++)
        {
            seq.Set(i, euler_seq[i], dep[euler_seq[i]]);
        }

        seq.Presolve();
    }

    int QueryLCA(int p, int q)
    {
        if (euler_end[p] <= euler_begin[q])
        {
            return seq.QueryVal(euler_end[p], euler_begin[q]);
        }
        if (euler_end[q] <= euler_begin[p])
        {
            return seq.QueryVal(euler_end[q], euler_begin[p]);
        }
        if (euler_begin[p] <= euler_begin[q])
        {
            return seq.QueryVal(euler_begin[p], euler_begin[q]);
        }
        if (euler_begin[q] <= euler_begin[p])
        {
            return seq.QueryVal(euler_begin[q], euler_begin[p]);
        }
        return -1;
    }

private:
    void dfs2(int p, int fa, vector<int> &a)
    {
        for (int q : g[p])
        {
            if (q != fa)
            {
                dfs2(q, p, a);
                a[p] += a[q];
            }
        }
    }

public:
    void CalculateSum(vector<int> &a)
    {
        dfs2(1, 0, a);
    }

public:
    int Edge2Vertex(int p, int q)
    {
        if (dep[p] > dep[q])
        {
            return p;
        }
        else
        {
            return q;
        }
    }

public:
    int GetFather(int p)
    {
        return father[p];
    }
};

void dbg_sequencetest()
{
    Sequence seq(10);
    for (int i = 1; i <= 10; i++)
    {
        seq.Set(i, rand(), rand());
    }
    seq.Presolve();
    for (int i = 1; i <= 100; i++)
    {
        int l = rand() % 10 + 1;
        int r = rand() % 10 + 1;
        if (l > r)
            swap(l, r);
        if (seq.check(l, r) == 0)
        {
        }
    }
}

int main()
{
    // dbg_sequencetest();

    int n;
    cin >> n;

    vector<pair<int, int>> edge(n + 2);

    Tree tree(n);

    for (int i = 1; i < n; i++)
    {
        int u, v;
        cin >> u >> v;
        tree.Make(u, v);
        edge[i] = {u, v};
    }

    tree.Presolve();

    int k;
    cin >> k;

    vector<int> cnt;
    cnt.resize(n + 2);

    for (int i = 1; i <= k; i++)
    {
        int u, v;
        cin >> u >> v;
        int lca = tree.QueryLCA(u, v);
      

        cnt[u]++;
        cnt[v]++;
        cnt[(lca)] -=2;
    }


    tree.CalculateSum(cnt);

    for (int i = 1; i < n; i++)
    {
        int id = tree.Edge2Vertex(edge[i].first, edge[i].second);
        cout << cnt[id] << " ";
    }
}
上一篇:邮件读取协议POP3和IMAP


下一篇:java大作业--邮件系统