Link Cut Tree
動的木
パスのsumを計算したり, パスに対する作用を遅延伝搬できる.
Spec
-
struct node
- link cut treeで扱うノードの構造体
- この中に載せたいデータを載せる
-
fix(node * n)
- ノードの情報の再計算をする
-
reverse(node * n)
- 平衡二分木の反転
- モノイドの演算順序が反転するのでその処理を書く
- モノイドが可換であれば問題ない
-
lazy(node * n, i64 l)
- 遅延伝搬するときの演算
expose(n); lazy(n, x)
をすると,[root, n]
のパスにx
を作用させることになる
-
push(node* n)
- 遅延伝搬
- lazyを変えている場合はここも変更
-
expose(node* n)
n
をLink Cut Treeの根として, その木が[root, n]
のパスをあらわすようになる
-
link(node* p, node* c)
p
を親,c
を子として繋げる
-
cut(node* c)
c
の親とつながっている辺を切る
-
evert(node* t)
-
t
を親にする
Code
パス加算, パスsumを処理している
include <cstdint>
#include <utility>
#include <string>
#include <iostream>
using i64 = long long;
namespace lctree {
struct R {
int a;
R(): a(0) {}
R(int a): a(a) {}
};
struct V {
int a;
V(): a(0) {}
V(int a): a(a) {}
};
inline V compress(const V& a, const V& b) { return V(a.a + b.a); }
inline V rake_merge(const V& a, const R& b) { return V(a.a + b.a); }
inline V reverse(const V& a) { return a; }
inline void rake_plus(R& a, const V& b) { a.a += b.a; }
inline void rake_minus(R& a, const V& b) { a.a -= b.a; }
struct node;
extern struct node n[505050];
extern int ni;
using node_index = std::uint_least32_t;
using size_type = std::size_t;
struct node {
node_index c[3];
V v; V f; R r;
bool rev;
node(): rev(false) { c[0] = c[1] = c[2] = 0; }
node& operator[](int d) { return n[c[d]]; }
};
inline node_index new_node(V v) { n[ni].v = v; n[ni].f = v; return ni++; }
inline void reverse(node_index i) {
n[i].v = reverse(n[i].v);
n[i].f = reverse(n[i].f);
n[i].rev ^= true;
}
inline void push(node_index i) {
if(n[i].rev) {
std::swap(n[i].c[0], n[i].c[1]);
if(n[i].c[0]) reverse(n[i].c[0]);
if(n[i].c[1]) reverse(n[i].c[1]);
n[i].rev = false;
}
}
inline void fix(node_index i) {
push(i);
n[i].f = compress(compress(n[i][0].f, n[i].v), rake_merge(n[i][1].f, n[i].r));
}
inline int child_dir(node_index i) {
if(n[i].c[2]) {
if(n[i][2].c[0] == i) { return 0; }
else if(n[i][2].c[1] == i) { return 1; }
}
return 3;
}
inline void rotate(node_index x, size_type dir) {
node_index p = n[x].c[2];
int x_dir = child_dir(x);
node_index y = n[x].c[dir ^ 1];
n[n[y][dir].c[2] = x].c[dir ^ 1] = n[y].c[dir];
n[n[x].c[2] = y].c[dir] = x;
n[y].c[2] = p;
if(x_dir < 2) n[p].c[x_dir] = y;
fix(n[x].c[dir ^ 1]);
fix(x);
}
void splay(node_index i) {
push(i);
int i_dir;
int j_dir;
while(child_dir(i) < 2) {
node_index j = n[i].c[2];
if(child_dir(j) < 2) {
node_index k = n[j].c[2];
push(k), push(j), push(i);
i_dir = child_dir(i);
j_dir = child_dir(j);
if(i_dir == j_dir) rotate(k, j_dir ^ 1), rotate(j, i_dir ^ 1);
else rotate(j, i_dir ^ 1), rotate(k, j_dir ^ 1);
}
else push(j), push(i), rotate(j, child_dir(i) ^ 1);
}
fix(i);
}
node_index expose(node_index i) {
node_index right = 0;
node_index ii = i;
while(i) {
splay(i);
rake_minus(n[i].r, n[right].f);
rake_plus(n[i].r, n[i][1].f);
n[i].c[1] = right;
fix(i);
right = i;
i = n[i].c[2];
}
splay(ii);
return ii;
}
void link(node_index i, node_index j) {
if(!i || !j) return;
expose(i);
expose(j);
n[n[j].c[2] = i].c[1] = j;
fix(i);
}
void cut(node_index i) {
if(!i) return;
expose(i);
node_index p = n[i].c[0];
n[i].c[0] = n[p].c[2] = 0;
fix(i);
}
void evert(node_index i) {
if(!i) return;
expose(i);
reverse(i);
push(i);
}
node n[505050];
int ni = 1;
int all_tree(node_index i) {
expose(i);
return n[i].f.a;
}
}
#include <bits/stdc++.h>
using namespace std;
using i64 = long long;
#define rep(i,s,e) for(i64 (i) = (s);(i) < (e);(i)++)
#define all(x) x.begin(),x.end()
template<class T>
static inline std::vector<T> ndvec(size_t&& n, T val) noexcept {
return std::vector<T>(n, std::forward<T>(val));
}
template<class... Tail>
static inline auto ndvec(size_t&& n, Tail&&... tail) noexcept {
return std::vector<decltype(ndvec(std::forward<Tail>(tail)...))>(n, ndvec(std::forward<Tail>(tail)...));
}
template<class T, class Cond>
struct chain {
Cond cond; chain(Cond cond) : cond(cond) {}
bool operator()(T& a, const T& b) const {
if(cond(a, b)) { a = b; return true; }
return false;
}
};
template<class T, class Cond>
chain<T, Cond> make_chain(Cond cond) { return chain<T, Cond>(cond); }
#include <cstdio>
namespace niu {
char cur;
struct FIN {
static inline bool is_blank(char c) { return c <= ' '; }
inline char next() { return cur = getc_unlocked(stdin); }
inline char peek() { return cur; }
inline void skip() { while(is_blank(next())){} }
#define intin(inttype) \
FIN& operator>>(inttype& n) { \
bool sign = 0; \
n = 0; \
skip(); \
while(!is_blank(peek())) { \
if(peek() == '-') sign = 1; \
else n = (n << 1) + (n << 3) + (peek() & 0b1111); \
next(); \
} \
if(sign) n = -n; \
return *this; \
}
intin(int)
intin(long long)
} fin;
char tmp[128];
struct FOUT {
static inline bool is_blank(char c) { return c <= ' '; }
inline void push(char c) { putc_unlocked(c, stdout); }
FOUT& operator<<(char c) { push(c); return *this; }
FOUT& operator<<(const char* s) { while(*s) push(*s++); return *this; }
#define intout(inttype) \
FOUT& operator<<(inttype n) { \
if(n) { \
char* p = tmp + 127; bool neg = 0; \
if(n < 0) neg = 1, n = -n; \
while(n) *--p = (n % 10) | 0b00110000, n /= 10; \
if(neg) *--p = '-'; \
return (*this) << p; \
} \
else { \
push('0'); \
return *this; \
} \
}
intout(int)
intout(long long)
} fout;
}
int main() {
using niu::fin;
using niu::fout;
i64 N, Q;
fin >> N;
vector<vector<int>> vs(N);
vector<int> co(N);
for(int i = 0;i < N;i++) {
lctree::new_node(1);
int a;
fin >> a;
a--;
co[i] = a;
vs[a].push_back(i);
}
vector<vector<int>> G(N);
for(int i = 0;i + 1 < N;i++) {
i64 a, b;
fin >> a >> b;
a--;
b--;
G[a].push_back(b);
if(co[a] != co[b])
G[b].push_back(a);
lctree::evert(b + 1);
lctree::link(a + 1, b + 1);
}
auto func = [&](i64 ans, i64 a, i64 b) {
i64 A = lctree::all_tree(a);
ans -= A * (A + 1) / 2;
lctree::evert(a);
lctree::cut(b);
i64 B = lctree::all_tree(a);
ans += B * (B + 1) / 2;
i64 C = lctree::all_tree(b);
ans += C * (C + 1) / 2;
//std::cout << A << " " << B << " " << C << std::endl;
return ans;
};
for(int i = 0;i < N;i++) {
i64 ans = (N - vs[i].size()) * ((N - vs[i].size()) + 1) / 2;
for(auto v: vs[i]) {
lctree::expose(v + 1);
lctree::n[v + 1].v.a = 0;
lctree::fix(v + 1);
}
for(auto v: vs[i]) {
for(auto t: G[v]) {
ans = func(ans, v + 1, t + 1);
}
}
//cout << ans << endl;
fout << (N * (N + 1) / 2) - ans << "\n";
for(auto v: vs[i]) {
lctree::expose(v + 1);
lctree::n[v + 1].v.a = 1;
lctree::fix(v + 1);
}
for(auto v: vs[i]) {
for(auto t: G[v]) {
lctree::evert(t + 1);
lctree::link(v + 1, t + 1);
}
}
}
}