Bỏ qua

Hash cây

Khi cần kiểm tra hai cây có đồng cấu (isomorphic) hay không, ta thường chuyển cây thành giá trị băm (hash) để giảm độ phức tạp.

Tree Hash rất linh hoạt, có thể thiết kế nhiều kiểu băm khác nhau; tuy nhiên nếu thiết kế tùy tiện, dễ bị "hack". Dưới đây là một phương pháp dễ cài đặt và khó bị hack.

Phương pháp

Phương pháp này cần một hàm băm cho đa tập. Giá trị hash của cây con gốc tại một đỉnh là giá trị hash của đa tập các giá trị hash của các cây con gốc tại các con của nó, tức là:

\[ h_x = f(\{ h_i \mid i \in son(x) \}) \]

Trong đó \(h_x\) là hash của cây con gốc \(x\), \(f\) là hàm băm đa tập.

Ví dụ hàm hash dùng trong code:

\[ f(S) = \left( c + \sum_{x \in S} g(x) \right) \bmod m \]

Với \(c\) là hằng số (thường lấy \(1\)), \(m\) là modulus (thường dùng \(2^{32}\) hoặc \(2^{64}\) để tràn tự nhiên, hoặc số nguyên tố lớn). \(g\) là ánh xạ số nguyên sang số nguyên, ví dụ dùng xor shift, hoặc các hàm khác (không nên dùng đa thức). Để tránh bị hack, có thể xor thêm một số ngẫu nhiên trước/sau khi ánh xạ.

Cách hash này rất dễ viết. Nếu cần đổi gốc, chỉ cần DP lần hai, trừ đi hash cây con là được.

Bài tập ví dụ

UOJ #763. Tree Hash

Bài mẫu. Chỉ cần DFS từ gốc \(1\) là xong.

Code tham khảo
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#include <cctype>
#include <iostream>
#include <random>
#include <set>
#include <vector>

using ull = unsigned long long;

const ull mask = std::mt19937_64(time(nullptr))();

ull shift(ull x) {
  x ^= mask;
  x ^= x << 13;
  x ^= x >> 7;
  x ^= x << 17;
  x ^= mask;
  return x;
}

constexpr int N = 1e6 + 10;

int n;
ull hash[N];
std::vector<int> edge[N];
std::set<ull> trees;

void getHash(int x, int p) {
  hash[x] = 1;
  for (int i : edge[x]) {
    if (i == p) {
      continue;
    }
    getHash(i, x);
    hash[x] += shift(hash[i]);
  }
  trees.insert(hash[x]);
}

using std::cin;
using std::cout;

int main() {
  cin.tie(nullptr)->sync_with_stdio(false);
  cin >> n;
  for (int i = 1; i < n; i++) {
    int u, v;
    cin >> u >> v;
    edge[u].push_back(v);
    edge[v].push_back(u);
  }
  getHash(1, 0);
  cout << trees.size();
}

[BJOI2015] Đồng cấu cây

Ở đây đồng cấu là vô gốc, còn phương pháp trên là cho cây có gốc. Do đó, chỉ khi chọn cùng gốc thì hai cây vô gốc mới có hash giống nhau. Với dữ liệu nhỏ, có thể brute-force hash với mọi gốc, rồi sort so sánh.

Nếu dữ liệu lớn, có thể dùng DP đổi gốc, duyệt hai lần để tính hash với mọi gốc. Hoặc, dùng hàm hash đa tập: lưu hash của mọi gốc vào một đa tập, rồi hash đa tập này để so sánh (cách 1).

Cũng có thể tối ưu bằng cách tìm trọng tâm (centroid) của cây. Một cây có tối đa hai trọng tâm, chỉ cần hash với các trọng tâm làm gốc. Sau đó, so sánh từng hash (cách 2), hoặc nếu có một trọng tâm thì lấy hash đó làm hash toàn cây, nếu có hai thì lấy min/max.

Cách 1
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#include <iostream>
#include <map>
#include <random>
#include <vector>

using ull = unsigned long long;

constexpr int N = 60, M = 998244353;
const ull mask = std::mt19937_64(time(nullptr))();

ull shift(ull x) {
  x ^= mask;
  x ^= x << 13;
  x ^= x >> 7;
  x ^= x << 17;
  x ^= mask;
  return x;
}

std::vector<int> edge[N];
ull sub[N], root[N];
std::map<ull, int> trees;

void getSub(int x) {
  sub[x] = 1;
  for (int i : edge[x]) {
    getSub(i);
    sub[x] += shift(sub[i]);
  }
}

void getRoot(int x) {
  for (int i : edge[x]) {
    root[i] = sub[i] + shift(root[x] - shift(sub[i]));
    getRoot(i);
  }
}

using std::cin;
using std::cout;

int main() {
  cin.tie(nullptr)->sync_with_stdio(false);
  int m;
  cin >> m;
  for (int t = 1; t <= m; t++) {
    int n, rt = 0;
    cin >> n;
    for (int i = 1; i <= n; i++) {
      int fa;
      cin >> fa;
      if (fa) {
        edge[fa].push_back(i);
      } else {
        rt = i;
      }
    }
    getSub(rt);
    root[rt] = sub[rt];
    getRoot(rt);
    ull hash = 1;
    for (int i = 1; i <= n; i++) {
      hash += shift(root[i]);
    }
    if (!trees.count(hash)) {
      trees[hash] = t;
    }
    cout << trees[hash] << '\n';
    for (int i = 1; i <= n; i++) {
      edge[i].clear();
    }
  }
}
Cách 2
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#include <iostream>
#include <map>
#include <random>
#include <vector>

using ull = unsigned long long;
using Hash2 = std::pair<ull, ull>;

constexpr int N = 60, M = 998244353;
const ull mask = std::mt19937_64(time(nullptr))();

ull shift(ull x) {
  x ^= mask;
  x ^= x << 13;
  x ^= x >> 7;
  x ^= x << 17;
  x ^= mask;
  return x;
}

int n;
int size[N], weight[N], centroid[2];
std::vector<int> edge[N];
std::map<Hash2, int> trees;

void getCentroid(int x, int fa) {
  size[x] = 1;
  weight[x] = 0;
  for (int i : edge[x]) {
    if (i == fa) {
      continue;
    }
    getCentroid(i, x);
    size[x] += size[i];
    weight[x] = std::max(weight[x], size[i]);
  }
  weight[x] = std::max(weight[x], n - size[x]);
  if (weight[x] <= n / 2) {
    int index = centroid[0] != 0;
    centroid[index] = x;
  }
}

ull getHash(int x, int fa) {
  ull hash = 1;
  for (int i : edge[x]) {
    if (i == fa) {
      continue;
    }
    hash += shift(getHash(i, x));
  }
  return hash;
}

using std::cin;
using std::cout;

int main() {
  cin.tie(nullptr)->sync_with_stdio(false);
  int m;
  cin >> m;
  for (int t = 1; t <= m; t++) {
    cin >> n;
    for (int i = 1; i <= n; i++) {
      int fa;
      cin >> fa;
      if (fa) {
        edge[fa].push_back(i);
        edge[i].push_back(fa);
      }
    }
    getCentroid(1, 0);
    Hash2 hash;
    hash.first = getHash(centroid[0], 0);
    if (centroid[1]) {
      hash.second = getHash(centroid[1], 0);
      if (hash.first > hash.second) {
        std::swap(hash.first, hash.second);
      }
    } else {
      hash.second = hash.first;
    }
    if (!trees.count(hash)) {
      trees[hash] = t;
    }
    cout << trees[hash] << '\n';
    for (int i = 1; i <= n; i++) {
      edge[i].clear();
    }
    centroid[0] = centroid[1] = 0;
  }
}

HDU 6647 Bracket Sequences on Tree

Yêu cầu đếm số chuỗi ngoặc khác nhau sinh ra từ các cách duyệt cây vô gốc.

Nhận xét: hai cây có gốc không đồng cấu sẽ không sinh ra cùng chuỗi ngoặc. Đầu tiên, với cây có gốc, gọi \(f(u)\) là số cách sinh chuỗi ngoặc khác nhau từ cây con gốc \(u\). Khi duyệt các con của \(u\) theo mọi thứ tự, có \(|son(u)|!\) cách, mỗi con \(v\)\(f(v)\) cách, nên \(f(u)=|son(u)|! \cdot \prod_{v \in son(u)} f(v)\). Tuy nhiên, nếu có các cây con đồng cấu, sẽ bị trùng, nên cần chia cho tích giai thừa số lần xuất hiện của mỗi loại cây con (giống đếm hoán vị đa tập).

DP như trên sẽ tính được số cách cho gốc. Dùng DP đổi gốc để tính cho mọi đỉnh.

Code tham khảo
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#include <iostream>
#include <map>
#include <random>
#include <vector>

using ull = unsigned long long;

constexpr int N = 1e5 + 10, M = 998244353;
const ull mask = std::mt19937_64(time(nullptr))();

struct Tree {
  ull hash, deg, ans;
  std::map<ull, ull> son;

  Tree() { clear(); }

  void add(Tree& o);
  void remove(Tree& o);
  void clear();
};

ull inv(ull x) {
  ull y = M - 2, z = 1;
  while (y) {
    if (y & 1) {
      z = z * x % M;
    }
    x = x * x % M;
    y >>= 1;
  }
  return z;
}

ull shift(ull x) {
  x ^= mask;
  x ^= x << 13;
  x ^= x >> 7;
  x ^= x << 17;
  x ^= mask;
  return x;
}

void Tree::add(Tree& o) {
  ull temp = shift(o.hash);
  hash += temp;
  ans = ans * ++deg % M * inv(++son[temp]) % M * o.ans % M;
}

void Tree::remove(Tree& o) {
  ull temp = shift(o.hash);
  hash -= temp;
  ans = ans * inv(deg--) % M * son[temp]-- % M * inv(o.ans) % M;
}

void Tree::clear() {
  hash = 1;
  deg = 0;
  ans = 1;
  son.clear();
}

std::vector<int> edge[N];
Tree sub[N], root[N];
std::map<ull, ull> trees;

void getSub(int x, int fa) {
  for (int i : edge[x]) {
    if (i == fa) {
      continue;
    }
    getSub(i, x);
    sub[x].add(sub[i]);
  }
}

void getRoot(int x, int fa) {
  for (int i : edge[x]) {
    if (i == fa) {
      continue;
    }
    root[x].remove(sub[i]);
    root[i] = sub[i];
    root[i].add(root[x]);
    root[x].add(sub[i]);
    getRoot(i, x);
  }
  trees[root[x].hash] = root[x].ans;
}

using std::cin;
using std::cout;

int main() {
  cin.tie(nullptr)->sync_with_stdio(false);
  int t, n;
  cin >> t;
  while (t--) {
    cin >> n;
    for (int i = 1; i < n; i++) {
      int u, v;
      cin >> u >> v;
      edge[u].push_back(v);
      edge[v].push_back(u);
    }
    getSub(1, 0);
    root[1] = sub[1];
    getRoot(1, 0);
    ull tot = 0;
    for (auto p : trees) {
      tot = (tot + p.second) % M;
    }
    cout << tot << '\n';
    for (int i = 1; i <= n; i++) {
      edge[i].clear();
      sub[i].clear();
      root[i].clear();
    }
    trees.clear();
  }
}

Tài liệu tham khảo

Phương pháp hash trong bài tham khảo và mở rộng từ blog 一种好写且卡不掉的树哈希.