线段树

本文最后更新于:6 天前

花了半天多的时间看了线段树的理论,并造了一个大大的轮子
–2022.2.26 18:04

今天听完了y总版本的线段树。y总的代码风格真的清晰,编码思路与本文的思路略有不同(主要区别在于方法的参数数量较少,原因是使用单独的结构表示了线段树节点,从而在结构中存储了更多信息)。

如果以后变勤快了的话,再把y总版本加到这里吧。

​ --2022.3.31 17:24

1. 背景

假设数组data={1,6,3,4,8,2,9}data = \{1, 6, 3, 4, 8, 2, 9\},当我们需要对数组进行频繁的元素更新、查询某一区间的最大/最小值时,暴力的方法允许在O(n)O(n)时间内完成单点修改与区间查询。

我们可以使用线段树存储数组元素,此时更新与区间查询的时间复杂度均为O(logn)O(logn).

2. 概述

线段树(segment tree),是用来存放给定区间内对应信息的一种非常灵活的数据结构。使用数组来存储树型结构,支持区间查询与单点修改。允许在对数时间内从数组中找到最小值、最大值、总和、最大公约数、最小公倍数等。

3. 线段树的特点与原理

线段树是一个平衡二叉树,但不一定是完全二叉树。当数组datadata的长度是2的整数次幂时,线段树成为满二叉树。同时,这也就决定了线段树可以采用顺序的数组结构存储。

线段树的节点:线段树的每个非叶子节点存储了一段区间的信息(如最大值/最小值),每个叶子节点对应数组datadata的一个元素(叶子节点也可以看作是区间信息,只不过这个区间的左右端点值相等)。

具体来说,线段树的根节点代表整个数组所在区间的信息,即data[0:N1]data[0: N - 1](含N)所对应的信息。

将区间均分成两半,mid=N12mid = \frac{N - 1}{2},根节点的左孩子节点存储data[0:mid]data[0:mid]区间的信息;右孩子节点存储data[mid+1:N1]data[mid+1:N-1]区间的信息。也正是因为这样的均分策略,最终构造出的树才是平衡二叉树。

线段树的额外空间:我们已经知道了线段树的叶子节点存储的是datadata中的元素值。如果说平衡二叉树有nn个叶子节点的话,那么需要4n4n的连续空间存储整棵树。

一棵层数为nn的满二叉树,其第nn层的节点数为2n12^{n-1},前n1n-1层的节点总数2n112^{n - 1} - 1,也就是说,整棵树的节点数大约是最后一层节点数的2倍。
若存在第n+1n+1层,则第n+1n+1层的节点数也将会是第nn层的2倍。r

由于线段树是平衡二叉树,不一定是满二叉树,因此若有nn个叶子节点的话,需要多开辟一层的存储空间,即4倍叶子节点数的额外空间大小。

4. 建树

线段树的创建是自下而上进行的。一个节点所存储的信息完全由其左右孩子决定。当左右孩子所存储信息确定时,该节点所存储的信息即可确定。

假如左孩子存储的是区间[a,b][a, b]的最小值xx,右孩子存储的是区间[b,c][b, c]的最小值yy,那么父节点存储的区间[a,c][a, c]的最小值就可以由min(x,y)min(x, y)唯一确定。由此自底向上,就可以完成建树了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
private void buildSegmentTree(int treeNode, int dataLeft, int dataRight, T[] data){
//此时是叶子节点
if(dataLeft == dataRight){
tree[treeNode] = data[dataLeft];
return;
}

//得到左右子树的顺序索引
int left = getLeft(treeNode);
int right = getRight(treeNode);
//得到数据的中间索引,如果begin和end很大,相加求和除二的方法可能会溢出。
int mid = dataLeft + (dataRight - dataLeft) / 2;
//构建左子树
buildSegmentTree(left, dataLeft, mid, data);
//构建右子树
buildSegmentTree(right, mid + 1, dataRight, data);
//修改当前节点值
tree[treeNode] = merger.merge(tree[left], tree[right]);
}

在构造方法中,可以调用如下语句递归的创建线段树。

1
buildSegmentTree(0, 0, data.length - 1, data);

5. 更新

当我们想更新原始数据数组datadata的某一处的值时,首先要在线段树中确定该值对应的叶子节点索引,之后自底向上的更新路径上的其他节点信息。

1
private void _update(int treeNode, int dataIndex, T newValue, int dataLeft, int dataRight)
  • treeNode是当前树节点索引——调用更新时应初始化为0
  • dataIndex是想要更新的原始数据索引
  • newValue是更新后的值
  • dataLeft与dataRight是当前树节点treeNode所表示的区间信息的左右端点。

因此,调用线段树更新操作的对外接口update方法体如下。从根节点treeNode=0处,将原始数据索引为dataIndex的数据更新为newValue。由于根节点所表示的区间为data[0:dataLen1]data[0:dataLen-1],因此后两个参数也就确定了。

1
2
3
public void update(int dataIndex, T newValue){
_update(0, dataIndex, newValue, 0, dataLen - 1);
}

对于_update方法,方法体如下。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
private void _update(int treeNode, int dataIndex, T newValue, int dataLeft, int dataRight){
if(dataLeft == dataRight){
tree[treeNode] = newValue;
return;
}

int leftChildNode = getLeft(treeNode);
int rightChildNode = getRight(treeNode);
int mid = dataLeft + (dataRight - dataLeft) / 2;
if(dataIndex <= mid){
_update(leftChildNode, dataIndex, newValue, dataLeft, mid);
} else {
_update(rightChildNode, dataIndex, newValue, mid + 1, dataRight);
}
//修改完子树更新当前节点
tree[treeNode] = merger.merge(tree[leftChildNode], tree[rightChildNode]);
}

从根节点开始,找到左右孩子所表示的区间信息,若要更新的节点落在左孩子节点所表示的区间内,则以左孩子为根节点递归的调用_update方法。反之则更新右子树。最后修改当前节点的相关信息。

递归直至dataLeft与dataRight相等终止,即当前节点treeNode是叶子节点,此时寻找到了待更新节点。

6. 查询

查询query方法对外的调用接口以及_query的方法的声明如下。

1
2
3
4
5
6
7
//query方法
public T query(int dataLeft, int dataRight){
return _query(0, 0, dataLen - 1, dataLeft, dataRight);
}

//_query方法声明
private T _query(int treeNode, int treeLeft, int treeRight, int dataLeft, int dataRight);
  • treeNode: 线段树的某一节点索引,查询开始时应从根节点0开始。
  • treeLeft与treeRight:线段树当前节点所表示的区间左右端点,查询开始时应初始化为0与dataLen - 1.
  • dataLeft与dataRight:要查询的区间信息的左右端点。

由于线段树特殊的区间均分方法,使得查询时出现了四种情况。

1)如果查询区间与当前节点所表示的区间完全吻合,则可以直接返回当前节点的信息。

1
2
3
if(treeLeft == dataLeft && treeRight == dataRight){
return tree[treeNode];
}

2)如果查询区间完全位于当前节点左孩子所表示区间中,则以左子树为根节点,递归地调用查询方法。需要注意的是,此时_query方法的参数2与参数3分别为treeLeft、mid,即左孩子所表示的区间左右端点。

1
2
3
4
5
6
7
8
int leftChildNode = getLeft(treeNode);
int rightChildNode = getRight(treeNode);
int mid = treeLeft + (treeRight - treeLeft) / 2;

//2.查询区间完全在左子树所表示区间,则去左子树查询
if(dataRight <= mid){
return _query(leftChildNode, treeLeft, mid, dataLeft, dataRight);
}

3)如果查询区间完全位于当前节点右孩子所表示区间中,则以右子树为根节点,递归调用查询方法。

1
2
3
if(dataLeft >= mid + 1){
return _query(rightChildNode, mid + 1, treeRight, dataLeft, dataRight);
}

4)最后,如果查询区间同时覆盖了左右子树所表示区间的部分,则应该在左子树中查询区间[dataLeft,mid][dataLeft, mid],在右子树中查询区间[mid+1,dataRight][mid + 1, dataRight],最后将查询到的结果合并

1
2
3
T leftValue = _query(leftChildNode, treeLeft, mid, dataLeft, mid);
T rightValue = _query(rightChildNode, mid + 1, treeRight, mid + 1, dataRight);
return merger.merge(leftValue, rightValue);

7. 完整代码

最后,线段树的代码如下。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
package com.chen.datastructure;

import com.chen.datastructure.SegmentTree.Merger;

class SegmentTree<T> {

private T[] tree;
private int dataLen;
private Merger<T> merger;

/**
* 对于自定义数据类型T的线段树,需要给出合并时的动作
* @param <T> 线段树的数据类型
*/
public interface Merger<T>{
T merge(T a, T b);
}

public SegmentTree(T[] data, Merger<T> merger){
this.merger = merger;
this.tree = (T[]) new Object[data.length * 4];
this.dataLen = data.length;
buildSegmentTree(0, 0, data.length - 1, data);
}

/**
* 从节点index开始,对原始数据数据构建线段树。
* @param treeNode 当前节点索引
* @param dataLeft 原始数据数组的左端点
* @param dataRight 原始数据数组的右端点
* @param data 原始数据数组
*/
private void buildSegmentTree(int treeNode, int dataLeft, int dataRight, T[] data){
if(dataLeft == dataRight){
tree[treeNode] = data[dataLeft];
return;
}
//得到左右子树的顺序索引
int left = getLeft(treeNode);
int right = getRight(treeNode);
//得到数据的中间索引,如果begin和end很大,相加求和除二的方法可能会溢出。
int mid = dataLeft + (dataRight - dataLeft) / 2;
//构建左子树
buildSegmentTree(left, dataLeft, mid, data);
//构建右子树
buildSegmentTree(right, mid + 1, dataRight, data);
//修改当前节点值
tree[treeNode] = merger.merge(tree[left], tree[right]);
}

private int getLeft(int index){
return 2 * index + 1;
}
private int getRight(int index){
return 2 * index + 2;
}

/**
* 对原始索引为dataIndex的数据值更改。
* @param dataIndex 要更改的原始数据索引
* @param newValue 更改后的值
*/
public void update(int dataIndex, T newValue){
_update(0, dataIndex, newValue, 0, dataLen - 1);
}

/**
* 内部修改逻辑
* @param treeNode 当前节点的索引
* @param dataIndex 要修改的数据的原始索引
* @param newValue 要修改的值
* @param dataLeft 数据区间左端点
* @param dataRight 数据区间右端点
*/
private void _update(int treeNode, int dataIndex, T newValue, int dataLeft, int dataRight){
if(dataLeft == dataRight){
tree[treeNode] = newValue;
return;
}

int leftChildNode = getLeft(treeNode);
int rightChildNode = getRight(treeNode);
int mid = dataLeft + (dataRight - dataLeft) / 2;
if(dataIndex <= mid){
_update(leftChildNode, dataIndex, newValue, dataLeft, mid);
} else {
_update(rightChildNode, dataIndex, newValue, mid + 1, dataRight);
}
//修改完子树更新当前节点
tree[treeNode] = merger.merge(tree[leftChildNode], tree[rightChildNode]);
}

/**
* 线段树的区间查询
* @param dataLeft 查询区间左端点
* @param dataRight 查询区间右端点
* @return 查询结果
*/
public T query(int dataLeft, int dataRight){
return _query(0, 0, dataLen - 1, dataLeft, dataRight);
}

/**
* 内部的区间查询逻辑
* @param treeNode 当前treeNode的索引
* @param treeLeft 当前treeNode所表示区间的左端点
* @param treeRight 当前treeNode所表示区间的右端点
* @param dataLeft 查询区间左端点
* @param dataRight 查询区间的右端点
* @return 查询结果
*/
private T _query(int treeNode, int treeLeft, int treeRight,
int dataLeft, int dataRight){
//1.查询区间完全与当前node节点所表示区间重合
if(treeLeft == dataLeft && treeRight == dataRight){
return tree[treeNode];
}

int leftChildNode = getLeft(treeNode);
int rightChildNode = getRight(treeNode);
int mid = treeLeft + (treeRight - treeLeft) / 2;

//2.查询区间完全在左子树所表示区间,则去左子树查询
if(dataRight <= mid){
return _query(leftChildNode, treeLeft, mid, dataLeft, dataRight);
}
//3.查询区间完全在右子树所表示区间,则去右子树查询
if(dataLeft >= mid + 1){
return _query(rightChildNode, mid + 1, treeRight, dataLeft, dataRight);
}
//4.查询区间一般在左子树,一半在右子树
//此时,应该去左子树查询dataLeft至左子树所表示区间的右端点——mid,去右子树查询右子树所表示区间的左端点——mid+1至dataRight
//并merge
T leftValue = _query(leftChildNode, treeLeft, mid, dataLeft, mid);
T rightValue = _query(rightChildNode, mid + 1, treeRight, mid + 1, dataRight);
return merger.merge(leftValue, rightValue);
}


@Override
public String toString() {
for(T num: tree){
System.out.print(num + " ");
}
System.out.println();
return "";
}
}

本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!