模板
- "1 l r x":把 a[l] 到 a[r] 都增加 x。
- "2 l r x":把 b[l] 到 b[r] 都增加 x。
- "3 l r":输出 (a[l]*b[l] + a[l+1]*b[l+1] + ... + a[r]*b[r]) % 998244353。
线段树维护: 区间 a[i] 之和 sa。
区间 b[i] 之和 sb。
区间 a[i]*b[i] 之和 sab。
lazy tag 是区间内的 a[i] 的统一增加量,区间内的 b[i] 的统一增加量。
应用 lazy tag(pushdown): 如果把 a[i] 增加 x:
- sa 增加 sz * x,其中 sz 是该节点的区间大小。
- (a[i] + x) * b[i] = a[i] * b[i] + x * b[i],也就是 sab 增加量等于 x * sb。
对于 b[i] 增加 x 同理。
可以先增加 a[i],再增加 b[i],无需考虑同时增加。
struct Node {
int l, r;
int sa, sb, sab;
int ta, tb;
} tr[N << 2];
void pushup(int u) {
tr[u].sa = (tr[u << 1].sa + tr[u << 1 | 1].sa) % mod;
tr[u].sb = (tr[u << 1].sb + tr[u << 1 | 1].sb) % mod;
tr[u].sab = (tr[u << 1].sab + tr[u << 1 | 1].sab) % mod;
}
void build(int u, int l, int r) {
tr[u] = {l, r};
if (l == r) {
tr[u].sa = a[l] % mod, tr[u].sb = b[l] % mod;
tr[u].sab = 1LL * a[l] * b[l] % mod;
return;
}
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void helper(int u, int ta, int tb) {
int sz = (tr[u].r - tr[u].l + 1);
tr[u].sa = (tr[u].sa + 1LL * sz * ta % mod) % mod;
tr[u].sab = (tr[u].sab + 1LL * tr[u].sb * ta % mod) % mod;
tr[u].sb = (tr[u].sb + 1LL * sz * tb % mod) % mod;
tr[u].sab = (tr[u].sab + 1LL * tr[u].sa * tb % mod) % mod;
tr[u].ta = (tr[u].ta + ta) % mod;
tr[u].tb = (tr[u].tb + tb) % mod;
}
void pushdown(int u) {
helper(u << 1, tr[u].ta, tr[u].tb);
helper(u << 1 | 1, tr[u].ta, tr[u].tb);
tr[u].ta = tr[u].tb = 0;
}
void update(int u, int L, int R, int ta, int tb) {
if (L <= tr[u].l && tr[u].r <= R) {
helper(u, ta, tb);
return;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (L <= mid) {
update(u << 1, L, R, ta, tb);
}
if (mid < R) {
update(u << 1 | 1, L, R, ta, tb);
}
pushup(u);
}
int query(int u, int L, int R) {
if (L <= tr[u].l && tr[u].r <= R) {
return tr[u].sab;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
int res = 0;
if (L <= mid) {
res = query(u << 1, L, R);
}
if (mid < R) {
res = (res + query(u << 1 | 1, L, R)) % mod;
}
return res;
}
本题要求:
- 将区间 内每一个数都加上
- 输出 的区间和
#include <iostream>
using namespace std;
typedef long long LL;
const int N = 100010;
int n, m;
int w[N];
struct Node {
int l, r;
LL sum, add; // 区间和 懒标记
} tr[N << 2];
void pushup(int u)
{
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
void pushdown(int u)
{
auto &root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];
if (root.add)
{
left.add += root.add, left.sum += (LL)(left.r - left.l + 1) * root.add;
right.add += root.add, right.sum += (LL)(right.r - right.l + 1) * root.add;
root.add = 0;
}
}
void build(int u, int l, int r)
{
if (l == r) tr[u] = {l, r, w[r], 0};
else
{
tr[u] = {l, r};
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
void update(int u, int l, int r, int d)
{
if (tr[u].l >= l && tr[u].r <= r)
{
tr[u].sum += (LL)(tr[u].r - tr[u].l + 1) * d;
tr[u].add += d;
}
else
{
// 分裂
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) update(u << 1, l, r, d);
if (r > mid) update(u << 1 | 1, l, r, d);
pushup(u);
}
}
LL query(int u, int l, int r)
{
if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
LL sum = 0;
if (l <= mid) sum += query(u << 1, l, r);
if (r > mid) sum += query(u << 1 | 1, l, r);
return sum;
}
int main()
{
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) scanf("%d", &w[i]);
build(1, 1, n);
while (m -- )
{
char op[2];
scanf("%s", op);
if (op[0] == 'C')
{
int x, y, k;
scanf("%d%d%d", &x, &y, &k);
update(1, x, y, k);
}
else
{
int x, y;
scanf("%d%d", &x, &y);
printf("%lld\n", query(1, x, y));
}
}
return 0;
}
动态开点线段树
一道动态开点线段树模板题,也是实现区间更新和查询,参考宫水三叶题解的模板
单次操作最多创建 的点,空间复杂度为
按需创建区间,估算点数为 ,其中 分别代表值域大小和查询次数
const int N = 1e9 + 10, M = 500010;
class RangeModule {
public:
int cnt = 1;
struct Node {
int ls, rs, sum, add;
} tr[M];
void pushdown(int u, int len)
{
if (tr[u].ls == 0) tr[u].ls = ++ cnt;
if (tr[u].rs == 0) tr[u].rs = ++ cnt;
if (tr[u].add == 0) return;
if (tr[u].add == -1) tr[tr[u].ls].sum = tr[tr[u].rs].sum = 0;
else
{
tr[tr[u].ls].sum = len - len / 2;
tr[tr[u].rs].sum = len / 2;
}
tr[tr[u].ls].add = tr[tr[u].rs].add = tr[u].add;
tr[u].add = 0;
}
void pushup(int u)
{
tr[u].sum = tr[tr[u].ls].sum + tr[tr[u].rs].sum;
}
void update(int u, int lc, int rc, int l, int r, int v)
{
int len = rc - lc + 1;
if (l <= lc && rc <= r)
{
tr[u].sum = v == 1 ? len : 0;
tr[u].add = v;
return;
}
pushdown(u, len);
int mid = lc + rc >> 1;
if (l <= mid) update(tr[u].ls, lc, mid, l, r, v);
if (r > mid) update(tr[u].rs, mid + 1, rc, l, r, v);
pushup(u);
}
int query(int u, int lc, int rc, int l, int r)
{
if (l <= lc && rc <= r) return tr[u].sum;
pushdown(u, rc - lc + 1);
int mid = lc + rc >> 1, res = 0;
if (l <= mid) res += query(tr[u].ls, lc, mid, l, r);
if (r > mid) res += query(tr[u].rs, mid + 1, rc, l, r);
return res;
}
RangeModule() {
memset(tr, 0, sizeof tr);
}
void addRange(int left, int right) {
update(1, 1, N - 1, left, right - 1, 1);
}
bool queryRange(int left, int right) {
return query(1, 1, N - 1, left, right - 1) == right - left;
}
void removeRange(int left, int right) {
update(1, 1, N - 1, left, right - 1, -1);
}
};
题意可转换为线段树维护区间最大值,模板采用指针法,更新时直接赋值,而非累加差值的方式
class Solution {
public:
struct Node {
Node *ls, *rs;
int val, add;
};
void update(Node *node, int lc, int rc, int l, int r, int v)
{
if (l <= lc && rc <= r)
{
node->val = v;
node->add = v;
return;
}
pushdown(node);
int mid = lc + rc >> 1;
if (l <= mid) update(node->ls, lc, mid, l, r, v);
if (r > mid) update(node->rs, mid + 1, rc, l, r, v);
pushup(node);
}
int query(Node *node, int lc, int rc, int l, int r)
{
if (l <= lc && rc <= r) return node->val;
pushdown(node);
int mid = lc + rc >> 1, res = 0;
if (l <= mid) res = query(node->ls, lc, mid, l, r);
if (r > mid) res = max(res, query(node->rs, mid + 1, rc, l, r));
return res;
}
void pushdown(Node *node)
{
if (node->ls == nullptr) node->ls = new Node();
if (node->rs == nullptr) node->rs = new Node();
if (node->add == 0) return;
node->ls->add = node->add, node->rs->add = node->add;
node->ls->val = node->add, node->rs->val = node->add;
node->add = 0;
}
void pushup(Node *node)
{
node->val = max(node->ls->val, node->rs->val);
}
vector<int> fallingSquares(vector<vector<int>>& positions) {
vector<int> res;
int N = 1e9;
Node *root = new Node();
for (auto &t: positions)
{
int x = t[0], h = t[1], cur = query(root, 0, N, x, x + h - 1);
update(root, 0, N, x, x + h - 1, cur + h);
res.push_back(root->val);
}
return res;
}
};
练习题:CF915E
动态开点线段树,指针法会超时,用数组法的模板(与 715 的代码基本相同)可以过,但时间比较极限,还是用数组法比较好,就是需要估计点数比较麻烦,一般开成 M*50 即可
单点更新 无需懒标记
本题在 LIS 基础上新增相邻元素差不超过 的限制,自然想到 ,线段树的区间不再是下标,而是 的数
这个模板不用记录懒标记
const int N = 100010;
struct Node {
int l, r;
int val;
} tr[N * 4];
void pushup(int u)
{
tr[u].val = max(tr[u << 1].val, tr[u << 1 | 1].val);
}
void build(int u, int l, int r) // 这里不用 pushup,一开始都设为 0 即可
{
tr[u] = {l, r};
if (l == r) return;
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
}
void update(int u, int x, int v) // 修改当前值作为结尾的最长上升子序列的长度
{
if (tr[u].l == x && tr[u].r == x) tr[u].val = max(tr[u].val, v);
else
{
int mid = tr[u].l + tr[u].r >> 1;
if (x <= mid) update(u << 1, x, v);
else update(u << 1 | 1, x, v);
pushup(u);
}
}
int query(int u, int l, int r)
{
if (l > r) return 0; // 做一个特判,代码更简洁
if (l <= tr[u].l && tr[u].r <= r) return tr[u].val;
else
{
int mid = tr[u].l + tr[u].r >> 1, res = 0;
if (l <= mid) res = max(res, query(u << 1, l, r));
if (r > mid) res = max(res, query(u << 1 | 1, l, r));
return res;
}
}
class Solution {
public:
int lengthOfLIS(vector<int>& nums, int k) {
int mx = *max_element(nums.begin(), nums.end());
build(1, 1, mx);
for (int x: nums)
{
int len = query(1, max(x - k, 1), x - 1);
update(1, x, len + 1); // 以值 x 结尾的最长上升子序列的长度更新为 len + 1
}
return query(1, 1, mx); // 返回的就是以 1 ~ mx 作为结尾的最长上升子序列长度的最大值
}
};
转换为线段树操作
给你两个下标从 0 开始的数组
nums1
和nums2
,和一个二维数组queries
表示一些操作。总共有 3 种类型的操作:
- 操作类型 1 为
queries[i] = [1, l, r]
。你需要将nums1
从下标l
到下标r
的所有0
反转成1
或将1
反转成0
。l
和r
下标都从 0 开始。- 操作类型 2 为
queries[i] = [2, p, 0]
。对于0 <= i < n
中的所有下标,令nums2[i] = nums2[i] + nums1[i] * p
。- 操作类型 3 为
queries[i] = [3, 0, 0]
。求nums2
中所有元素的和。请你返回一个数组,包含所有第三种操作类型的答案。
分析:区间和就是区间内 1
的个数,每次翻转区间和会变成 ,根据这点修改与懒标记有关的逻辑就可以过了,实现上用异或表示翻转次数
typedef long long LL;
typedef pair<int, int> PII;
#define x first
#define y second
const int MOD = 1e9 + 7;
const int N = 100010;
class Solution {
public:
vector<int> nums;
struct Node {
int l, r;
LL sum, add;
} tr[N << 2];
void pushup(int u)
{
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
void pushdown(int u)
{
auto &root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];
if (root.add)
{
left.add ^= 1, left.sum = left.r - left.l + 1 - left.sum;
right.add ^= 1, right.sum = right.r - right.l + 1 - right.sum;
root.add ^= 1;
}
}
void build(int u, int l, int r)
{
if (l == r) tr[u] = {l, r, nums[r], 0};
else
{
tr[u] = {l, r};
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
void update(int u, int l, int r)
{
if (tr[u].l >= l && tr[u].r <= r)
{
tr[u].sum = tr[u].r - tr[u].l + 1 - tr[u].sum;
tr[u].add ^= 1;
}
else
{
// 分裂
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) update(u << 1, l, r);
if (r > mid) update(u << 1 | 1, l, r);
pushup(u);
}
}
LL query(int u, int l, int r)
{
if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
LL sum = 0;
if (l <= mid) sum += query(u << 1, l, r);
if (r > mid) sum += query(u << 1 | 1, l, r);
return sum;
}
vector<long long> handleQuery(vector<int>& nums1, vector<int>& nums2, vector<vector<int>>& que) {
int n = nums1.size();
nums1.insert(nums1.begin(), 0);
nums = nums1;
build(1, 1, n);
vector<LL> res;
LL s = accumulate(nums2.begin(), nums2.end(), 0LL);
for (auto& t: que)
{
if (t[0] == 1) update(1, t[1] + 1, t[2] + 1);
else if (t[0] == 2) s += query(1, 1, n) * t[1];
else res.push_back(s);
}
return res;
}
};
线段树二分
给你一个下标从 0 开始的正整数数组
heights
,其中heights[i]
表示第i
栋建筑的高度。如果一个人在建筑
i
,且存在i < j
的建筑j
满足heights[i] < heights[j]
,那么这个人可以移动到建筑j
。给你另外一个数组
queries
,其中queries[i] = [ai, bi]
。第i
个查询中,Alice 在建筑ai
,Bob 在建筑bi
。请你能返回一个数组
ans
,其中ans[i]
是第i
个查询中,Alice 和 Bob 可以相遇的 最左边的建筑 。如果对于查询i
,Alice 和 Bob 不能相遇,令ans[i]
为-1
。
class Solution {
public:
vector<int> mx;
void build(int u, int l, int r, vector<int> &w)
{
if (l == r)
{
mx[u] = w[l - 1];
return;
}
int mid = l + r >> 1;
build(u << 1, l, mid, w);
build(u << 1 | 1, mid + 1, r, w);
mx[u] = max(mx[u << 1], mx[u << 1 | 1]);
}
// 返回 [L,n] 中 > v 的最小下标(前三个参数表示线段树的节点信息)
int query(int u, int l, int r, int L, int v)
{
if (mx[u] <= v) return 0;
if (l == r) return l;
int mid = l + r >> 1;
if (L <= mid)
{
int pos = query(u << 1, l, mid, L, v);
if (pos > 0) return pos;
}
return query(u << 1 | 1, mid + 1, r, L, v);
}
vector<int> leftmostBuildingQueries(vector<int>& heights, vector<vector<int>>& queries) {
int n = heights.size();
mx.resize(n * 4);
build(1, 1, n, heights);
vector<int> res;
for (auto &q: queries)
{
int i = q[0], j = q[1];
if (i > j) swap(i, j);
if (i == j || heights[i] < heights[j])
res.push_back(j);
else
{
int pos = query(1, 1, n, j + 1, heights[i]);
res.push_back(pos - 1); // 不存在时刚好得到 -1
}
}
return res;
}
};
class BookMyShow {
public:
int n, m;
vector<int> mn;
vector<LL> sum;
// 将 idx 上的元素值增加 val
void add(int u, int l, int r, int idx, int v)
{
if (l == r)
{
mn[u] += v;
sum[u] += v;
return;
}
int mid = l + r >> 1;
if (idx <= mid) add(u << 1, l, mid, idx, v);
else add(u << 1 | 1, mid + 1, r, idx, v);
mn[u] = min(mn[u << 1], mn[u << 1 | 1]);
sum[u] = sum[u << 1] + sum[u << 1 | 1];
}
// 返回区间 [L,R] 内的元素和
LL query(int u, int l, int r, int L, int R)
{
if (L <= l && r <= R) return sum[u];
LL res = 0;
int mid = l + r >> 1;
if (L <= mid) res += query(u << 1, l, mid, L, R);
if (R > mid) res += query(u << 1 | 1, mid + 1, r, L, R);
return res;
}
// 返回区间 [1,R] 中 <= val 的最靠左的位置,不存在时返回 0
int get(int u, int l, int r, int R, int v)
{
if (mn[u] > v) return 0;
if (l == r) return l;
int mid = l + r >> 1;
if (mn[u << 1] <= v) return get(u << 1, l, mid, R, v);
if (mid < R) return get(u << 1 | 1, mid + 1, r, R, v);
return 0;
}
BookMyShow(int n, int m) {
this->n = n, this->m = m;
mn.resize(n * 4);
sum.resize(n * 4);
}
vector<int> gather(int k, int maxRow) {
int i = get(1, 1, n, maxRow + 1, m - k);
if (i == 0) return {};
int seats = query(1, 1, n, i, i);
add(1, 1, n, i, k);
return {i - 1, seats};
}
bool scatter(int k, int maxRow) {
if ((LL)m * (maxRow + 1) - query(1, 1, n, 1, maxRow + 1) < k)
return false;
// 从第一个没有坐满的排开始占座
for (int i = get(1, 1, n, maxRow + 1, m - 1); ; i++)
{
int left_seats = m - query(1, 1, n, i, i);
if (k <= left_seats)
{
add(1, 1, n, i, k);
return true;
}
k -= left_seats;
add(1, 1, n, i, left_seats);
}
}
};
线段树维护单点修改,区间最大子段和
模板题为:https://www.luogu.com.cn/problem/P4513
新开4个域——max,maxl,maxr,sum,其中sum为该区间的和,max为该区间上的最大子段和,maxl为必须包含左端点的最大子段和,maxr为必须包含右端点的最大子段和。
更新的逻辑见题解,挺好理解
int a[N], n, m, op, x, y;
struct Node {
int maxv, maxl, maxr, sum;
} tr[N << 2];
void pushup(Node &root, const Node &left, const Node &right) // 加 const 才能传右值
{
if (left.maxr < 0 && right.maxl < 0) // 连接点小于 0 取左右儿子的区间中
root.maxv = max(left.maxv, right.maxv);
else // 否则可取 左右的边界段
{
root.maxv = 0;
if (left.maxr > 0) root.maxv += left.maxr;
if (right.maxl > 0) root.maxv += right.maxl;
}
root.maxv = max(root.maxv, max(left.maxv, right.maxv));
root.maxl = max(left.maxl, left.sum + right.maxl);
root.maxr = max(right.maxr, right.sum + left.maxr);
root.sum = left.sum + right.sum;
}
void build(int u, int l, int r)
{
if (l == r)
{
tr[u].sum = tr[u].maxv = tr[u].maxl = tr[u].maxr = a[l];
return;
}
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(tr[u], tr[u << 1], tr[u << 1 | 1]);
}
void update(int p, int v, int u, int l, int r)
{
if (l == r)
{
tr[u].sum = tr[u].maxl = tr[u].maxr = tr[u].maxv = v;
return;
}
int mid = l + r >> 1;
if (p <= mid) update(p, v, u << 1, l, mid);
else update(p, v, u << 1 | 1, mid + 1, r);
pushup(tr[u], tr[u << 1], tr[u << 1 | 1]);
}
Node query(int L, int R, int u, int l, int r)
{
if (L <= l && r <= R) return tr[u];
int mid = l + r >> 1;
if (L <= mid && mid < R)
{
Node res; // 每次查询都要做合并操作
pushup(res, query(L, R, u << 1, l, mid), query(L, R, u << 1 | 1, mid + 1, r));
return res;
}
else if (L <= mid) return query(L, R, u << 1, l, mid);
else return query(L, R, u << 1 | 1, mid + 1, r);
}
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(nullptr);
cin >> n >> m;
for (int i = 1; i <= n; i++) cin >> a[i];
build(1, 1, n);
while (m -- )
{
cin >> op >> x >> y;
if (op == 1)
{
if (x > y) swap(x, y);
cout << query(x, y, 1, 1, n).maxv << endl;
}
else update(x, y, 1, 1, n);
}
return 0;
}
线段树维护区间 max,只查询,无需建树
例题:
https://codeforces.com/problemset/problem/1691/D
输入 t(≤1e5) 表示 t 组数据,每组数据输入 n(≤2e5) 和长为 n 的数组 a (-1e9≤a[i]≤1e9)。所有数据的 n 之和不超过 2e5。
请你判断,对数组 a 的每个非空子数组 b,是否都有 max(b) >= sum(b)?
如果是,输出 YES,否则输出 NO。
注:子数组是连续的。
考虑 a[i], 当 a[i] 作为子数组最大值时,子数组的左右端点可以怎么取?维护左边、右边的更大元素,这一点用单调栈维护
假设不满足的区间为 (j,k),其中 j<i<k,转化为 sum(j,i-1)+a[i]+sum(i+1,k)>a[i]
,即两个 sum 至少其中一个大于 0,根据这点可以判断是否不满足了
实现上,维护前缀和以及后缀和,用线段树维护前后缀和的区间最大值,然后判断即可,具体看代码
cin >> n;
for (int i = 0; i < n; i++) cin >> a[i];
_n = n;
while (__builtin_popcount(_n) != 1) _n ++; // Round off n to next power of 2
vector<LL> preTree(2 * _n, -1e18), sufTree(2 * _n, -1e18);
for (int i = 0; i < n; i++)
{
// 初始化 从 _n 开始
preTree[_n + i] = pre[i];
sufTree[_n + i] = suf[i];
}
for (int i = _n - 1; i; i--)
{
// pushup 操作
preTree[i] = max(preTree[2 * i], preTree[2 * i + 1]);
sufTree[i] = max(sufTree[2 * i], sufTree[2 * i + 1]);
}
// 常规 query 操作
function<LL(vector<LL>&, int, int, int, int, int)> query =
[&](vector<LL> &tr, int u, int ns, int ne, int qs, int qe)
{
if (qe < ns || qs > ne) return (LL)-1e18;
if (qs <= ns && ne <= qe) return tr[u];
int mid = ns + ne >> 1;
LL l = query(tr, u << 1, ns, mid, qs, qe);
LL r = query(tr, u << 1 | 1, mid + 1, ne, qs, qe);
return max(l, r);
};
for (int i = 0; i < n; i++)
{
// 注意一下区间是 [0, _n - 1] 因为 1 号点的管辖区间是 [0, _n - 1] 对应初始化
LL rMax = query(preTree, 1, 0, _n - 1, i + 1, ng[i] - 1) - pre[i];
LL lMax = query(sufTree, 1, 0, _n - 1, pg[i] + 1, i - 1) - suf[i];
if (max(lMax, rMax) > 0)
{
st = 0;
break;
}
}
vector 线段树,单点修改,维护前缀最大值
多看看这样的代码
vector<int> tr;
void update(int u, int l, int r, int x, int v) {
if (l == r) {
tr[u] = v;
return;
}
int mid = l + r >> 1;
if (x <= mid) {
update(u << 1, l, mid, x, v);
} else {
update(u << 1 | 1, mid + 1, r, x, v);
}
tr[u] = max(tr[u << 1], tr[u << 1 | 1]);
}
// [0, x] 中最大值
int query(int u, int l, int r, int x) {
if (r <= x) {
return tr[u];
}
int mid = l + r >> 1;
if (x <= mid) {
return query(u << 1, l, mid, x);
}
return max(tr[u << 1], query(u << 1 | 1, mid + 1, r, x));
}
单点更新,维护最大子段和的模板(LC版)
来自 这篇题解
// 模板:线段树维护最大子段和
struct Node {
long long sm, lv, rv, ans;
};
Node tree[n * 4 + 5];
auto merge = [&](Node nl, Node nr) {
return Node {
nl.sm + nr.sm,
max(nl.lv, nl.sm + nr.lv),
max(nl.rv + nr.sm, nr.rv),
max({nl.ans, nr.ans, nl.rv + nr.lv})
};
};
auto initNode = [&](int val) {
return Node { val, val, val, val };
};
auto build = [&](this auto &&build, int id, int l, int r) -> void {
if (l == r) tree[id] = initNode(nums[l]);
else {
int nxt = id << 1, mid = (l + r) >> 1;
build(nxt, l, mid); build(nxt | 1, mid + 1, r);
tree[id] = merge(tree[nxt], tree[nxt | 1]);
}
};
auto modify = [&](this auto &&modify, int id, int l, int r, int pos, int val) -> void {
if (l == r) tree[id] = initNode(val);
else {
int nxt = id << 1, mid = (l + r) >> 1;
if (pos <= mid) modify(nxt, l, mid, pos, val);
else modify(nxt | 1, mid + 1, r, pos, val);
tree[id] = merge(tree[nxt], tree[nxt | 1]);
}
};
// 线段树模板结束
build(1, 0, n - 1);