CF1486F Pairs of Paths

CF1486F Pairs of Paths

Description

给定一棵包含 $n$ 个节点(编号 $1$ 到 $n$),我们同时选定该树上的 $m$ 条简单路径,求多少对无序二元组满足两条路径有且仅有一个焦点。

$1\leq n,m\leq3\times10^5$

6s , 512MB

Solution

先不考虑单独一个节点的路径。

考虑对某个节点分析,我们统计两条路径的产生贡献且交在这个节点上的数量,钦定这个节点为 $1$。

对于路径 $i$ ,$j$ 有两种情况:

  • 两者路径的 lca 都为 $i$,且四个端点分别处于 $i$ 的四个不同子树中。
  • 一条路径的 lca 为 $i$,$j$ 的一端在 $i$ 的子树外,一端在不同于 $i$ 的两个端点所在的子树内。

第一种情况用容斥计算。所有的 $-$ 至少一个端点在相同子树的个数 $+$ 两个端点都在相同子树中。

第二种情况也可以容斥。所有的 $-$ 一个端点相同的。

最后加上单独一个节点的路径。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#include<bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define int long long
#define maxn 300005
#define put() putchar('\n')
#define Tp template<typename Ty>
#define Ts template<typename Ty,typename... Ar>
using namespace std;
inline void read(int &x){
int f=1;x=0;char c=getchar();
while (c<'0'||c>'9') {if (c=='-') f=-1;c=getchar();}
while (c>='0'&&c<='9') {x=x*10+c-'0';c=getchar();}
x*=f;
}
namespace Debug{
Tp void _debug(char* f,Ty t){cerr<<f<<'='<<t<<endl;}
Ts void _debug(char* f,Ty x,Ar... y){while(*f!=',') cerr<<*f++;cerr<<'='<<x<<",";_debug(f+1,y...);}
Tp ostream& operator<<(ostream& os,vector<Ty>& V){os<<"[";for(auto& vv:V) os<<vv<<",";os<<"]";return os;}
#define gdb(...) _debug((char*)#__VA_ARGS__,__VA_ARGS__)
}using namespace Debug;
vector<int>to[maxn];
int n,m;
int deep[maxn],fa[maxn],dfn[maxn],son[maxn],top[maxn],siz[maxn],times,id[maxn],pp[maxn];
inline void dfs(int x,int pre) {
int i;
deep[x]=deep[pre]+1;fa[x]=pre;siz[x]=1;son[x]=0;
for (auto y:to[x]) if (y^pre) {
dfs(y,x);
siz[x]+=siz[y];
if (!son[x]||siz[son[x]]<siz[y]) son[x]=y;
}
}
inline void dfs2(int x,int pre,int u) {
int i;top[x]=u;id[x]=++times;pp[times]=x;
if (!son[x]) return ;
dfs2(son[x],x,u);
for (auto y:to[x]) if (y^pre&&y^son[x]) dfs2(y,x,y);
}
inline int query(int x,int y) {
while (top[x]^top[y]) {
if (deep[top[x]]<deep[top[y]]) swap(x,y);
x=fa[top[x]];
}
if (deep[x]<deep[y]) return x;else return y;
}
inline int getk(int x,int k) {
while (deep[x]-deep[top[x]]+1<=k) k-=deep[x]-deep[top[x]]+1,x=fa[top[x]];
return pp[id[x]-k];
}
struct node{
int x,y,g,xx,yy;
}e[maxn];
vector<node>O[maxn];
int d[maxn];
int t[maxn],total[maxn];
ll ans;
map<ll,int>mp;
inline void solve(int x,int pre) {
int i;
for (auto y:to[x]) if (y^pre) solve(y,x),d[x]+=d[y];
for (auto tmp:O[x]) d[tmp.xx]--,d[tmp.yy]--;
int nums=O[x].size();
ans+=1ll*d[x]*nums+1ll*nums*(nums-1)/2;
ans+=1ll*(d[x]+nums)*total[x]+1ll*total[x]*(total[x]-1)/2;

for (auto tmp:O[x]) {
t[tmp.xx]++,t[tmp.yy]++;
if (tmp.xx&&tmp.yy) mp[1ll*tmp.xx*(n+1)+tmp.yy]++;
}
for (auto y:to[x]) if (y^pre) ans-=1ll*d[y]*t[y];
t[0]=0;
for (auto tmp:O[x]) {
if (t[tmp.xx]) ans-=1ll*t[tmp.xx]*(t[tmp.xx]-1)/2,t[tmp.xx]=0;
if (t[tmp.yy]) ans-=1ll*t[tmp.yy]*(t[tmp.yy]-1)/2,t[tmp.yy]=0;
ll now=1ll*tmp.xx*(n+1)+tmp.yy;
ans+=1ll*mp[now]*(mp[now]-1)/2;mp[now]=0;
}
}
signed main(void){
// freopen("1.in","r",stdin);
int i,x,y;
read(n);
for (i=1;i<n;i++) read(x),read(y),to[x].push_back(y),to[y].push_back(x);
dfs(1,0);
dfs2(1,0,1);
read(m);
for (i=1;i<=m;i++) {
read(e[i].x),read(e[i].y);
if (e[i].x==e[i].y) {total[e[i].x]++;continue;}
e[i].g=query(e[i].x,e[i].y);
if (e[i].x^e[i].g) e[i].xx=getk(e[i].x,deep[e[i].x]-deep[e[i].g]-1);
if (e[i].y^e[i].g) e[i].yy=getk(e[i].y,deep[e[i].y]-deep[e[i].g]-1);
if (e[i].xx>e[i].yy) swap(e[i].xx,e[i].yy);
d[e[i].x]++,d[e[i].y]++,d[e[i].g]-=2;
O[e[i].g].push_back(e[i]);
}
solve(1,0);
printf("%lld",ans);
return 0;
}