Bỏ qua

Phân tách cây (Tree Divide)

Điểm phân chia (Point Centroid Decomposition)

Điểm phân chia (hay còn gọi là phân chia theo trọng tâm, centroid decomposition) rất phù hợp để xử lý các bài toán về thông tin đường đi trên cây quy mô lớn.

Ví dụ 1 Luogu P3806【Mẫu】Điểm phân chia 1

Cho một cây có \(n\) đỉnh, \(m\) truy vấn, mỗi truy vấn cho một số \(k\), hỏi có tồn tại một đường đi độ dài \(k\) hay không.

\(n\le 10000,m\le 100,k\le 10000000\)

Ta chọn một đỉnh bất kỳ làm gốc \(\mathit{rt}\), mọi đường đi hoàn toàn nằm trong cây con của nó có thể chia thành hai loại: loại đi qua gốc hiện tại và loại không đi qua gốc. Với các đường đi qua gốc, lại chia thành hai loại: một đầu là gốc, hoặc cả hai đầu không phải gốc (loại sau có thể ghép từ hai đường đi loại trước). Do đó, với mỗi gốc \(rt\), ta tính trước đóng góp của các đường đi qua \(rt\), sau đó đệ quy xử lý các cây con cho các đường đi không qua \(rt\).

Trong bài này, với các đường đi qua \(rt\), ta liệt kê từng con \(ch\) của \(rt\), tính khoảng cách từ mọi đỉnh trong cây con \(ch\) đến \(rt\). Gọi \(\mathit{dist}_i\) là khoảng cách từ \(i\) đến \(rt\), \(\mathit{tf}_d\) là mảng đánh dấu xem đã có đỉnh nào ở các cây con trước có khoảng cách \(d\) đến \(rt\) chưa. Nếu với truy vấn \(k\)\(tf_{k-\mathit{dist}_i}=true\), tức là tồn tại một đường đi độ dài \(k\). Sau khi xử lý xong cây con \(ch\), cập nhật các giá trị mới vào \(\mathit{tf}\).

Khi xóa mảng \(\mathit{tf}\), không nên dùng memset, mà nên lưu lại các vị trí đã dùng vào một hàng đợi để xóa, đảm bảo đúng độ phức tạp.

Trong mỗi tầng của phân chia, tổng số lần xử lý mỗi đỉnh là \(1\), nếu tổng cộng đệ quy \(h\) tầng thì độ phức tạp \(O(hn)\).

Nếu mỗi lần chọn trọng tâm (centroid) làm gốc, số tầng tối đa là \(O(\log n)\), tổng độ phức tạp \(O(n\log n)\). Vì vậy, phương pháp này còn gọi là trọng tâm phân chia (centroid decomposition).

Lưu ý: mỗi lần chọn lại gốc phải tính lại kích thước cây con, nếu không sẽ sai độ phức tạp hoặc sai kết quả.

Code tham khảo
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
/* size 处理子树 d[], 连通块大小 cnt
   dp 最大子树 f[], 树的重心 rot
   get 计算出点到重心的距离 t[], top
   calc 点分治 bu[] 长度桶
   hd to nx wg 链式前向星存图
   ak[] as[] 离线处理询问
   ok[] 点分治中已成为重心的点
 */
#include <iostream>
const int N = 1e4 + 4, M = 105, Q = 1e7 + 7;
int n, m, hd[N], to[N * 2], nx[N * 2], wg[N * 2];
int ak[M], d[N], f[N], t[N], top, cnt, rot;
bool as[M], ok[N], bu[Q];

int size(int u, int pa) {
  cnt++, d[u] = 1;
  for (int p = hd[u]; ~p; p = nx[p])
    if (to[p] != pa && !ok[to[p]]) d[u] += size(to[p], u);
  return d[u];
}

void dp(int u, int pa) {
  f[u] = cnt - d[u];
  for (int p = hd[u]; ~p; p = nx[p])
    if (to[p] != pa && !ok[to[p]]) {
      f[u] = std::max(f[u], d[to[p]]);
      dp(to[p], u);
    }
  if (f[u] < f[rot]) rot = u;
}

void get(int u, int pa, int dis) {
  t[top++] = dis;
  for (int p = hd[u]; ~p; p = nx[p])
    if (to[p] != pa && !ok[to[p]]) get(to[p], u, dis + wg[p]);
}

void calc(int u) {
  cnt = 0, size(u, u);
  rot = u, dp(u, u);
  bu[0] = true, t[0] = 0, top = 1;
  for (int p = hd[rot], i; ~p; p = nx[p])
    if (!ok[to[p]]) {
      i = top, get(to[p], rot, wg[p]);
      for (int q = 0; q < m; q++)
        for (int j = i; j < top && !as[q]; j++)
          if (ak[q] >= t[j]) as[q] = bu[ak[q] - t[j]];
      --i;
      while (++i < top)
        if (t[i] < Q) bu[t[i]] = true;
    }
  while (top--)
    if (t[top] < Q) bu[t[top]] = false;
  ok[rot] = true;
  for (int p = hd[rot]; ~p; p = nx[p])
    if (!ok[to[p]]) calc(to[p]);
}

int main() {
  std::cin >> n >> m;
  for (int i = 1; i <= n; i++) hd[i] = -1;
  for (int i = 0, u, v; i + 2 < n * 2;) {
    std::cin >> u >> v >> wg[i];
    wg[i + 1] = wg[i];
    to[i] = v, nx[i] = hd[u], hd[u] = i++;
    to[i] = u, nx[i] = hd[v], hd[v] = i++;
  }
  for (int i = 0; i < m; i++) std::cin >> ak[i];
  calc(1);
  for (int i = 0; i < m; i++) std::cout << (as[i] ? "AYE\n" : "NAY\n");
}
Ví dụ 2 Luogu P4178 Tree

Cho một cây có \(n\) đỉnh, \(k\), hỏi số cặp đỉnh có khoảng cách không vượt quá \(k\).

\(n\le 40000,k\le 20000,w_i\le 1000\)

Vì ở đây hỏi số cặp đỉnh có khoảng cách không vượt quá \(k\), ta dùng segment tree để hỗ trợ truy vấn và cập nhật.

Code tham khảo
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#include <algorithm>
#include <cstring>
#include <iostream>
#include <queue>
using namespace std;
constexpr long long MAXN = 2000010;
constexpr long long inf = 2e9;
long long n, a, b, c, q, rt, siz[MAXN], maxx[MAXN], dist[MAXN];
long long cur, h[MAXN], nxt[MAXN], p[MAXN], w[MAXN], ret;
bool vis[MAXN];

void add_edge(long long x, long long y, long long z) {
  cur++;
  nxt[cur] = h[x];
  h[x] = cur;
  p[cur] = y;
  w[cur] = z;
}

long long sum;

void calcsiz(long long x, long long fa) {
  siz[x] = 1;
  maxx[x] = 0;
  for (long long j = h[x]; j; j = nxt[j])
    if (p[j] != fa && !vis[p[j]]) {
      calcsiz(p[j], x);
      maxx[x] = max(maxx[x], siz[p[j]]);
      siz[x] += siz[p[j]];
    }
  maxx[x] = max(maxx[x], sum - siz[x]);
  if (maxx[x] < maxx[rt]) rt = x;
}

long long dd[MAXN], cnt;

void calcdist(long long x, long long fa) {
  dd[++cnt] = dist[x];
  for (long long j = h[x]; j; j = nxt[j])
    if (p[j] != fa && !vis[p[j]])
      dist[p[j]] = dist[x] + w[j], calcdist(p[j], x);
}

queue<long long> tag;

struct segtree {
  long long cnt, rt, lc[MAXN], rc[MAXN], sum[MAXN];

  void clear() {
    while (!tag.empty()) update(rt, 1, 20000000, tag.front(), -1), tag.pop();
    cnt = 0;
  }

  void print(long long o, long long l, long long r) {
    if (!o || !sum[o]) return;
    if (l == r) {
      cout << l << ' ' << sum[o] << '\n';
      return;
    }
    long long mid = (l + r) >> 1;
    print(lc[o], l, mid);
    print(rc[o], mid + 1, r);
  }

  void update(long long& o, long long l, long long r, long long x,
              long long v) {
    if (!o) o = ++cnt;
    if (l == r) {
      sum[o] += v;
      if (!sum[o]) o = 0;
      return;
    }
    long long mid = (l + r) >> 1;
    if (x <= mid)
      update(lc[o], l, mid, x, v);
    else
      update(rc[o], mid + 1, r, x, v);
    sum[o] = sum[lc[o]] + sum[rc[o]];
    if (!sum[o]) o = 0;
  }

  long long query(long long o, long long l, long long r, long long ql,
                  long long qr) {
    if (!o) return 0;
    if (r < ql || l > qr) return 0;
    if (ql <= l && r <= qr) return sum[o];
    long long mid = (l + r) >> 1;
    return query(lc[o], l, mid, ql, qr) + query(rc[o], mid + 1, r, ql, qr);
  }
} st;

void dfz(long long x, long long fa) {
  // tf[0]=true;tag.push(0);
  st.update(st.rt, 1, 20000000, 1, 1);
  tag.push(1);
  vis[x] = true;
  for (long long j = h[x]; j; j = nxt[j])
    if (p[j] != fa && !vis[p[j]]) {
      dist[p[j]] = w[j];
      calcdist(p[j], x);
      for (long long k = 1; k <= cnt; k++)
        if (q - dd[k] >= 0)
          ret += st.query(st.rt, 1, 20000000, max(0ll, 1 - dd[k]) + 1,
                          max(0ll, q - dd[k]) + 1);
      for (long long k = 1; k <= cnt; k++)
        st.update(st.rt, 1, 20000000, dd[k] + 1, 1), tag.push(dd[k] + 1);
      cnt = 0;
    }
  st.clear();
  for (long long j = h[x]; j; j = nxt[j])
    if (p[j] != fa && !vis[p[j]]) {
      sum = siz[p[j]];
      rt = 0;
      maxx[rt] = inf;
      calcsiz(p[j], x);
      calcsiz(rt, -1);
      dfz(rt, x);
    }
}

signed main() {
  cin.tie(nullptr)->sync_with_stdio(false);
  cin >> n;
  for (long long i = 1; i < n; i++)
    cin >> a >> b >> c, add_edge(a, b, c), add_edge(b, a, c);
  cin >> q;
  rt = 0;
  maxx[rt] = inf;
  sum = n;
  calcsiz(1, -1);
  calcsiz(rt, -1);
  dfz(rt, -1);
  cout << ret << '\n';
  return 0;
}
Ví dụ 3 Luogu P2664 Trò chơi trên cây

Một cây mỗi đỉnh được gán một màu, định nghĩa \(s(i,j)\) là số màu trên đường đi từ \(i\) đến \(j\), \(\mathit{sum_{i}}=\sum_{j=1}^n s(i,j)\).Đối với mọi \(1\leq i\leq n\),tìm \(sum_i\).(\(1 \le n, c_i \le 10^5\)

Bài này kiểm tra sâu về tư duy điểm phân chia, rất phù hợp luyện tập nâng cao.

Trước hết, cần chuyển đổi ý nghĩa của \(\mathit{sum_i}\). Nếu tính trực tiếp như đề, rất khó hợp nhất thông tin giữa các cây con. Ta chuyển sang xét đóng góp của từng màu \(j\) cho \(\mathit{sum_i}\), gọi \(\mathit{cnt_j}\) là số đường đi qua \(i\) có màu \(j\), khi đó \(\mathit{sum_i} = \sum \mathit{cnt_j}\). Để tính \(\mathit{cnt_j}\), chỉ cần mỗi khi gặp màu mới thì \(\mathit{cnt_{col_u}}+=\mathit{size_u}\), với \(\mathit{size_u}\) là kích thước cây con gốc \(u\).

Trong quá trình điểm phân chia, cần thống kê:

  1. Đường đi có một đầu là gốc, đóng góp cho gốc.
  2. Đường đi có LCA là gốc, đóng góp cho các đỉnh trong cây con.

Phần 1 dễ xử lý, vì mỗi tầng chỉ cần duyệt toàn bộ cây con, dùng định nghĩa \(\mathit{sum_i}\) để cộng dồn.

Với phần 2, giả sử gốc \(u\) có con \(d\), chọn \(v\) trong cây con \(d\). Khi đó, đáp án cho \(v\) gồm:

  1. Số màu xuất hiện trên đường \((u, v)\), gọi là \(\mathit{num}\), nhân với tổng kích thước các cây con khác \(d\) của \(u\), gọi là \(\mathit{siz1}\), đóng góp là \(\mathit{num}\times \mathit{siz1}\).
  2. Với màu \(j\) không xuất hiện trên \((u, v)\), cộng thêm tổng \(\mathit{cnt_j}\) của các cây con khác \(d\).

Chi tiết xem code tham khảo.

Code tham khảo
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#include <algorithm>
#include <iostream>
using namespace std;
#define rep(i, a, b) for (int i = (a); i <= (b); ++i)
constexpr int N = 200005;
int h[N], nxt[N * 2], to[N * 2], c[N], gr;

void tu(int x, int y) { to[++gr] = y, nxt[gr] = h[x], h[x] = gr; }

using ll = long long;
int n, nn, siz[N], mn, rt;
bool vis[N];

void get_root(int u, int f) {
  siz[u] = 1;
  int mx = 0;
  for (int i = h[u]; i; i = nxt[i]) {
    int d = to[i];
    if (vis[d] || d == f) continue;
    get_root(d, u);
    siz[u] += siz[d];
    mx = max(mx, siz[d]);
  }
  mx = max(mx, nn - siz[u]);
  if (mx < mn) mn = mx, rt = u;
}

ll ans[N], sum;
int cnt[N], v[N];
// sum实时统计的是cnt[i]的和
int nowrt;

void get_dis(int u, int f, int now) {  // now为当前树链上的颜色数量(不含u)
  siz[u] = 1;
  if (!v[c[u]]) {
    sum -= cnt[c[u]];  // 减去在之前子树中已经出现过的颜色信息
    now++;
  }
  v[c[u]]++;
  ans[u] += sum + now * siz[nowrt];  // 统计过u点的路径对u的贡献
  for (int i = h[u]; i; i = nxt[i]) {
    int d = to[i];
    if (d == f || vis[d]) continue;
    get_dis(d, u, now);
    siz[u] += siz[d];
  }
  v[c[u]]--;
  if (!v[c[u]]) {
    sum += cnt[c[u]];  // 回溯
  }
}

void get_cnt(int u, int f) {
  if (!v[c[u]]) {
    cnt[c[u]] += siz[u];
    sum += siz[u];  // 将刚遍历过的子树的信息整合到cnt[i]和sum上去
  }
  v[c[u]]++;
  for (int i = h[u]; i; i = nxt[i]) {
    int d = to[i];
    if (vis[d] || d == f) continue;
    get_cnt(d, u);
  }
  v[c[u]]--;
}

void clear(int u, int f, int now) {
  if (!v[c[u]]) now++;
  v[c[u]]++;
  ans[u] -= now;
  ans[nowrt] += now;
  for (int i = h[u]; i; i = nxt[i]) {
    int d = to[i];
    if (vis[d] || d == f) continue;
    clear(d, u, now);
  }
  v[c[u]]--;
  cnt[c[u]] = 0;
}

void clear2(int u, int f) {
  cnt[c[u]] = 0;
  for (int i = h[u]; i; i = nxt[i]) {
    int d = to[i];
    if (vis[d] || d == f) continue;
    clear2(d, u);
  }
}

int son[N];

void divid(int u) {
  vis[u] = true;
  int tot = 0;
  nowrt = u;
  ans[u]++;
  for (int i = h[u]; i; i = nxt[i]) {
    if (vis[to[i]]) continue;
    son[++tot] = to[i];
  }
  siz[u] = sum = cnt[c[u]] = 1;
  v[c[u]]++;
  rep(i, 1, tot) {  // 统计每个子树和它之前的所有子树中节点组合产生的贡献
    int d = son[i];
    get_dis(d, u, 0);
    get_cnt(d, u);
    siz[u] += siz[d];
    cnt[c[u]] += siz[d];
    sum += siz[d];
  }
  clear2(u, 0);  // 清空数组,记得不可以用memset
  siz[u] = sum = cnt[c[u]] = 1;
  for (int i = tot; i >= 1;
       --i) {  // 统计每个子树和它之后的所有子树中节点组合产生的贡献
    int d = son[i];
    get_dis(d, u, 0);
    get_cnt(d, u);
    siz[u] += siz[d];
    cnt[c[u]] += siz[d];
    sum += siz[d];
  }
  v[c[u]]--;
  clear(u, 0, 0);                      // 清空的同时统计答案
  for (int i = h[u]; i; i = nxt[i]) {  // 继续向下进行点分治
    int d = to[i];
    if (vis[d]) continue;
    nn = siz[d], mn = n + 1, rt = 0;
    get_root(d, u);
    divid(rt);
  }
}

int main() {
  cin.tie(nullptr)->sync_with_stdio(false);
  cin >> n;
  int u, v;
  rep(i, 1, n) cin >> c[i];
  rep(i, 2, n) cin >> u >> v, tu(u, v), tu(v, u);
  rt = 0, nn = n, mn = n + 1;
  get_root(1, 0);
  divid(rt);
  rep(i, 1, n) cout << ans[i] << '\n';
  return 0;
}

Phân chia theo cạnh (Edge Centroid Decomposition)

Tương tự điểm phân chia, nhưng chọn một cạnh để chia cây thành hai phần có kích thước gần nhau nhất, rồi đệ quy xử lý hai phần.

Tuy nhiên, cách này không hiệu quả với cây nhiều nhánh như cây sao:

菊花图

Nếu một đỉnh có nhiều con kích thước gần nhau, phân chia theo cạnh sẽ rất tệ.

Nếu cây là nhị phân, sẽ tránh được vấn đề này. Ta có thể chuyển cây đa nhánh thành cây nhị phân như xây segment tree:

建树

Các đỉnh mới gán thông tin phù hợp với bài toán. Ví dụ, khi tính độ dài đường đi, gán trọng số cạnh gốc là \(1\), cạnh mới là \(0\).

Tổng số đỉnh tăng tối đa \(O(n)\), nên độ phức tạp vẫn \(O(n\log n)\).

Hầu hết các bài điểm phân chia đều có thể giải bằng phân chia theo cạnh (thường hằng số lớn hơn, nhưng không bị "hack" nặng), nên không cần ví dụ riêng.

Cây phân chia (Centroid Tree)

Cây phân chia là cây được xây lại từ cây gốc bằng cách phân chia theo trọng tâm, sao cho chiều cao cây mới là \(O(\log n)\).

Thường dùng cho các bài toán có truy vấn cập nhật động, không phụ thuộc hình dạng cây gốc.

Phân tích thuật toán

Mỗi lần tìm trọng tâm, liên kết nó với trọng tâm tầng trước thành cha-con, tạo thành cây mới có tối đa \(\log n\) tầng.

Nhờ vậy, nhiều thuật toán brute-force trên cây gốc sẽ chạy đúng và nhanh trên cây phân chia.

Cài đặt

Một mẹo nhỏ: mỗi lần truyền tổng kích thước tầng trước trừ đi kích thước con nặng nhất, sẽ ra tổng kích thước tầng hiện tại. Như vậy chỉ cần một DFS để tìm trọng tâm.

Code tham khảo
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#include <algorithm>
#include <iostream>
#include <vector>
using namespace std;

using IT = vector<int>::iterator;

struct Edge {
  int to, nxt, val;

  Edge() {}

  Edge(int to, int nxt, int val) : to(to), nxt(nxt), val(val) {}
} e[300010];

int head[150010], cnt;

void addedge(int u, int v, int val) {
  e[++cnt] = Edge(v, head[u], val);
  head[u] = cnt;
}

int siz[150010], son[150010];
bool vis[150010];

int tot, lasttot;
int maxp, root;

void getG(int now, int fa) {
  siz[now] = 1;
  son[now] = 0;
  for (int i = head[now]; i; i = e[i].nxt) {
    int vs = e[i].to;
    if (vs == fa || vis[vs]) continue;
    getG(vs, now);
    siz[now] += siz[vs];
    son[now] = max(son[now], siz[vs]);
  }
  son[now] = max(son[now], tot - siz[now]);
  if (son[now] < maxp) {
    maxp = son[now];
    root = now;
  }
}

struct Node {
  int fa;
  vector<int> anc;
  vector<int> child;
} nd[150010];

int build(int now, int ntot) {
  tot = ntot;
  maxp = 0x7f7f7f7f;
  getG(now, 0);
  int g = root;
  vis[g] = true;
  for (int i = head[g]; i; i = e[i].nxt) {
    int vs = e[i].to;
    if (vis[vs]) continue;
    int tmp = build(vs, ntot - son[vs]);
    nd[tmp].fa = now;
    nd[now].child.push_back(tmp);
  }
  return g;
}

int virtroot;

int main() {
  int n;
  cin >> n;
  for (int i = 1; i < n; i++) {
    int u, v, val;
    cin >> u >> v >> val;
    addedge(u, v, val);
    addedge(v, u, val);
  }
  virtroot = build(1, n);
}