Bzoj4543

不会做。

首先我们我们可以观察到这三个点 $(x,y,z)$ ,存在一个点 $o$,使得 $(x,o)=(y,o)=(z,o)$。

这启发我们有两种思路:

  • 对于一个点 $x$ ,求到 $x$ 距离相等的点的个数。对于子树内部还好,但是对于整棵树来说比较难做。

  • 对于一个点 $x$,以 $x$ 的中继点,找到在其子树内部的 $o$,类似于:

设 $f_{i,j}$ 表示在 $i$ 的子树内深度为 $j$ 的节点个数,$g_{i,j}$ 表示在 $i$ 的子树内,$\forall x,y,x\not = y,(x,lca(x,y))=(y,lca(x,y))=(i,lca(x,y))+j$

转移就比较简单了。

1
2
3
4
5
6
7
ans += f[i][j - 1] * g[to][j];
ans += f[to][j] * g[i][j + 1];

g[i][j + 1] += f[to][j] * f[i][j + 1];
g[i][j - 1] += g[to][j];

f[i][j + 1] += f[to][j];

以深度为下标,用长链剖分优化:

code:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#include<bits/stdc++.h>
#define maxn 200005
#define int long long
#define put() putchar('\n')
using namespace std;
inline void read(int &x){
int f=1;x=0;char c=getchar();
while (c<'0'||c>'9') {if (c=='-') f=-1;c=getchar();}
while (c>='0'&&c<='9') {x=x*10+c-'0';c=getchar();}
x*=f;
}
int deep[maxn],son[maxn],n,h[maxn],head=1;
struct yyy{
int to,z;
inline void add(int x,int y) {
to=y;z=h[x];h[x]=head;
}
}a[maxn*2];
int *f[maxn*2],*g[maxn*2],stac[maxn*4],*id=stac;
int ans;
inline void dfs1(int x,int pre) {
int i;deep[x]=1;
for (i=h[x];i;i=a[i].z)
if (a[i].to^pre) {
dfs1(a[i].to,x);
if (deep[x]<deep[a[i].to]+1) deep[x]=deep[a[i].to]+1,son[x]=a[i].to;
}
}
inline void solve(int x,int pre) {
int i,j;
f[x][0]=1;g[x][0]=0;
if (son[x]) f[son[x]]=f[x]+1,g[son[x]]=g[x]-1,solve(son[x],x);
else return ;
ans+=g[x][0];
for (i=h[x];i;i=a[i].z)
if (a[i].to!=pre&&a[i].to!=son[x]) {
f[a[i].to]=id;id+=deep[a[i].to]*2+2;g[a[i].to]=id;id+=deep[a[i].to]*2+2;solve(a[i].to,x);
for (j=0;j<=deep[a[i].to];j++) {
if (j) ans+=g[a[i].to][j]*f[x][j-1];
ans+=f[a[i].to][j]*g[x][j+1];
}
for (j=0;j<=deep[a[i].to];j++){
g[x][j+1]+=f[a[i].to][j]*f[x][j+1];
if (j) g[x][j-1]+=g[a[i].to][j];
}
for (j=1;j<=deep[a[i].to];j++) f[x][j]+=f[a[i].to][j-1];
}
}
signed main(void){
int i,x,y;
read(n);
for (i=1;i<n;i++) {
read(x);read(y);
a[++head].add(x,y);
a[++head].add(y,x);
}
dfs1(1,0);
f[1]=id;id+=deep[1]*2;g[1]=id;id+=deep[1]*2;
solve(1,0);
printf("%lld\n",ans);
return 0;
}