Bzoj4381

Solution

看到 $n\le 50000$ ,考虑根号分治。我们设定一个阈值 $K$。

  • 对于 $c_i\le K$ 的,前缀和预处理出每个 $sum_{i,j}$ 表示点 $j$ ,每次向上跳 $i$ 步,到根的权值。

    复杂度是 $O(Kn)-O(q\log n)$

  • 对于 $c_i>K$ 的,暴力跳。复杂度是 $O(q\dfrac{n}{K}\log n)$ 或者 $O(q\dfrac{n}{K})$ 的如果你会长链剖分反正我是不会

这种情况一般 $K=\sqrt{n}$ ,但是一般来说 $K$ 取更小一点效果更好。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#include<bits/stdc++.h>
#define maxn 50005
#define ll long long
#define put() putchar('\n')
using namespace std;
inline void read(int &x){
int f=1;x=0;char c=getchar();
while (c<'0'||c>'9') {if (c=='-') f=-1;c=getchar();}
while (c>='0'&&c<='9') {x=x*10+c-'0';c=getchar();}
x*=f;
}
int b[maxn],c[maxn],w[maxn],g[505][maxn],ff[maxn],tr[maxn];
int h[maxn],head=1,fa[maxn][21],deep[maxn],lg[maxn*2];
int n;
struct yyy{
int to,z;
inline void add(int x,int y) {
to=y;z=h[x];h[x]=head;
}
}a[maxn*2];
inline void ins(int x,int y) {
a[++head].add(x,y);
a[++head].add(y,x);
}
int block;
inline void dfs1(int x,int pre) {
int i;deep[x]=deep[pre]+1;fa[x][0]=pre;
for (i=1;i<=lg[deep[x]];i++) fa[x][i]=fa[fa[x][i-1]][i-1];
for (i=h[x];i;i=a[i].z) if (a[i].to^pre) dfs1(a[i].to,x);
}
inline int lca(int x,int y) {
int i;if (deep[x]<deep[y]) swap(x,y);
while (deep[x]>deep[y]) x=fa[x][lg[deep[x]-deep[y]]];
if (x==y) return x;
for (i=lg[deep[x]];i>=0;i--) if (fa[x][i]^fa[y][i]) x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
inline void dfs2(int x,int pre,int j) {
int i;g[j][x]=g[j][ff[x]]+w[x];
for (i=h[x];i;i=a[i].z) if (a[i].to^pre) dfs2(a[i].to,x,j);
}
inline int getfa(int x,int k) {
int i;
for (i=lg[k];i>=0;i--) if ((k>>i)&1) x=fa[x][i];
return x;
}
signed main(void){
freopen("1.in","r",stdin);
int i,j,now,sum,z,x,y,tmp,pus;
read(n);
block=sqrt(n)/3;
for (i=2;i<=n*2;i++) lg[i]=lg[i/2]+1;
for (i=1;i<=n;i++) read(w[i]);
for (i=1;i<n;i++) {
read(x);read(y);
ins(x,y);
}
for (i=1;i<=n;i++) read(b[i]);
for (i=1;i<n;i++) read(c[i]);
dfs1(1,0);
for (i=1;i<=n;i++) ff[i]=i;
for (j=1;j<=block;j++) {
for (i=1;i<=n;i++) tr[i]=ff[i];
for (i=1;i<=n;i++) ff[i]=fa[tr[i]][0];
dfs2(1,0,j);
}
for (i=1;i<n;i++) {
x=b[i],y=b[i+1];
if (c[i]<=block) {
z=lca(x,y);sum=0;
if (deep[x]<deep[y]) swap(x,y);
tmp=(deep[x]-deep[z]+c[i])/c[i]*c[i];sum+=g[c[i]][x]-g[c[i]][pus=getfa(x,tmp)];//printf("%d %d %d %d\n",c[i],x,pus,tmp);
if (y^z) tmp=(deep[y]-deep[z]-1+c[i])/c[i]*c[i],sum+=g[c[i]][y]-g[c[i]][pus=getfa(y,tmp)];//printf("%d %d %d %d\n",c[i],y,pus,tmp);
printf("%d\n",sum);
}
else {
sum=0;z=lca(x,y);
now=x;while (deep[now]>=deep[z]) {
sum+=w[now];
if (deep[now]>c[i]) now=getfa(now,c[i]);
else break;
}
now=y;while (deep[now]>deep[z]) {
sum+=w[now];
if (deep[now]>c[i]) now=getfa(now,c[i]);
else break;
}
// if (y==z) sum-=w[z];
printf("%d\n",sum);
}
}
return 0;
}