树链剖分学习笔记
# 前言
一个随随便便200+行代码的数据结构
能用别的方法就绝不写树链剖分
# 树链剖分
通过dfs序将树“压平”,并使用其他数据结构(一般来说是线段树)来维护一段路径的信息
原理见代码(有注释)
# 例题引入 Qtree1 (洛谷P4114)
给定一棵 个节点的树,有两种操作:
CHANGE i t
把第 条边的边权变成QUERY a b
输出从 到 的路径上最大的边权。当 时,输出
AC代码:
#include <bits/stdc++.h>
#define ll long long
#define Android ios::sync_with_stdio(false), cin.tie(NULL)
using namespace std;
const int max_node = 1e5 + 100;
ll initial[max_node];
struct segment_tree{ // 普普通通的线段树
ll tree[max_node * 4];
#define mid ((l + r) >> 1)
void build(int l, int r, int rt, int * id_node){
if(l == r) {
tree[rt] = initial[id_node[l]];
return;
}
build(l, mid, rt<<1, id_node), build(mid + 1, r, rt<<1|1, id_node);
pushup(rt);
}
void pushup(int rt){
tree[rt] = max(tree[rt<<1], tree[rt<<1|1]);
}
ll query(int l, int r, int ql, int qr, int rt){
if(ql > qr) return 0;
if(ql <= l && r <= qr) return tree[rt];
ll result = 0;
if(ql <= mid) result = query(l, mid, ql, qr, rt<<1);
if(qr > mid) result = max(result, query(mid + 1, r, ql, qr, rt<<1|1));
return result;
}
void change(int l, int r, int pos, ll delta, int rt){
if(l == r){
tree[rt] = delta;
return;
}
if(pos <= mid) change(l, mid, pos, delta, rt<<1);
else change(mid + 1, r, pos, delta, rt<<1|1);
pushup(rt);
}
#undef mid
}st;
struct heavy_light_seg{
// 储存原图
int head[max_node], nxt[max_node * 2], dest[max_node * 2];
int n, edge_cnt, root;
// 原图信息
int size[max_node], fa[max_node], heavy_son[max_node], depth[max_node];
// 重链剖分
int dfn;
int top[max_node], id_node[max_node], node_id[max_node];
// id_node: dfs序号到节点序号的映射
// node_id: 节点序号到dfs序号的映射
void init(int node_cnt){
n = node_cnt;
edge_cnt = dfn = 0;
for(int i=0; i<=n; i++) head[i] = 0, heavy_son[i] = -1;
}
inline void add_edge(int u, int v){ // 原图的边
dest[++edge_cnt] = v;
nxt[edge_cnt] = head[u];
head[u] = edge_cnt;
}
void _dfs1(int u, int parent, int dep){ // 记录重子节点
size[u] = 1;
fa[u] = parent;
depth[u] = dep;
int & hs = heavy_son[u];
for(int x=head[u]; x; x=nxt[x]){
if(dest[x] == parent) continue;
_dfs1(dest[x], u, dep + 1);
size[u] += size[dest[x]];
if(hs == -1 || size[hs] < size[dest[x]]) hs = dest[x];
}
}
void _dfs2(int u, int t){ // 按dfs序编号
top[u] = t;
node_id[u] = ++dfn;
id_node[dfn] = u;
if(heavy_son[u] == -1) return; // 叶子节点
_dfs2(heavy_son[u], t); // 先遍历重子节点,使dfs序连续
for(int x=head[u]; x; x=nxt[x]){
if(dest[x] == fa[u] || dest[x] == heavy_son[u]) continue;
_dfs2(dest[x], dest[x]);
}
}
void build(int rt){
root = rt;
_dfs1(root, -1, 1);
_dfs2(root, root);
}
int lca(int u, int v){
while(top[u] != top[v]){ // 不在同一条链上
// 每次只能选u或v其中一个点往上跳!
if(depth[top[u]] > depth[top[v]]) swap(u, v);
v = fa[top[v]];
}
return depth[u] < depth[v]?u:v;
}
ll query(int u, int v){ // 仿照lca(u,v),上跳过程中顺便统计答案即可
ll ans = 0;
while(top[u] != top[v]){
if(depth[top[u]] > depth[top[v]]) swap(u, v);
ans = max(ans, st.query(0, n, node_id[top[v]], node_id[v], 1));
v = fa[top[v]];
}
if(depth[u] > depth[v]) swap(u, v);
if(u != v) ans = max(ans, st.query(0, n, node_id[u] + 1, node_id[v], 1));
return ans;
}
}hls;
int eu[max_node], ev[max_node], ew[max_node];
char buffer[70];
void solve(){
int n;
cin >> n;
hls.init(n);
for(int i=1; i<n; i++){
cin >> eu[i] >> ev[i] >> ew[i];
hls.add_edge(eu[i], ev[i]);
hls.add_edge(ev[i], eu[i]);
}
hls.build(1);
for(int i=1; i<n; i++){
if(hls.depth[eu[i]] > hls.depth[ev[i]]) swap(eu[i], ev[i]);
initial[ev[i]] = ew[i]; // 取子节点id作为边的id
}
st.build(0, n, 1, hls.id_node);
int x, t;
while(cin >> buffer){
if(buffer[0] == 0 || buffer[0] == 'D') break;
if(buffer[0] == 'C'){
cin >> x >> t;
st.change(0, n, hls.node_id[ev[x]], t, 1);
}else{
cin >> x >> t;
cout << hls.query(x, t) << '\n';
}
}
}
signed main(){
Android;
solve();
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# 练习题
# 洛谷P3384
#include <bits/stdc++.h>
#define ll long long
#define Android ios::sync_with_stdio(false), cin.tie(NULL)
using namespace std;
const ll inf = 0x3f3f3f3f;
const int maxn = 1e5 + 100;
ll mod;
ll initial[maxn];
struct heavy_light_seg{
static const int max_node = 1e5 + 100;
// 储存原图
int head[max_node], nxt[max_node * 2], dest[max_node * 2];
int n, edge_cnt, root;
// 原图信息
int size[max_node], fa[max_node], heavy_son[max_node], depth[max_node];
// 重链剖分
int dfn;
int top[max_node], id_node[max_node], node_id[max_node];
// 维护重链信息
struct segment_tree{
ll tree[max_node * 4], lazy[max_node * 4];
#define mid ((l + r) >> 1)
void build(int l, int r, int rt, heavy_light_seg * item){
lazy[rt] = 0;
if(l == r) {
tree[rt] = initial[item->id_node[l]];
return;
}
build(l, mid, rt<<1, item), build(mid + 1, r, rt<<1|1, item);
pushup(rt);
}
void pushup(int rt){
tree[rt] = (tree[rt<<1] + tree[rt<<1|1]) % mod;
}
void pushdown(int l, int r, int rt){
if(lazy[rt] == 0) return;
lazy[rt<<1] = (lazy[rt<<1] + lazy[rt]) % mod;
lazy[rt<<1|1] = (lazy[rt<<1|1] + lazy[rt]) % mod;
tree[rt<<1] = (tree[rt<<1] + (mid - l + 1) * lazy[rt] % mod) % mod;
tree[rt<<1|1] = (tree[rt<<1|1] + (r - mid) * lazy[rt] % mod) % mod;
lazy[rt] = 0;
}
ll query(int l, int r, int ql, int qr, int rt){
if(ql <= l && r <= qr) return tree[rt] % mod;
pushdown(l, r, rt);
ll result = 0;
if(ql <= mid) result = query(l, mid, ql, qr, rt<<1);
if(qr > mid) result = (result + query(mid + 1, r, ql, qr, rt<<1|1)) % mod;
return result;
}
void add(int l, int r, int ql, int qr, ll delta, int rt){
if(ql <= l && r <= qr){
lazy[rt] = (lazy[rt] + delta) % mod;
tree[rt] = (tree[rt] + delta * (r - l + 1) % mod) % mod;
return;
}
pushdown(l, r, rt);
if(ql <= mid) add(l, mid, ql, qr, delta, rt<<1);
if(qr > mid) add(mid + 1, r, ql, qr, delta, rt<<1|1);
pushup(rt);
}
#undef mid
}st;
void init(int node_cnt);// 省略
inline void add_edge(int u, int v);// 省略
void _dfs1(int u, int parent, int dep);// 省略
void _dfs2(int u, int t);// 省略
void build(int rt){
root = rt;
_dfs1(root, -1, 1);
_dfs2(root, root);
st.build(0, n, 1, this);
}
int lca(int u, int v){
while(top[u] != top[v]){ // 不在同一条链上
if(depth[top[u]] > depth[top[v]]) swap(u, v);
v = fa[top[v]];
}
return depth[u] < depth[v]?u:v;
}
void op1(int u, int v, ll delta){
while(top[u] != top[v]){
if(depth[top[u]] > depth[top[v]]) swap(u, v);
st.add(0, n, node_id[top[v]], node_id[v], delta, 1);
v = fa[top[v]];
}
if(node_id[u] > node_id[v]) swap(u, v);
st.add(0, n, node_id[u], node_id[v], delta, 1);
}
ll op2(int u, int v){
ll ret = 0;
while(top[u] != top[v]){
if(depth[top[u]] > depth[top[v]]) swap(u, v);
ret = (ret + st.query(0, n, node_id[top[v]], node_id[v], 1)) % mod;
v = fa[top[v]];
}
if(node_id[u] > node_id[v]) swap(u, v);
ret = (ret + st.query(0, n, node_id[u], node_id[v], 1)) % mod;
return ret;
}
void op3(int u, ll delta){
st.add(0, n, node_id[u], node_id[u] + size[u] - 1, delta, 1);
}
ll op4(int u){
return st.query(0, n, node_id[u], node_id[u] + size[u] - 1, 1);
}
}hls;
ll n, m, r;
void solve(){
cin >> n >> m >> r >> mod;
hls.init(n);
for(int i=1; i<=n; i++) cin >> initial[i];
for(int i=1, u, v; i<n; i++){
cin >> u >> v;
hls.add_edge(u, v);
hls.add_edge(v, u);
}
hls.build(r);
for(ll i=0, op, x, y, z; i<m; i++){
cin >> op;
if(op == 1){
cin >> x >> y >> z;
hls.op1(x, y, z);
}
if(op == 2){
cin >> x >> y;
cout << hls.op2(x, y) << '\n';
}
if(op == 3){
cin >> x >> y;
hls.op3(x, y);
}
if(op == 4){
cin >> x;
cout << hls.op4(x) << '\n';
}
}
}
signed main(){
Android;
solve();
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# SPOJ-QTREE2
这道题完全不需要线段树(前缀和),甚至不需要树剖(树上路径和+倍增lca)
#include <bits/stdc++.h>
#define ll long long
#define Android ios::sync_with_stdio(false), cin.tie(NULL)
using namespace std;
const int max_node = 1e4 + 100;
ll initial[max_node];
struct segment_tree{
ll tree[max_node * 4];
#define mid ((l + r) >> 1)
void build(int l, int r, int rt, int * id_node){
if(l == r) {
tree[rt] = initial[id_node[l]];
return;
}
build(l, mid, rt<<1, id_node), build(mid + 1, r, rt<<1|1, id_node);
pushup(rt);
}
void pushup(int rt){
tree[rt] = tree[rt<<1] + tree[rt<<1|1];
}
ll query(int l, int r, int ql, int qr, int rt){
if(ql > qr) return 0;
if(ql <= l && r <= qr) return tree[rt];
ll result = 0;
if(ql <= mid) result = query(l, mid, ql, qr, rt<<1);
if(qr > mid) result = result + query(mid + 1, r, ql, qr, rt<<1|1);
return result;
}
#undef mid
}st;
struct heavy_light_seg{
// 储存原图
int head[max_node], nxt[max_node * 2], dest[max_node * 2];
int n, edge_cnt, root;
// 原图信息
int size[max_node], fa[max_node], heavy_son[max_node], depth[max_node];
// 重链剖分
int dfn;
int top[max_node], id_node[max_node], node_id[max_node];
void init(int node_cnt);// 省略
inline void add_edge(int u, int v);// 省略
void _dfs1(int u, int parent, int dep);// 省略
void _dfs2(int u, int t);// 省略
void build(int rt);// 省略
int lca(int u, int v){
while(top[u] != top[v]){ // 不在同一条链上
if(depth[top[u]] > depth[top[v]]) swap(u, v);
v = fa[top[v]];
}
return depth[u] < depth[v]?u:v;
}
ll dist(int u, int v){
ll ans = 0;
while(top[u] != top[v]){
if(depth[top[u]] > depth[top[v]]) swap(u, v);
ans += st.query(0, n, node_id[top[v]], node_id[v], 1);
v = fa[top[v]];
}
if(depth[u] > depth[v]) swap(u, v);
if(u != v) ans += st.query(0, n, node_id[u] + 1, node_id[v], 1);
return ans;
}
ll kth(int u, int v, int k){
int lc = lca(u, v);
if(k >= depth[u] - depth[lc] + 1){
k -= depth[u] - depth[lc] + 1;
k = depth[v] - depth[lc] - k + 1;
u = v;
}
k--;
while(true){
if(depth[u] - depth[top[u]] < k){
k -= depth[u] - depth[fa[top[u]]];
u = fa[top[u]];
}else{
return id_node[node_id[u] - k];
}
}
return 0;
}
}hls;
int eu[max_node], ev[max_node], ew[max_node];
char buffer[70];
void solve(){
int n;
// 同QTREE1 省略...
st.build(0, n, 1, hls.id_node);
int a, b, c;
while(cin >> buffer){
if(buffer[0] == 0 || buffer[1] == 'O') break;
if(buffer[0] == 'D'){
cin >> a >> b;
cout << hls.dist(a, b) << '\n';
}else{
cin >> a >> b >> c;
cout << hls.kth(a, b, c) << '\n';
}
}
}
signed main(){
Android;
int t; cin >> t;
while(t--) solve();
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
上次更新: 2021/02/24, 03:37:30