1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
| #include<bits/stdc++.h> #define ll long long #define int long long #define ull unsigned long long #define maxn 5005 #define put() putchar('\n') #define Tp template<typename T> #define Ts template<typename T,typename... Ar> using namespace std; Tp void read(T &x){ int f=1;x=0;char c=getchar(); while (c<'0'||c>'9') {if (c=='-') f=-1;c=getchar();} while (c>='0'&&c<='9') {x=x*10+c-'0';c=getchar();} x*=f; } namespace Debug{ Tp void _debug(char* f,T t){cerr<<f<<'='<<t<<endl;} Ts void _debug(char* f,T x,Ar... y){while(*f!=',') cerr<<*f++;cerr<<'='<<x<<",";_debug(f+1,y...);} #define gdb(...) _debug((char*)#__VA_ARGS__,__VA_ARGS__) }using namespace Debug; #define fi first #define se second #define mk make_pair const int mod=998244353; ll power(ll x,int y=mod-2,int p=mod) { ll sum=1;x%=p; while (y) { if (y&1) sum=sum*x%p; x=x*x%p;y>>=1; } return sum; } int n,k; int suf[maxn],isuf[maxn]; int C(int x,int y) { if (y<0||x<y) return 0; return suf[x]*isuf[y]%mod*isuf[x-y]%mod; } int ans; int f[maxn][maxn]; int s[maxn],is[maxn],pre[maxn]; void add(int &x,int y) {x=(x+y)%mod;} signed main(void){ int i,j; read(n);read(k); for (suf[0]=1,i=1;i<=n;i++) suf[i]=suf[i-1]*i%mod; for (isuf[n]=power(suf[n]),i=n;i>=1;i--) isuf[i-1]=isuf[i]*i%mod; int invn=power(n); for (s[0]=1,is[0]=1,i=1;i<=n;i++) s[i]=s[i-1]*i%mod*invn%mod,is[i]=power(s[i]); f[0][0]=1; for(i=1;i<=k;i++) { for (j=n;j>=0;j--) pre[j]=(pre[j+1]+f[i-1][j]*s[j])%mod; for (j=1;j<=n;j++) f[i][j]=((n-j+1)*invn%mod*is[j-1]%mod*pre[j-1]%mod); } int res=0; for (i=1;i<=n;i++) add(res,f[k][i]); assert(res==1); for (i=1;i<=n;i++) { int X=n-i,Y=i,p=f[k][i]*power(C(n,i))%mod; for (j=-Y;j<=0;j++) { if (X-Y>=j) { int res=C(n,X-(j))-C(n,X-(j-1)); add(ans,res*p%mod*(n-(X-Y)+2*j)); } } } if (ans<0) ans+=mod; printf("%lld\n",ans); return 0; }
|