P4007 小 Y 和恐怖的奴隶主

小 Y 和恐怖的奴隶主

Description

https://www.luogu.com.cn/problem/P4007

Solution

看到数据范围想到矩阵乘法。考虑递推,知道状态的概率,再维护期望。

但是很多状态是重复的。例如 $m=3$ 只需要维护 $(a,b,c)$ 表示血量为 $1,2,3$ 的分别有 $a,b,c$ 个,其中 $a+b+c\le K$。搜出来大概不到两百个状态,转移矩阵直接暴力搞出来。设状态数为 $w$。

然后查询的时候预处理出 $2^i$ 的矩阵,矩阵乘向量是 $O(w^2)$ 的,总的复杂度就是 $O(w^3\log n+Tw^2\log n)$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#include<bits/stdc++.h>
#define int long long
#define ull unsigned long long
#define maxn 205
const int mod=998244353;
int base,m,K;
int id[5005],cnt;
int s[5],tot,inv[15],p[maxn];
struct Mat {
int a[maxn][maxn];
void clear(void) {
memset(a,0,sizeof(a));
for (int i=1;i<=base;i++) a[i][i]=1;
}
Mat operator *(const Mat &x) const {
int i,j,k;Mat ans;memset(ans.a,0,sizeof(ans.a));
for (k=1;k<=base;k++) {
for (i=1;i<=base;i++) {
for (j=1;j<=base;j++)
ans.a[i][j]=(ans.a[i][j]+a[i][k]*x.a[k][j])%mod;
}
}
return ans;
}
void print(void) {
printf("size = %d\n",base);
int i,j;
for (i=1;i<=base;i++,put()) for (j=1;j<=base;j++) printf("%d ",a[i][j]);put();
}
}g,pw[61],ans;
int calc(int stat) {
return (stat%10)+(stat/10%10)+(stat/100%10);
}
int cs(int *s) {
return s[3]*100+s[2]*10+s[1];
}
void dfs(int now,int sum) {
if (now==m+1) {
int stat=0,i;
for (i=1;i<=m;i++) stat=stat*10+s[i];
if (!id[stat]) id[stat]=++cnt,p[cnt]=stat;
return ;
}
for (int i=0;i<=K-sum;i++) s[now]=i,dfs(now+1,sum+i);
}
void add(int &x,int y) {x=(x+y)%mod;}
signed main(void){
int T,stat,i,j,l,n;
read(T);read(m);read(K);
dfs(1,0);
for (inv[0]=1,i=1;i<=10;i++) inv[i]=power(i,mod-2);
for (i=1;i<=cnt;i++) {
int tmp=p[i];
s[3]=tmp/100%10,s[2]=tmp/10%10,s[1]=tmp%10;
int nums=s[1]+s[2]+s[3];
for (j=1;j<=m;j++) if (s[j]>0) {
s[j]--;s[j-1]++;if (nums<K&&j>=2) s[m]++;
stat=cs(s);
s[j]++;s[j-1]--;if (nums<K&&j>=2) s[m]--;
g.a[i][id[stat]]+=(s[j])*power(nums+1)%mod;
}
stat=cs(s);
int pus=power(nums+1);
g.a[i][id[stat]]+=power(nums+1)%mod;
add(g.a[i][cnt+1],power(nums+1));
}
g.a[cnt+1][cnt+1]=1;

base=cnt+1;
int now=100;
if (m==1) now=1;else if (m==2) now=10;

pw[0]=g;
for (i=1;i<=60;i++) pw[i]=pw[i-1]*pw[i-1];
while (T--) {
read(n);
memset(ans.a,0,sizeof(ans.a));
ans.a[1][id[now]]=1;
for (l=60;l>=0;l--) if ((n>>l)&1) {
for (i=1;i<=base;i++) g.a[1][i]=0;
for (i=1;i<=base;i++)
for (j=1;j<=base;j++)
add(g.a[1][j],ans.a[1][i]*pw[l].a[i][j]);
for (i=1;i<=base;i++) ans.a[1][i]=g.a[1][i];
}
printf("%lld\n",ans.a[1][base]);
}
return 0;
}