CF1641D Two Arrays

Description

https://www.luogu.com.cn/problem/CF1641D

Solution

有一个根号分治的 bitset 做法,感觉有点丑陋,就不说了。

考虑怎么优雅的判两个集合的交集。

考虑容斥。取两个集合的所有子集,如果相同且子集大小奇数就 $+1$,偶数就 $-1$。这样如果集合有交集则为 $1$,否则为 $0$。

记两个集合的交集为 $s$,则贡献是 $\sum (-1)^{i+1}\dbinom{s}{i}$,由二项式定理可以推出来。

有了这个东西就可以 $O(2^mm)$ 做一个集合的前缀有多少个无交。

然后先将 $w$ 排个序,在对于每个右指针,找到最近的左指针使得无交。然后更新答案。

观察到如果要对答案有贡献,右指针向右单调时,左指针也必须向左边单调。所以两个指针的移动是 $O(n)$ 的。

最后复杂度是 $O(n2^mm)$,判相同这里用的哈希表 gp_hash_table,感觉远快于 unordered_map

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#include<bits/stdc++.h>
#include<ext/pb_ds/assoc_container.hpp>
#include<ext/pb_ds/hash_policy.hpp>
#define int unsigned long long
#define ull unsigned long long
#define maxn 100005
#define put() putchar('\n')
#define LL __int128
#define Tp template<typename T>
#define Ts template<typename T,typename... Ar>
using namespace std;
Tp void read(T &x){
int f=1;x=0;char c=getchar();
while (c<'0'||c>'9') {if (c=='-') f=-1;c=getchar();}
while (c>='0'&&c<='9') {x=x*10+c-'0';c=getchar();}
x*=f;
}
namespace Debug{
Tp void _debug(char* f,T t){cerr<<f<<'='<<t<<endl;}
Ts void _debug(char* f,T x,Ar... y){while(*f!=',') cerr<<*f++;cerr<<'='<<x<<",";_debug(f+1,y...);}
#define gdb(...) _debug((char*)#__VA_ARGS__,__VA_ARGS__)
}using namespace Debug;
#define fi first
#define se second
#define mk make_pair
int n,m;
const int inf=2e9,base=1e9+7,base2=998244353;
int a[maxn][6],w[maxn],id[maxn];

int ans=2e9+5;
__gnu_pbds::gp_hash_table<ull, int> mp;
// map<ull,int>mp;
bool cmp(int x,int y) {return w[x]<w[y];}
void add(int id,int flag) {
int i,j;
for (i=1;i<(1<<m);i++) {
int cnt=0;
ull res=0,res2=0;
for (j=0;j<m;j++) if ((i>>j)&1) {
res=(1ll*res*base+a[id][j]);
res2=(1ll*res*base2+a[id][j]);
cnt++;
}
if ((flag==1)==(cnt&1)) mp[res*res2]+=1;
else mp[res*res2]-=1;
}
}
int query(int id) {
int i,j,ans=0;
for (i=1;i<(1<<m);i++) {
int cnt=0;
ull res=0,res2=0;
for (j=0;j<m;j++) if ((i>>j)&1) {
res=(1ll*res*base+a[id][j]);
res2=(1ll*res*base2+a[id][j]);
cnt++;
}
ans=ans+mp[res*res2];
}
return ans;
}
signed main(void){
int i,j;
read(n);read(m);
for (i=1;i<=n;i++) {
for (j=0;j<m;j++) read(a[i][j]);
sort(a[i],a[i]+m);
read(w[i]);id[i]=i;
}
sort(id+1,id+1+n,cmp);
for (i=1;i<n;i++) assert(w[id[i]]<=w[id[i+1]]);
int r=1;
while (r<=n) {
if (query(id[r])==r-1) add(id[r],1),r++;
else break;
}
if (r==n+1) return puts("-1"),0;
int l=r-1;
for (;r<=n;r++) {
if (!l) break;
if (query(id[r])==l) continue;
while (l&&query(id[r])<l) add(id[l],-1),l--;
ans=min(ans,w[id[r]]+w[id[l+1]]);
}
printf("%llu\n",ans);
return 0;
}
//i=begin && g++ $i.cpp -o $i -std=c++14 && ./$i