树状数组与线段树

本文最后更新于:2024年2月12日 晚上

数据结构 区间求和 区间最大值 区间修改 单点修改
前缀和
差分
树状数组
线段数

树状数组

单点修改,求区间和

  • 初始化时间复杂度度:$O(N)$
  • 单次修改时间复杂度:$O(logN)$
  • 单次修改时间复杂度:$O(logN)$
  • 空间复杂度:$O(N)$

单点修改,区间求最值:

  • 单点修改:$O((logn)^2)$
  • 区间求最值:$O((logn)^2)$

修改

题目:

1649. 通过指令创建有序数组

307. 区域和检索 - 数组可修改

2179. 统计数组中好三元组数目

6206. 最长递增子序列 II

// 单点更新,区间求和
class BIT {
    int[] arr;
    
    public BIT(int n) {
        this.arr = new int[n + 10];
    }
    
    public int lowbit(int x) {
        return -x & x;
    }
    
    public void add(int x, int val) {
        while (x < arr.length) {
            arr[x] += val;
            x += lowbit(x);
        }
    }
    
    public int query(int x) {
        int res = 0;
        while (x > 0) {
            res += arr[x];
            x -= lowbit(x);
        }
        return res;
    }
}
// 单点更新,区间求最大值
class BIT {
    int[] arr; // 记录数值
    int[] max; // 记录最大值
    
    public BIT(int n) {
        this.arr = new int[n + 10];
        this.max = new int[n + 10];
    }
    
    public int lowbit(int x) {
        return -x & x;
    }
    
    public void update(int x, int val) {
        arr[x] = val; // 单点更新
        while (x < arr.length) {
            max[x] = arr[x];
            int lx = lowbit(x);
            for (int i = 1; i < lx; i <<= 1) {
                max[x] = Math.max(max[x], max[x - i]);
            }
            x += lowbit(x);
        }
    }
    
    public int query(int l, int r) {
        int ans = 0;
        while (l <= r) {
            ans = Math.max(ans, arr[r--]);
            while (r - lowbit(r) >= l) {
                ans = Math.max(ans, max[r]);
                r -= lowbit(r);
            }
        }
        return ans;
    }
}

线段树:Segment Tree

题目:

3805. 环形数组

POJ 3468

线段树可以在$O(logN)$​的时间内实现单点修改、区间修改、区间查询(区间求和、求区间最大值、求区间最小值)等操作。

模板一:单点修改,区间查询

int N = (int) 1e5 + 10;
int[] tree = new int [N]; // 线段树
int[] arr = new int[N]; // 被管理的数组
int size;

// 递归建树:node:树中节点编号
public static void build(int node, int start, int end) {
    if (start == end) {
        tree[node] = arr[start];
    } else {
        int mid = (start + end) / 2;
        int left_node = 2 * node + 1;
        int right_node = 2 * node + 2;
        build(left_node, start, mid);
        build(right_node, mid + 1, end);
        tree[node] = tree[left_node] + tree[right_node];
    }
} 

// 更新单个值
public static void update(int node, int start, int end, int idx, int val) {
    if (start == end) {
        arr[idx] = val;
        tree[node] = val;
    } else {
        int mid = (start + end) / 2;
        int left_node = 2 * node + 1;
        int right_node = 2 * node + 2;
        if (idx >= start && idx <= mid) {
            update(left_node, start, mid, idx, val);
        } else {
            update(right_node, mid + 1, end, idx, val);
        }
        tree[node] = tree[left_node] + tree[right_node];
    }
}

// 区间求和
public static int query(int node, int start, int end, int L, int R) {
    if (R < start || L > end) {
        return 0;
    } else if (L <= start && end <= R) {
        return tree[node];
    } else if (start == end) {
        return tree[node];
    } else {
        int mid = (start + end) / 2;
        int left_node = 2 * node + 1;
        int right_node = 2 * node + 2;
        return query(left_node, start, mid, L, R) + query(right_node, mid + 1, end, L, R);
    }
}

模板二:lazy_tag,区间修改,区间查询

class Node{
    int l, r;
    long sum; // 区间和
    long tag; // 懒惰标记
    public Node() {}
    public Node(int _l, int _r) {
        l = _l;
        r = _r;
    }
}

static int N = (int) 1e5 + 7; // 实际中设为数组长度的4倍
static Node[] tree = new Node[N]; // 线段树
static long[] arr = new long[N]; // 被管理的数组 
static long SUM; // 用于查询,使用前先清零

// 建树
static void build(int l, int r, int idx) {
    int left = idx * 2 + 1;
    int right = idx * 2 + 2;
    tree[idx] = new Node(l, r);
    if (l == r) {
        tree[idx].sum = arr[l];
    } else {
        build(l, (l + r) >> 1, left); // 构建左子树
        build(((l + r) >> 1) + 1, r, right); // 右子树
        pushUp(idx); // 更新父节点的值
    }
}

// 给区间l到r之间的所有数都加上val,idx表示当前节点
static void update(int l, int r, int val, int idx) {
    // 当前区间被包含于要修改的区间
    if (l <= tree[idx].l && r >= tree[idx].r) {
        tree[idx].sum += (tree[idx].r - tree[idx].l + 1) * val;
        tree[idx].tag += val;
        return;
    }
    // 懒惰标记不为零,先将懒惰标记下传
    if (tree[idx].tag != 0) {
        pushDown(idx);
    }
    int mid = (tree[idx].l + tree[idx].r) >> 1;
    int left = idx * 2 + 1;
    int right = idx * 2 + 2;
    if (r <= mid) {
        update(l, r, val, left);
    } else if (l > mid) {
        update(l, r, val, right);
    } else {
        update(l, r, val, left);
        update(l, r, val, right);
    }
    // 操作完成后,跟新父节点的值
    pushUp(idx);
}

// 下传懒惰标记
static void pushDown(int idx) {
    int mid = (tree[idx].l + tree[idx].r) >> 1;
    int left = idx * 2 + 1;
    int right = idx * 2 + 2;
    tree[left].sum += (mid - tree[idx].l + 1) * tree[idx].tag;
    tree[right].sum += (tree[idx].r - mid) * tree[idx].tag;
    tree[left].tag += tree[idx].tag;
    tree[right].tag += tree[idx].tag;
    tree[idx].tag = 0; // 当前节点标记下传完毕,标记清零
}

// 查询区间l到r之间的和
static void query(int l, int r, int idx) {
    if (l <= tree[idx].l && r >= tree[idx].r) {
        SUM += tree[idx].sum;
    } else {
        if (tree[idx].tag != 0) {
            pushDown(idx);
        }
        int mid = (tree[idx].l + tree[idx].r) >> 1;
        int left = idx * 2 + 1;
        int right = idx * 2 + 2;
        if (r <= mid) {
            query(l, r, left);
        } else if (l > mid){
            query(l, r, right);
        } else {
            query(l, r, left);
            query(l, r, right);
        }
    }
}

// 更新父节点
static void pushUp(int idx) {
    int left = idx * 2 + 1;
    int right = idx * 2 + 2;
    tree[idx].sum = tree[left].sum + tree[right].sum;
}

动态开点线段树

class SG {
    class Node {
        Node ls, rs; // 当前区间的左右子节点
        int val; // 区间最大值
        int add; // 懒标记
    }
    
    Node root = new Node(); // 根结点
    
    void update(Node node, int lc, int rc, int l, int r, int v) {
        if (l <= lc && rc <= r) {
            node.add = v;
            node.val = v;
            return;
        }
        pushdown(node);
        int mid = lc + rc >> 1;
        if (l <= mid) update(node.ls, lc, mid, l, r, v);
        if (r > mid) update(node.rs, mid + 1, rc, l, r, v);
        pushup(node);
    }
    
    int query(Node node, int lc, int rc, int l, int r) {
        if (l <= lc && rc <= r) return node.val;
        pushdown(node);
        int mid = lc + rc >> 1, ans = 0;
        if (l <= mid) ans = query(node.ls, lc, mid, l, r);
        if (r < mid) ans = Math.max(ans, query(node.rs, mid + 1, rc, l, r));
        return ans;
    }
    
    void pushdown(Node node) {
        if (node.ls == null) node.ls = new Node();
        if (node.rs == null) node.rs = new Node();
        if (node.add == 0) return ;
        node.ls.add = node.add; node.rs.add = node.add;
        node.ls.val = node.add; node.rs.val = node.add;
        node.add = 0;
    }
    
    void pushup(Node node) {
        node.val = Math.max(node.ls.val, node.rs.val);
    }
}
class RecentCounter {
    Node root;
    public RecentCounter() {
        root = new Node(1, (int)(1e9));
    }
    
    public int ping(int t) {
        root.add(t, 1);
        return root.query(Math.max(1, t - 3000), t);
    }
}

class Node {
    int l, r;
    Node left, right;
    int sum;

    Node(int l, int r) {
        this.l = l;
        this.r = r;
    }

    void add(int idx, int x) {
        if (idx <= l && idx >= r) {
            this.sum += x;
        } else {
            int mid = l + r >> 1;
            if (left == null) left = new Node(l, mid);
            if (right == null) right = new Node(mid + 1, r);
            if (idx <= mid) left.add(idx, x);
            else right.add(idx, x);
            pushup();
        }
    }

    void pushup() {
        this.sum = left.sum + right.sum;
    }

    int query(int l, int r) {
        if (l <= this.l && r >= this.r)
            return sum;
        else {
            int mid = l + r >> 1;
            if (left == null) left = new Node(l, mid);
            if (right == null) right = new Node(mid + 1, r);
            int res = 0;
            if (l <= mid) res += left.query(l, r);
            if (r > mid) res += right.query(l, r);
            return res;
        }
    }

}

/**
 * Your RecentCounter object will be instantiated and called as such:
 * RecentCounter obj = new RecentCounter();
 * int param_1 = obj.ping(t);
 */

参考:

[1]https://wmathor.com/index.php/archives/1176/