dsu on tree 是一种处理树上不带修改的离线子树问题的算法,复杂度是 O(nlogn)
一般来说就是把小的集合往大的上合并的意思,这个“大”指的就是重儿子
流程:对于 u
- 处理好 u 所有轻儿子及其子树的答案,然后删除刚刚算出来的轻儿子子树信息对于 u 的贡献
- 处理好重儿子及其子树的贡献,算好后不删除
- 暴力统计所有轻儿子及其子树及 u 本身的贡献,与上一步算出来的重儿子的贡献合并,就得到了 u 的答案
例题:CF600E
n 个节点以 1 为根的有根树,每个节点有一个颜色
如果一种颜色在 x 为根的子树中出现最多,称其占主导地位,可能多种颜色占主导地位
求出每个节点为根的子树中,主导颜色的编号之和
n, c <= 1e5
开
cnt[N]
数组表示“当前”子树 i 颜色的数量- 第一步时,“当前”指以 u 的轻儿子为根的子树
- 最后统计时,“当前”指以 u 为根的子树
递归重儿子后保留贡献,然后把轻儿子的贡献往里加,就实现了小的往大的合并
c++
int h[N], e[N * 2], ne[N * 2], idx;
int c[N], sz[N], son[N], n, a, b;
LL res[N], sum;
int cnt[N], flag, maxc; //cnt存放某颜色在“当前”子树中的数量 flag用于标记重儿子,maxc用于更新最大值
void add(int a, int b)
{
e[idx] = b, ne[idx] = h[a], h[a] = idx ++;
}
void dfs(int u, int fa)
{
sz[u] = 1;
for (int i = h[u]; i != -1; i = ne[i])
{
int j = e[i];
if (j == fa) continue;
dfs(j, u);
sz[u] += sz[j];
if (sz[j] > sz[son[u]]) son[u] = j;
}
}
//TODO 统计某结点及其所有轻儿子的贡献
void count(int u, int fa, int v)
{
cnt[c[u]] += v; //val为正为负可以控制是增加贡献还是删除贡献
if (cnt[c[u]] > maxc) //找最大值,基操吧
{
maxc = cnt[c[u]];
sum = c[u];
}
else if (cnt[c[u]] == maxc) //这样做是因为如果两个颜色数量相同那都得算
sum += c[u];
for (int i = h[u]; i != -1; i = ne[i]) //排除被标记的重儿子,统计其它儿子子树信息
{
int j = e[i];
if (j == fa || j == flag) continue; //不能写if(v==f||v==son[u]) continue;
count(j, u, v);
}
}
void dfs(int u, int fa, bool keep)
{
//* 第一步:搞轻儿子及其子树算其答案删贡献
for (int i = h[u]; i != -1; i = ne[i])
{
int j = e[i];
if (j == fa || j == son[u]) continue;
dfs(j, u, false);
}
//* 第二步:搞重儿子及其子树算其答案不删贡献
if (son[u])
{
dfs(son[u], u, true);
flag = son[u];
}
//* 第三步:暴力统计u及其所有轻儿子的贡献合并到刚算出的重儿子信息里
count(u, fa, 1);
flag = 0;
res[u] = sum;
//* 把需要删除贡献的删一删
if (!keep)
{
count(u, fa, -1);
sum = maxc = 0; //这是因为count函数中会改变这两个变量值
}
}
int main()
{
scanf("%d", &n);
memset(h, -1, sizeof h);
for (int i = 1; i <= n; i++) scanf("%d", &c[i]);
for (int i = 1; i < n; i++)
{
scanf("%d%d", &a, &b);
add(a, b);
add(b, a);
}
dfs(1, 0);
dfs(1, 0, 0);
for (int i = 1; i <= n; i++)
printf("%lld\n", res[i]);
return 0;
}
每个节点有一个权值,返回每个节点为根的子树中,最小缺失的权值是多少。
权值范围为 [1,1e5]
(都是传 map ,set 之类的集合了,使用 std::move()
提速,或者开 vector<Set>[n],使用 v[a].swap(v[b]) 来合并)
c++
class Solution {
public:
vector<int> smallestMissingValueSubtree(vector<int>& parents, vector<int>& nums) {
int n = parents.size();
vector<int> son[n];
for (int i = 1; i < n; i++)
son[parents[i]].push_back(i);
vector<int> mex(n); // min excluded
function<unordered_set<int>(int)> dfs = [&](int u)
{
unordered_set<int> fa;
mex[u] = 1;
for (int x: son[u])
{
auto s = std::move(dfs(x));
if (s.size() > fa.size())
swap(s, fa);
for (int x: s)
fa.insert(x);
if (mex[x] > mex[u]) mex[u] = mex[x];
}
fa.insert(nums[u]);
while (fa.count(mex[u])) mex[u] ++;
return fa;
};
dfs(0);
return mex;
}
};
如果一棵树中存在的每种颜色的结点个数都相同,则我们称它是一棵颜色平衡树。
这道蓝桥杯的题目就可以秒掉了,但是需要开启 O2 优化,否则 TLE
c++
vector<int> g[N];
int col[N], n, c, f, res;
unordered_map<int, int> cnt[N];
void dfs(int u)
{
for (int v: g[u])
{
dfs(v);
if (cnt[v].size() > cnt[u].size()) cnt[u].swap(cnt[v]);
for (auto &[x, y]: cnt[v])
cnt[u][x] += y;
}
cnt[u][col[u]] ++;
int all = -1, st = 1;
for (auto &[_, y]: cnt[u])
if (all == -1) all = y;
else if (y != all)
{
st = 0;
break;
}
if (st) res ++;
}
int main()
{
scanf("%d", &n);
for (int i = 1; i <= n; i++)
{
scanf("%d%d", &c, &f);
g[f].push_back(i);
col[i] = c;
}
dfs(1);
printf("%d\n", res);
return 0;
}
更好的写法
c++
vector<int> g[N];
int c[N], sz[N], son[N], cnt[N], ccnt[N], n, res, f;
// cnt[N] 记录每种颜色出现的次数
// ccnt[i] 表示有多少种颜色出现 i 次
// 判断逻辑是 cnt[c[u]] * ccnt[cnt[c[u]]] == sz[u]
void dfs(int u)
{
sz[u] = 1;
for (int v: g[u])
{
dfs(v);
sz[u] += sz[v];
if (sz[v] > sz[son[u]]) son[u] = v;
}
}
void count(int u, int val)
{
-- ccnt[cnt[c[u]]];
cnt[c[u]] += val;
++ ccnt[cnt[c[u]]];
for (int v: g[u]) count(v, val);
}
void dfs(int u, bool keep)
{
for (int v: g[u])
if (v != son[u]) dfs(v, false);
if (son[u]) dfs(son[u], true);
-- ccnt[cnt[c[u]]]; // 原来的
++ cnt[c[u]]; // 新的
++ ccnt[cnt[c[u]]];
for (int v: g[u])
if (v != son[u]) count(v, 1);
if (cnt[c[u]] * ccnt[cnt[c[u]]] == sz[u]) res ++;
if (!keep) count(u, -1);
}
int main()
{
scanf("%d", &n);
for (int i = 1; i <= n; i++)
{
scanf("%d%d", &c[i], &f);
g[f].push_back(i);
}
dfs(1);
dfs(1, 0);
printf("%d\n", res);
return 0;
}