1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
| #include<bits/stdc++.h> #define int long long #define ull unsigned long long #define maxn 500005 #define put() putchar('\n') #define Tp template<typename T> #define Ts template<typename T,typename... Ar> using namespace std; Tp void read(T &x){ int f=1;x=0;char c=getchar(); while (c<'0'||c>'9') {if (c=='-') f=-1;c=getchar();} while (c>='0'&&c<='9') {x=x*10+c-'0';c=getchar();} x*=f; } namespace Debug{ Tp void _debug(char* f,T t){cerr<<f<<'='<<t<<endl;} Ts void _debug(char* f,T x,Ar... y){while(*f!=',') cerr<<*f++;cerr<<'='<<x<<",";_debug(f+1,y...);} #define gdb(...) _debug((char*)#__VA_ARGS__,__VA_ARGS__) }using namespace Debug; #define fi first #define se second #define mk make_pair const int mod=1e9+7; int power(int x,int y=mod-2) { int sum=1; while (y) { if (y&1) sum=sum*x%mod; x=x*x%mod;y>>=1; } return sum; } int n,k; vector<int>to[maxn]; int c[maxn],f[maxn],g[maxn]; int suf[maxn],isuf[maxn],siz[maxn]; int Ans[maxn],nums[maxn],son[maxn]; void add(int &x,int y) {x=(x+y)%mod;} int C(int x,int y) {return suf[x]*isuf[y]%mod*isuf[x-y]%mod;} int iC(int x,int y) {return isuf[x]*suf[y]%mod*suf[x-y]%mod;} void dfs(int x,int pre) { siz[x]=1;f[x]=1; for (auto y:to[x]) if (y^pre) { dfs(y,x); f[x]=f[x]*f[y]%mod*C(siz[x]+siz[y]-1,siz[y])%mod; siz[x]+=siz[y]; if (!son[x]||siz[son[x]]<siz[y]) son[x]=y; } } void dfs2(int x,int pre) { for (auto y:to[x]) if (y^pre) { int tmp=f[x]*power(f[y]*C(siz[x]-1,siz[y])%mod)%mod; g[y]=g[x]*tmp%mod*C(n-siz[y]-1,n-siz[x])%mod;
dfs2(y,x); } } int vis[maxn]; vector<int>t[maxn]; void update(int x,int pre) { int flag=0; if (!vis[c[x]]) flag=1,vis[c[x]]=1,t[c[x]].push_back(x); for (auto y:to[x]) if (y^pre) update(y,x); if (flag) vis[c[x]]=0; } void clear(int x,int pre) { t[c[x]].clear(); for (auto y:to[x]) if (y^pre) clear(y,x); } void solve(int x,int pre) { for (auto y:to[x]) if (y!=pre&&y!=son[x]) solve(y,x),clear(y,x); if (son[x]) solve(son[x],x); for (auto y:to[x]) if (y!=pre&&y!=son[x]) update(y,x); t[c[x]].clear();t[c[x]].push_back(x); int cnt=0,tot=1,i=c[pre]; for (auto y:t[i]) cnt+=siz[y],tot=tot*f[y]%mod*C(cnt,siz[y])%mod; cnt+=n-siz[x];tot=tot*g[x]%mod*C(cnt,n-siz[x])%mod;
if (Ans[i]>cnt) Ans[i]=cnt,nums[i]=tot; else if (Ans[i]==cnt&&cnt!=n) Ans[i]=cnt,add(nums[i],tot); } void calc(void) { int i; nums[1]=0; for (i=1;i<=n;i++) add(nums[1],f[i]*g[i]%mod*C(n-1,n-siz[i])%mod); printf("%lld\n",nums[1]); } signed main(void){ int i,x,y; read(n);read(k); for (i=1;i<=n;i++) read(c[i]),nums[i]=1,Ans[i]=1e9; for (i=1;i<n;i++) { read(x),read(y); to[x].push_back(y); to[y].push_back(x); } for (suf[0]=1,i=1;i<=n;i++) suf[i]=suf[i-1]*i%mod; for (isuf[n]=power(suf[n]),i=n;i>=1;i--) isuf[i-1]=isuf[i]*i%mod; dfs(1,0); g[1]=1; dfs2(1,0); if (k==1) return calc(),0; solve(1,0);
for (i=1;i<=k;i++) if (i!=c[1]) { int cnt=0,tot=1; for (auto y:t[i]) cnt+=siz[y],tot=tot*f[y]%mod*C(cnt,siz[y])%mod;
if (Ans[i]>cnt) Ans[i]=cnt,nums[i]=tot; else if (Ans[i]==cnt&&cnt!=n) Ans[i]=cnt,add(nums[i],tot); }
for (i=1;i<=k;i++) printf("%lld\n",nums[i]); return 0; }
|