1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
| #include<bits/stdc++.h> #define ll long long #define ull unsigned long long #define int long long #define maxn 300005 #define put() putchar('\n') #define Tp template<typename Ty> #define Ts template<typename Ty,typename... Ar> using namespace std; inline void read(int &x){ int f=1;x=0;char c=getchar(); while (c<'0'||c>'9') {if (c=='-') f=-1;c=getchar();} while (c>='0'&&c<='9') {x=x*10+c-'0';c=getchar();} x*=f; } namespace Debug{ Tp void _debug(char* f,Ty t){cerr<<f<<'='<<t<<endl;} Ts void _debug(char* f,Ty x,Ar... y){while(*f!=',') cerr<<*f++;cerr<<'='<<x<<",";_debug(f+1,y...);} Tp ostream& operator<<(ostream& os,vector<Ty>& V){os<<"[";for(auto& vv:V) os<<vv<<",";os<<"]";return os;} #define gdb(...) _debug((char*)#__VA_ARGS__,__VA_ARGS__) }using namespace Debug; vector<int>to[maxn]; int n,m; int deep[maxn],fa[maxn],dfn[maxn],son[maxn],top[maxn],siz[maxn],times,id[maxn],pp[maxn]; inline void dfs(int x,int pre) { int i; deep[x]=deep[pre]+1;fa[x]=pre;siz[x]=1;son[x]=0; for (auto y:to[x]) if (y^pre) { dfs(y,x); siz[x]+=siz[y]; if (!son[x]||siz[son[x]]<siz[y]) son[x]=y; } } inline void dfs2(int x,int pre,int u) { int i;top[x]=u;id[x]=++times;pp[times]=x; if (!son[x]) return ; dfs2(son[x],x,u); for (auto y:to[x]) if (y^pre&&y^son[x]) dfs2(y,x,y); } inline int query(int x,int y) { while (top[x]^top[y]) { if (deep[top[x]]<deep[top[y]]) swap(x,y); x=fa[top[x]]; } if (deep[x]<deep[y]) return x;else return y; } inline int getk(int x,int k) { while (deep[x]-deep[top[x]]+1<=k) k-=deep[x]-deep[top[x]]+1,x=fa[top[x]]; return pp[id[x]-k]; } struct node{ int x,y,g,xx,yy; }e[maxn]; vector<node>O[maxn]; int d[maxn]; int t[maxn],total[maxn]; ll ans; map<ll,int>mp; inline void solve(int x,int pre) { int i; for (auto y:to[x]) if (y^pre) solve(y,x),d[x]+=d[y]; for (auto tmp:O[x]) d[tmp.xx]--,d[tmp.yy]--; int nums=O[x].size(); ans+=1ll*d[x]*nums+1ll*nums*(nums-1)/2; ans+=1ll*(d[x]+nums)*total[x]+1ll*total[x]*(total[x]-1)/2; for (auto tmp:O[x]) { t[tmp.xx]++,t[tmp.yy]++; if (tmp.xx&&tmp.yy) mp[1ll*tmp.xx*(n+1)+tmp.yy]++; } for (auto y:to[x]) if (y^pre) ans-=1ll*d[y]*t[y]; t[0]=0; for (auto tmp:O[x]) { if (t[tmp.xx]) ans-=1ll*t[tmp.xx]*(t[tmp.xx]-1)/2,t[tmp.xx]=0; if (t[tmp.yy]) ans-=1ll*t[tmp.yy]*(t[tmp.yy]-1)/2,t[tmp.yy]=0; ll now=1ll*tmp.xx*(n+1)+tmp.yy; ans+=1ll*mp[now]*(mp[now]-1)/2;mp[now]=0; } } signed main(void){
int i,x,y; read(n); for (i=1;i<n;i++) read(x),read(y),to[x].push_back(y),to[y].push_back(x); dfs(1,0); dfs2(1,0,1); read(m); for (i=1;i<=m;i++) { read(e[i].x),read(e[i].y); if (e[i].x==e[i].y) {total[e[i].x]++;continue;} e[i].g=query(e[i].x,e[i].y); if (e[i].x^e[i].g) e[i].xx=getk(e[i].x,deep[e[i].x]-deep[e[i].g]-1); if (e[i].y^e[i].g) e[i].yy=getk(e[i].y,deep[e[i].y]-deep[e[i].g]-1); if (e[i].xx>e[i].yy) swap(e[i].xx,e[i].yy); d[e[i].x]++,d[e[i].y]++,d[e[i].g]-=2; O[e[i].g].push_back(e[i]); } solve(1,0); printf("%lld",ans); return 0; }
|