📜  重构段树(1)

📅  最后修改于: 2023-12-03 15:28:32.683000             🧑  作者: Mango

重构段树

段树(Segment Tree)是一种常见的数据结构,常用于区间查询和修改。但是,在实际应用中,在不同的场景下,需要对段树进行不同的优化,以满足不同的需求。因此,重构段树是程序员必备的技能之一。

什么是段树

段树是一种基于二叉树的数据结构,用于处理区间查询和修改。一般情况下,段树的叶子结点存储的是原始数据,而非叶子结点则是对应区间的统计信息。

重构段树的原因

在实际应用中,不同的场景下,需要对段树进行不同的优化,以满足不同的需求。常见的优化方式有以下几种:

延迟标记

如果需要对某个区间进行修改,我们可以遍历段树,逐一修改叶子结点,然后逐层向上更新非叶子结点。然而,这种做法的时间复杂度是O(NlogN),其中N为数据规模,对于大规模数据区间修改而言效率较低。

为了优化这种情况,我们引入了“延迟标记”这一概念。即对每个非叶子结点维护一个“标记”,表示该结点的子结点需要进行的修改操作,而该结点本身暂不进行操作。等到需要进行查询或修改时,再将该结点上的标记下传到其子结点。这样就可以省去逐个修改叶子结点的时间,时间复杂度可以优化到O(logN)。

线段树

线段树是一种特殊的段树,用于解决一维问题。它将一维区间分割成若干个区间,每个区间都有对应的结点。这种数据结构常用于处理数组相关的问题,如求某个区间的最值、和等。

树状数组

树状数组也是一种处理数组区间问题的数据结构,它不需要建树,而是利用了数学上的技巧,在数组上实现了基于树的数据结构。树状数组的空间复杂度更低,时间复杂度与线段树相当,但具有实现简单的优点。

重构段树的实现
延迟标记
void pushdown(int l, int r, int p) {
    if (tag[p]) {
        int mid = (l+r)>>1;
        sum[p<<1] += tag[p]*(mid-l+1);
        sum[p<<1|1] += tag[p]*(r-mid);
        tag[p<<1] += tag[p];
        tag[p<<1|1] += tag[p];
        tag[p] = 0;
    }
}
void update(int ql, int qr, int l, int r, int p, int val) {
    if (ql<=l && r<=qr) {
        sum[p] += (r-l+1)*val;
        tag[p] += val;
        return;
    }
    pushdown(l, r, p);
    int mid = (l+r)>>1;
    if (ql<=mid) update(ql, qr, l, mid, p<<1, val);
    if (qr>mid) update(ql, qr, mid+1, r, p<<1|1, val);
    sum[p] = sum[p<<1] + sum[p<<1|1];
}
int query(int ql, int qr, int l, int r, int p) {
    if (ql<=l && r<=qr) return sum[p];
    pushdown(l, r, p);
    int mid = (l+r)>>1, ans = 0;
    if (ql<=mid) ans += query(ql, qr, l, mid, p<<1);
    if (qr>mid) ans += query(ql, qr, mid+1, r, p<<1|1);
    return ans;
}
线段树
void build(int l, int r, int p) {
    if (l==r) {
        sum[p] = arr[l];
        return;
    }
    int mid = (l+r)>>1;
    build(l, mid, p<<1);
    build(mid+1, r, p<<1|1);
    sum[p] = max(sum[p<<1], sum[p<<1|1]);
}
void update(int pos, int val, int l, int r, int p) {
    if (l==r) {
        sum[p] = val;
        return;
    }
    int mid = (l+r)>>1;
    if (pos<=mid) update(pos, val, l, mid, p<<1);
    else update(pos, val, mid+1, r, p<<1|1);
    sum[p] = max(sum[p<<1], sum[p<<1|1]);
}
int query(int ql, int qr, int l, int r, int p) {
    if (ql<=l && r<=qr) return sum[p];
    int mid = (l+r)>>1, ans = 0;
    if (ql<=mid) ans = max(ans, query(ql, qr, l, mid, p<<1));
    if (qr>mid) ans = max(ans, query(ql, qr, mid+1, r, p<<1|1));
    return ans;
}
树状数组
int lowbit(int x) {
    return x&(-x);
}
void update(int pos, int val) {
    while (pos<=n) {
        c[pos] += val;
        pos += lowbit(pos);
    }
}
int query(int pos) {
    int ans = 0;
    while (pos>0) {
        ans += c[pos];
        pos -= lowbit(pos);
    }
    return ans;
}