P5327 題解
考慮把貢獻攤到每個點上計算,每個點帶來的貢獻實際上是經過它的路徑并大小,算完求和之后在除以 \(2\) 就得到了答案。
考慮怎么計算路徑并大小。
考慮這樣一個辦法,將所有路徑的起始點和終點按照 DFS 序排序,相鄰兩點(包括第一個會最后一個點)在樹上的距離之和便是其路徑并大小的兩倍。原理的話便是路徑并大小等價于包含所有路徑起始點的最小聯通生成樹。
考慮樹上點差分,然后用線段樹儲存 DFS 序為 \(x\) 的點是否在子樹中,維護節點內最大和最小的存在的 DFS 序就可以在通過一次求 LCA 合并兩個子節點的信息。
那么最后一步通過線段樹合并將子樹的信息合并到父親即可。
時間復雜度 \(O(n \log^2 n)\) 空間復雜度 \(O(n \log n)\),也可以通過寫壓縮 01Trie 或者直接維護線段樹葉子節點的方法做到線性空間,但沒什么必要。
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int maxn = 1e5 + 114;
int fa[maxn][19], lg[maxn];
int dfn[maxn], dfncnt, dep[maxn];
int Node[maxn];
int n, m;
vector<int> edge[maxn];
void dfs1(int u, int father) {
dep[u] = dep[father] + 1;
dfn[u] = ++dfncnt;
Node[dfncnt] = u;
fa[u][0] = father;
for (int i = 1; i <= 17; i++)
fa[u][i] = fa[fa[u][i - 1]][i - 1];
for (int v : edge[u]) {
if (v == father)
continue;
dfs1(v, u);
}
}
int LCA(int u, int v) {
if (dep[u] < dep[v])
swap(u, v);
while (dep[u] > dep[v]) {
u = fa[u][lg[dep[u] - dep[v]]];
}
if (u == v)
return u;
for (int i = 17; i >= 0; i--) {
if (fa[u][i] != fa[v][i]) {
u = fa[u][i], v = fa[v][i];
}
}
return fa[u][0];
}
int dist(int x, int y) {
return dep[x] + dep[y] - 2 * dep[LCA(x, y)];
}
#define ls(cur)(tr[cur].ls)
#define rs(cur)(tr[cur].rs)
int tot;
struct Segment_tree {
int ls, rs;
int mi, mx, sum, cnt;
} tr[maxn * 40];
int root[maxn];
void pushup(int cur) {
tr[cur].cnt = tr[ls(cur)].cnt + tr[rs(cur)].cnt;
if (tr[ls(cur)].cnt == 0 && tr[rs(cur)].cnt == 0)
cur = 0;
else if (tr[ls(cur)].cnt == 0)
tr[cur].sum = tr[rs(cur)].sum, tr[cur].mi = tr[rs(cur)].mi, tr[cur].mx = tr[rs(cur)].mx;
else if (tr[rs(cur)].cnt == 0)
tr[cur].sum = tr[ls(cur)].sum, tr[cur].mi = tr[ls(cur)].mi, tr[cur].mx = tr[ls(cur)].mx;
else {
tr[cur].sum = tr[ls(cur)].sum + tr[rs(cur)].sum + dist(Node[tr[ls(cur)].mx], Node[tr[rs(cur)].mi]);
tr[cur].mi = tr[ls(cur)].mi;
tr[cur].mx = tr[rs(cur)].mx;
}
}
void update(int &cur, int lt, int rt, int pos, int v) {
if (pos < lt || pos > rt)
return ;
if (cur == 0)
cur = ++tot;
if (lt == rt && lt == pos) {
tr[cur].cnt += v;
tr[cur].mi = tr[cur].mx = lt;
tr[cur].sum = 0;
return ;
}
int mid = (lt + rt) >> 1;
update(ls(cur), lt, mid, pos, v);
update(rs(cur), mid + 1, rt, pos, v);
pushup(cur);
}
int merge(int a, int b, int lt, int rt) {
if (a == 0 || b == 0)
return a + b;
if (lt == rt) {
tr[a].cnt += tr[b].cnt;
tr[a].mi = tr[a].mx = lt;
tr[a].sum = 0;
return a;
}
int mid = (lt + rt) >> 1;
tr[a].ls = merge(tr[a].ls, tr[b].ls, lt, mid);
tr[a].rs = merge(tr[a].rs, tr[b].rs, mid + 1, rt);
pushup(a);
return a;
}
vector<int> Ins[maxn], Del[maxn];
int answer;
void dfs2(int u, int father) {
for (int v : edge[u]) {
if (v == father)
continue;
dfs2(v, u);
root[u] = merge(root[u], root[v], 1, n);
}
for (int x : Ins[u])
update(root[u], 1, n, x, 1);
for (int x : Del[u])
update(root[u], 1, n, x, -1);
answer += (tr[root[u]].sum + dist(Node[tr[root[u]].mi], Node[tr[root[u]].mx])) / 2;
}
signed main() {
cin >> n >> m;
lg[1] = 0;
for (int i = 2; i <= n; i++)
lg[i] = lg[i / 2] + 1;
for (int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
edge[u].push_back(v);
edge[v].push_back(u);
}
dep[0] = -1;
dfs1(1, 0);
for (int i = 1; i <= m; i++) {
int u, v;
cin >> u >> v;
if (u == v)
continue;
Ins[u].push_back(dfn[u]);
Ins[u].push_back(dfn[v]);
Ins[v].push_back(dfn[v]);
Ins[v].push_back(dfn[u]);
Del[LCA(u, v)].push_back(dfn[u]);
Del[LCA(u, v)].push_back(dfn[v]);
Del[fa[LCA(u, v)][0]].push_back(dfn[u]);
Del[fa[LCA(u, v)][0]].push_back(dfn[v]);
}
dfs2(1, 0);
cout << answer / 2;
}

浙公網安備 33010602011771號