1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
| #include<bits/stdc++.h> #define ll long long #define ull unsigned long long #define maxn 100005 #define put() putchar('\n') #define Tp template<typename Ty> #define Ts template<typename Ty,typename... Ar> using namespace std; inline void read(int &x){ int f=1;x=0;char c=getchar(); while (c<'0'||c>'9') {if (c=='-') f=-1;c=getchar();} while (c>='0'&&c<='9') {x=x*10+c-'0';c=getchar();} x*=f; } namespace Debug{ Tp void _debug(char* f,Ty t){cerr<<f<<'='<<t<<endl;} Ts void _debug(char* f,Ty x,Ar... y){while(*f!=',') cerr<<*f++;cerr<<'='<<x<<",";_debug(f+1,y...);} Tp ostream& operator<<(ostream& os,vector<Ty>& V){os<<"[";for(auto& vv:V) os<<vv<<",";os<<"]";return os;} #define gdb(...) _debug((char*)#__VA_ARGS__,__VA_ARGS__) }using namespace Debug; #define fi first #define se second #define mk make_pair const int mod=1e9+7; inline int power(int x,int y=mod-2) { int sum=1; while (y) { if (y&1) sum=sum*x%mod; x=x*x%mod;y>>=1; } return sum; } int n,m; vector<int>to[maxn]; int dfn[maxn],times,son[maxn],siz[maxn],fa[maxn],top[maxn],deep[maxn],p[maxn]; inline void dfs1(int x,int pre) { siz[x]=1;fa[x]=pre;deep[x]=deep[pre]+1; for (auto y:to[x]) if (y^pre) { dfs1(y,x); siz[x]+=siz[y]; if (!son[x]||siz[y]>siz[son[x]]) son[x]=y; } } inline void dfs2(int x,int pre,int u) { top[x]=u;dfn[x]=++times;p[times]=x; if (!son[x]) return ; dfs2(son[x],x,u); for (auto y:to[x]) if (y!=pre&&y!=son[x]) { dfs2(y,x,y); } } inline int lca(int x,int y) { while (top[x]^top[y]) { if (deep[top[x]]<deep[top[y]]) swap(x,y); x=fa[top[x]]; } return deep[x]<deep[y]?x:y; } int ql[maxn],qr[maxn],cnt,root[maxn]; ll ans; inline void query(int x,int y) { cnt=0; while (top[x]^top[y]) { if (deep[top[x]]<deep[top[y]]) swap(x,y); ++cnt;qr[cnt]=dfn[x],ql[cnt]=dfn[top[x]]; x=fa[top[x]]; } if (deep[x]<deep[y]) swap(x,y); ++cnt;ql[cnt]=dfn[y],qr[cnt]=dfn[x]; } namespace seg{ struct node { int ls,rs,Min,sum,suf; }f[maxn*400]; int total; inline void Pushup(int l,int r,int mid,int rt) { int tmp1,tmp2,tmp3,tmp4,tmp5,tmp6; if (f[rt].ls) tmp1=f[f[rt].ls].Min,tmp2=f[f[rt].ls].sum,tmp5=f[f[rt].ls].suf; else tmp1=0,tmp2=(mid-l+1),tmp5=0; if (f[rt].rs) tmp3=f[f[rt].rs].Min,tmp4=f[f[rt].rs].sum,tmp6=f[f[rt].rs].suf; else tmp3=0,tmp4=r-mid,tmp6=0; f[rt].Min=min(tmp1,tmp5+tmp3); f[rt].suf=tmp5+tmp6; if (tmp5+tmp3<tmp1) f[rt].sum=tmp4; else if (tmp5+tmp3==tmp1) f[rt].sum=tmp2+tmp4; else f[rt].sum=tmp2; } inline void Update(int l,int r,int &rt,int head,int k) { if (!rt) rt=++total,f[rt].Min=0,f[rt].sum=r-l+1,f[rt]; if (l==r) return f[rt].Min+=k,f[rt].suf+=k,f[rt].sum=1,void(); int mid=l+r>>1; if (head<=mid) Update(l,mid,f[rt].ls,head,k); else Update(mid+1,r,f[rt].rs,head,k); Pushup(l,r,mid,rt); } inline void merge(int l,int r,int &x,int y) { if (!x||!y) return x=x+y,void(); if (l==r) return f[x].Min+=f[y].Min,f[x].suf+=f[y].suf,f[x].sum=1,void(); int mid=l+r>>1; merge(l,mid,f[x].ls,f[y].ls); merge(mid+1,r,f[x].rs,f[y].rs); Pushup(l,r,mid,x); } inline void print(int l,int r,int &rt) { if (!rt) return printf("[%d , %d] = %d %d\n",l,r,0,1),void(); if (l==r) return printf("%d : %d %d\n",l,f[rt].Min,f[rt].sum),void(); printf("[%d , %d] = %d %d\n",l,r,f[rt].Min,f[rt].sum,f[rt].suf); int mid=l+r>>1; print(l,mid,f[rt].ls);print(mid+1,r,f[rt].rs); } } inline void update(int x,int flag) { int i; if (!x) return ; for (i=1;i<=cnt;i++) { seg::Update(1,n,root[x],ql[i],flag); if (qr[i]<n) seg::Update(1,n,root[x],qr[i]+1,-flag); } } inline void solve(int x) { for (auto y:to[x]) if (y^fa[x]) { solve(y); seg::merge(1,n,root[x],root[y]); } ans+=n-(seg::f[root[x]].Min==0)*seg::f[root[x]].sum-1; } signed main(void){ int i,x,y,g; read(n);read(m); for (i=1;i<=n-1;i++) read(x),read(y),to[x].push_back(y),to[y].push_back(x); dfs1(1,0); dfs2(1,0,1); for (i=1;i<=n;i++) { query(i,i); update(i,1); update(fa[i],-1); } for (i=1;i<=m;i++) { read(x),read(y);g=lca(x,y); if (x==y) continue; query(x,y); update(x,1);update(y,1);update(g,-1);update(fa[g],-1); } solve(1); printf("%lld",ans/2); return 0; }
|