P6689 序列

Description

https://www.luogu.com.cn/problem/P6689

Solution

首先观察到如果一个序列的左括号相同,则序列出现的概率相同。

记 $f(i,j)$ 表示已经减了 $i$ 个参数,当前右括号的个数为 $j$。枚举从上一个参数右括号变成左括号的个数,有转移:
$$
f(i,j)=\dfrac{n-j+1}{n}\sum _{k\ge j-1} f(i-1,j-k)\prod _{l=j}^k \dfrac{l}{n}
$$
前缀和优化一下即可做到 $O(NK)$。

考虑快速计算给定序列的最长合法括号子序列。记左括号表示 $1$,右括号表示 $-1$,$s_i$ 表示前 $i$ 个位置的前缀和。记 $x$ 为 $s_i$ 最小的位置,则 $x$ 之前有 $-s_x$ 个右括号不能匹配,之后有 $s_n-s_x$ 个不能匹配。所以其最长子序列为 $N-s_n+2s_x$。

考虑枚举 $s_n,s_x$。考虑转化为网格图。记左括号为 $x+1$,右括号为 $y+1$。满足 $x-s_x\ge y$。用经典的反射容斥做即可。注意起点和终点要在斜线一边,否则会算错。

最后复杂度 $O(N^2+NK)$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#include<bits/stdc++.h>
#define ll long long
#define int long long
#define ull unsigned long long
#define maxn 5005
#define put() putchar('\n')
#define Tp template<typename T>
#define Ts template<typename T,typename... Ar>
using namespace std;
Tp void read(T &x){
int f=1;x=0;char c=getchar();
while (c<'0'||c>'9') {if (c=='-') f=-1;c=getchar();}
while (c>='0'&&c<='9') {x=x*10+c-'0';c=getchar();}
x*=f;
}
namespace Debug{
Tp void _debug(char* f,T t){cerr<<f<<'='<<t<<endl;}
Ts void _debug(char* f,T x,Ar... y){while(*f!=',') cerr<<*f++;cerr<<'='<<x<<",";_debug(f+1,y...);}
#define gdb(...) _debug((char*)#__VA_ARGS__,__VA_ARGS__)
}using namespace Debug;
#define fi first
#define se second
#define mk make_pair
const int mod=998244353;
ll power(ll x,int y=mod-2,int p=mod) {
ll sum=1;x%=p;
while (y) {
if (y&1) sum=sum*x%p;
x=x*x%p;y>>=1;
}
return sum;
}
int n,k;
int suf[maxn],isuf[maxn];
int C(int x,int y) {
if (y<0||x<y) return 0;
return suf[x]*isuf[y]%mod*isuf[x-y]%mod;
}
int ans;
int f[maxn][maxn];
int s[maxn],is[maxn],pre[maxn];
void add(int &x,int y) {x=(x+y)%mod;}
signed main(void){
int i,j;
read(n);read(k);
for (suf[0]=1,i=1;i<=n;i++) suf[i]=suf[i-1]*i%mod;
for (isuf[n]=power(suf[n]),i=n;i>=1;i--) isuf[i-1]=isuf[i]*i%mod;
int invn=power(n);
for (s[0]=1,is[0]=1,i=1;i<=n;i++) s[i]=s[i-1]*i%mod*invn%mod,is[i]=power(s[i]);
f[0][0]=1;
for(i=1;i<=k;i++) {
for (j=n;j>=0;j--) pre[j]=(pre[j+1]+f[i-1][j]*s[j])%mod;
for (j=1;j<=n;j++) f[i][j]=((n-j+1)*invn%mod*is[j-1]%mod*pre[j-1]%mod);
}
int res=0;
for (i=1;i<=n;i++) add(res,f[k][i]);
assert(res==1);
for (i=1;i<=n;i++) {
int X=n-i,Y=i,p=f[k][i]*power(C(n,i))%mod;
for (j=-Y;j<=0;j++) {
if (X-Y>=j) {
int res=C(n,X-(j))-C(n,X-(j-1));
// gdb(i,j,X,Y,res);
// assert(res>=0);
add(ans,res*p%mod*(n-(X-Y)+2*j));
}
}
}
if (ans<0) ans+=mod;
printf("%lld\n",ans);
return 0;
}
//i=begin && g++ $i.cpp -o $i -std=c++14 && ./$i