P5327 [ZJOI2019] 语言

[ZJOI2019] 语言

Description

https://www.luogu.com.cn/problem/P5327

Solution

考虑暴力。在一种语言之中,用树剖分成 $O(\log n)$ 个区间,将每种语言的做 $O(\log n)$ 次区间覆盖。最后统计每个点覆盖了多少个点。除去自己再除 $2$,即为可以互通的个数。

做一堆点相同的操作,考虑差分+线段树合并。

线段树每个节点统计前缀最小值的个数。维护前缀和,前缀最小值,前缀最小值的个数,左右合并时平凡的。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
#include<bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define maxn 100005
#define put() putchar('\n')
#define Tp template<typename Ty>
#define Ts template<typename Ty,typename... Ar>
using namespace std;
inline void read(int &x){
int f=1;x=0;char c=getchar();
while (c<'0'||c>'9') {if (c=='-') f=-1;c=getchar();}
while (c>='0'&&c<='9') {x=x*10+c-'0';c=getchar();}
x*=f;
}
namespace Debug{
Tp void _debug(char* f,Ty t){cerr<<f<<'='<<t<<endl;}
Ts void _debug(char* f,Ty x,Ar... y){while(*f!=',') cerr<<*f++;cerr<<'='<<x<<",";_debug(f+1,y...);}
Tp ostream& operator<<(ostream& os,vector<Ty>& V){os<<"[";for(auto& vv:V) os<<vv<<",";os<<"]";return os;}
#define gdb(...) _debug((char*)#__VA_ARGS__,__VA_ARGS__)
}using namespace Debug;
#define fi first
#define se second
#define mk make_pair
const int mod=1e9+7;
inline int power(int x,int y=mod-2) {
int sum=1;
while (y) {
if (y&1) sum=sum*x%mod;
x=x*x%mod;y>>=1;
}
return sum;
}
int n,m;
vector<int>to[maxn];
int dfn[maxn],times,son[maxn],siz[maxn],fa[maxn],top[maxn],deep[maxn],p[maxn];
inline void dfs1(int x,int pre) {
siz[x]=1;fa[x]=pre;deep[x]=deep[pre]+1;
for (auto y:to[x]) if (y^pre) {
dfs1(y,x);
siz[x]+=siz[y];
if (!son[x]||siz[y]>siz[son[x]]) son[x]=y;
}
}
inline void dfs2(int x,int pre,int u) {
top[x]=u;dfn[x]=++times;p[times]=x;
if (!son[x]) return ;
dfs2(son[x],x,u);
for (auto y:to[x]) if (y!=pre&&y!=son[x]) {
dfs2(y,x,y);
}
}
inline int lca(int x,int y) {
while (top[x]^top[y]) {
if (deep[top[x]]<deep[top[y]]) swap(x,y);
x=fa[top[x]];
}
return deep[x]<deep[y]?x:y;
}
int ql[maxn],qr[maxn],cnt,root[maxn];
ll ans;
inline void query(int x,int y) {
cnt=0;
while (top[x]^top[y]) {
if (deep[top[x]]<deep[top[y]]) swap(x,y);
++cnt;qr[cnt]=dfn[x],ql[cnt]=dfn[top[x]];
x=fa[top[x]];
}
if (deep[x]<deep[y]) swap(x,y);
++cnt;ql[cnt]=dfn[y],qr[cnt]=dfn[x];
}
namespace seg{
struct node {
int ls,rs,Min,sum,suf;
}f[maxn*400];
int total;
inline void Pushup(int l,int r,int mid,int rt) {
int tmp1,tmp2,tmp3,tmp4,tmp5,tmp6;
if (f[rt].ls) tmp1=f[f[rt].ls].Min,tmp2=f[f[rt].ls].sum,tmp5=f[f[rt].ls].suf;
else tmp1=0,tmp2=(mid-l+1),tmp5=0;
if (f[rt].rs) tmp3=f[f[rt].rs].Min,tmp4=f[f[rt].rs].sum,tmp6=f[f[rt].rs].suf;
else tmp3=0,tmp4=r-mid,tmp6=0;
f[rt].Min=min(tmp1,tmp5+tmp3);
f[rt].suf=tmp5+tmp6;
if (tmp5+tmp3<tmp1) f[rt].sum=tmp4;
else if (tmp5+tmp3==tmp1) f[rt].sum=tmp2+tmp4;
else f[rt].sum=tmp2;
}
inline void Update(int l,int r,int &rt,int head,int k) {
if (!rt) rt=++total,f[rt].Min=0,f[rt].sum=r-l+1,f[rt];
if (l==r) return f[rt].Min+=k,f[rt].suf+=k,f[rt].sum=1,void();
int mid=l+r>>1;
if (head<=mid) Update(l,mid,f[rt].ls,head,k);
else Update(mid+1,r,f[rt].rs,head,k);
Pushup(l,r,mid,rt);
}
inline void merge(int l,int r,int &x,int y) {
if (!x||!y) return x=x+y,void();
if (l==r) return f[x].Min+=f[y].Min,f[x].suf+=f[y].suf,f[x].sum=1,void();
int mid=l+r>>1;
merge(l,mid,f[x].ls,f[y].ls);
merge(mid+1,r,f[x].rs,f[y].rs);
Pushup(l,r,mid,x);
}
inline void print(int l,int r,int &rt) {
if (!rt) return printf("[%d , %d] = %d %d\n",l,r,0,1),void();
if (l==r) return printf("%d : %d %d\n",l,f[rt].Min,f[rt].sum),void();
printf("[%d , %d] = %d %d\n",l,r,f[rt].Min,f[rt].sum,f[rt].suf);
int mid=l+r>>1;
print(l,mid,f[rt].ls);print(mid+1,r,f[rt].rs);
}
}
inline void update(int x,int flag) {
int i;
if (!x) return ;
for (i=1;i<=cnt;i++) {
seg::Update(1,n,root[x],ql[i],flag);
if (qr[i]<n) seg::Update(1,n,root[x],qr[i]+1,-flag);
}
}
inline void solve(int x) {
for (auto y:to[x]) if (y^fa[x]) {
solve(y);
seg::merge(1,n,root[x],root[y]);
}
ans+=n-(seg::f[root[x]].Min==0)*seg::f[root[x]].sum-1;
}
signed main(void){
int i,x,y,g;
read(n);read(m);
for (i=1;i<=n-1;i++) read(x),read(y),to[x].push_back(y),to[y].push_back(x);
dfs1(1,0);
dfs2(1,0,1);
for (i=1;i<=n;i++) {
query(i,i);
update(i,1);
update(fa[i],-1);
}
for (i=1;i<=m;i++) {
read(x),read(y);g=lca(x,y);
if (x==y) continue;
query(x,y);
update(x,1);update(y,1);update(g,-1);update(fa[g],-1);
}
solve(1);
printf("%lld",ans/2);
return 0;
}