dsu on tree 是一种处理树上不带修改的离线子树问题的算法,复杂度是 O(nlogn)
一般来说就是把小的集合往大的上合并的意思,这个“大”指的就是重儿子
流程:对于 u
- 处理好 u 所有轻儿子及其子树的答案,然后删除刚刚算出来的轻儿子子树信息对于 u 的贡献
- 处理好重儿子及其子树的贡献,算好后不删除
- 暴力统计所有轻儿子及其子树及 u 本身的贡献,与上一步算出来的重儿子的贡献合并,就得到了 u 的答案
例题:CF600E
- n 个节点以 1 为根的有根树,每个节点有一个颜色 
- 如果一种颜色在 x 为根的子树中出现最多,称其占主导地位,可能多种颜色占主导地位 
- 求出每个节点为根的子树中,主导颜色的编号之和 
- n, c <= 1e5 
- 开 - cnt[N]数组表示“当前”子树 i 颜色的数量- 第一步时,“当前”指以 u 的轻儿子为根的子树
- 最后统计时,“当前”指以 u 为根的子树
 
- 递归重儿子后保留贡献,然后把轻儿子的贡献往里加,就实现了小的往大的合并 
c++
int h[N], e[N * 2], ne[N * 2], idx;
int c[N], sz[N], son[N], n, a, b;
LL res[N], sum;
int cnt[N], flag, maxc; //cnt存放某颜色在“当前”子树中的数量 flag用于标记重儿子,maxc用于更新最大值
 
void add(int a, int b)
{
    e[idx] = b, ne[idx] = h[a], h[a] = idx ++;
}
 
void dfs(int u, int fa)
{
    sz[u] = 1;
    for (int i = h[u]; i != -1; i = ne[i])
    {
        int j = e[i];
        if (j == fa) continue;
        dfs(j, u);
        sz[u] += sz[j];
        if (sz[j] > sz[son[u]]) son[u] = j;
    }
}
//TODO 统计某结点及其所有轻儿子的贡献
void count(int u, int fa, int v)
{
    cnt[c[u]] += v; //val为正为负可以控制是增加贡献还是删除贡献
    if (cnt[c[u]] > maxc) //找最大值,基操吧
    {
        maxc = cnt[c[u]];
        sum = c[u];
    }
    else if (cnt[c[u]] == maxc) //这样做是因为如果两个颜色数量相同那都得算
        sum += c[u];
    for (int i = h[u]; i != -1; i = ne[i]) //排除被标记的重儿子,统计其它儿子子树信息
    {
        int j = e[i];
        if (j == fa || j == flag) continue; //不能写if(v==f||v==son[u]) continue;
        count(j, u, v);
    }
}
 
void dfs(int u, int fa, bool keep)
{
     //* 第一步:搞轻儿子及其子树算其答案删贡献
    for (int i = h[u]; i != -1; i = ne[i])
    {
        int j = e[i];
        if (j == fa || j == son[u]) continue;
        dfs(j, u, false);
    }
    //* 第二步:搞重儿子及其子树算其答案不删贡献
    if (son[u])
    {
        dfs(son[u], u, true);
        flag = son[u];
    }
    //* 第三步:暴力统计u及其所有轻儿子的贡献合并到刚算出的重儿子信息里
    count(u, fa, 1);
    flag = 0;
    res[u] = sum;
    //* 把需要删除贡献的删一删
    if (!keep)
    {
        count(u, fa, -1);
        sum = maxc = 0; //这是因为count函数中会改变这两个变量值
    }
}
 
int main()
{
    scanf("%d", &n);
    memset(h, -1, sizeof h);
    for (int i = 1; i <= n; i++) scanf("%d", &c[i]);
    for (int i = 1; i < n; i++)
    {
        scanf("%d%d", &a, &b);
        add(a, b);
        add(b, a);
    }
    dfs(1, 0);
    dfs(1, 0, 0);
    for (int i = 1; i <= n; i++) 
        printf("%lld\n", res[i]);
    return 0;
}每个节点有一个权值,返回每个节点为根的子树中,最小缺失的权值是多少。
权值范围为 [1,1e5]
(都是传 map ,set 之类的集合了,使用 std::move() 提速,或者开 vector<Set>[n],使用 v[a].swap(v[b]) 来合并)
c++
class Solution {
public:
    vector<int> smallestMissingValueSubtree(vector<int>& parents, vector<int>& nums) {
        int n = parents.size();
        vector<int> son[n];
        for (int i = 1; i < n; i++)
            son[parents[i]].push_back(i);
        vector<int> mex(n); // min excluded
        function<unordered_set<int>(int)> dfs = [&](int u)
        {
            unordered_set<int> fa;
            mex[u] = 1;
            for (int x: son[u])
            {
                auto s = std::move(dfs(x));
                if (s.size() > fa.size())
                    swap(s, fa);
                for (int x: s)
                    fa.insert(x);
                if (mex[x] > mex[u]) mex[u] = mex[x];
            }
            fa.insert(nums[u]);
            while (fa.count(mex[u])) mex[u] ++;
            return fa;
        };
        dfs(0);
        return mex;
    }
};如果一棵树中存在的每种颜色的结点个数都相同,则我们称它是一棵颜色平衡树。
这道蓝桥杯的题目就可以秒掉了,但是需要开启 O2 优化,否则 TLE
c++
vector<int> g[N];
int col[N], n, c, f, res;
unordered_map<int, int> cnt[N];
void dfs(int u)
{
    
    for (int v: g[u])
    {
        dfs(v);
        if (cnt[v].size() > cnt[u].size()) cnt[u].swap(cnt[v]);
        for (auto &[x, y]: cnt[v]) 
            cnt[u][x] += y;
    }
    cnt[u][col[u]] ++;
    int all = -1, st = 1;
    for (auto &[_, y]: cnt[u]) 
        if (all == -1) all = y;
        else if (y != all) 
        {
            st = 0;
            break;
        }
    if (st) res ++;
}
int main()
{
    scanf("%d", &n);
    for (int i = 1; i <= n; i++)
    {
        scanf("%d%d", &c, &f);
        g[f].push_back(i);
        col[i] = c;
    }
    dfs(1);
    printf("%d\n", res);
    return 0;
}更好的写法
c++
vector<int> g[N];
int c[N], sz[N], son[N], cnt[N], ccnt[N], n, res, f;
// cnt[N] 记录每种颜色出现的次数
// ccnt[i] 表示有多少种颜色出现 i 次
// 判断逻辑是 cnt[c[u]] * ccnt[cnt[c[u]]] == sz[u]
void dfs(int u)
{
    sz[u] = 1;
    for (int v: g[u])
    {
        dfs(v);
        sz[u] += sz[v];
        if (sz[v] > sz[son[u]]) son[u] = v;
    }
}
void count(int u, int val)
{
    -- ccnt[cnt[c[u]]];
    cnt[c[u]] += val;
    ++ ccnt[cnt[c[u]]];
    for (int v: g[u]) count(v, val);
}
void dfs(int u, bool keep)
{
    for (int v: g[u])
        if (v != son[u]) dfs(v, false);
    if (son[u]) dfs(son[u], true);
    -- ccnt[cnt[c[u]]]; // 原来的
    ++ cnt[c[u]]; // 新的
    ++ ccnt[cnt[c[u]]];
    for (int v: g[u])
        if (v != son[u]) count(v, 1);
    if (cnt[c[u]] * ccnt[cnt[c[u]]] == sz[u]) res ++;
    if (!keep) count(u, -1);
}
int main()
{
    scanf("%d", &n);
    for (int i = 1; i <= n; i++)
    {
        scanf("%d%d", &c[i], &f);
        g[f].push_back(i);
    }
    dfs(1);
    dfs(1, 0);
    printf("%d\n", res);
    return 0;
}