HDU 2243. 考研路茫茫——单词情结

链接

http://acm.hdu.edu.cn/showproblem.php?pid=2243

题意

字符集为小写字母,有 $n$ $(n<6)$ 个模式串,长度不超过 $5$,求长度不超过 $m$ $(m<2^{31})$ 至少包含一个模式串的字符串个数。

思路

建立 AC 自动机,判断每个节点的后缀是否是模式串。

先求长度不超过 $m$ 不包含模式串的字符串个数,记 $f_i$ 为长度不超过 $i$ 的字符串个数:

$f_i=f_{i-1}*26+26
\Longrightarrow \begin{bmatrix}f_i & 1\end{bmatrix}=\begin{bmatrix}f_{i-1} & 1\end{bmatrix}*\begin{bmatrix}26 & 0 \ 26 & 1\end{bmatrix}
\Longrightarrow f_i=(\begin{bmatrix}26 & 1\end{bmatrix}\begin{bmatrix}26 & 0 \ 26 & 1\end{bmatrix}^{i-1})_{0,0}$

记 $dp_{i,j}$ 为在 Trie 上从根节点出发走了不超过 $i$ 步到达节点 $j$ 且没有经过模式串的方案数,等价于长度不超过 $i$ 以节点 $j$ 代表的字符串为后缀且没有经过模式串的字符串个数。

记 $a_{i,j}$ 为从 Trie 上节点 $i$ 走一步到达节点 $j$ 的方案数,Trie 除根节点外有 $k$ 个节点:

$\begin{aligned}
& dp_{i,j}=dp_{i-1,0}*a_{0,j}+dp_{i-1,1}*a_{1,j}+\cdots + dp_{i-1,k}*a_{k,j}+a_{0,j} \
\Longrightarrow & \begin{bmatrix}dp_{i,0} & dp_{i,1} & \cdots & dp_{i,k} & 1\end{bmatrix}
=\begin{bmatrix}dp_{i,0} & dp_{i,1} & \cdots & dp_{i,k} & 1\end{bmatrix}
*\begin{bmatrix}a_{0,0} & a_{0,1} & \cdots & a_{0,k} & 0 \
a_{1,0} & a_{1,1} & \cdots & a_{1,k} & 0 \
\vdots & \vdots & \ddots & \vdots & \vdots \
a_{k,0} & a_{k,1} & \cdots & a_{k,k} & 0 \
a_{0,0} & a_{0,1} & \cdots & a_{0,k} & 1
\end{bmatrix} \
\Longrightarrow & \begin{bmatrix}dp_{i,0} & dp_{i,1} & \cdots & dp_{i,k} & 1\end{bmatrix}
= \begin{bmatrix}a_{0,0} & a_{0,1} & \cdots & a_{0,k} & 1\end{bmatrix}
*\begin{bmatrix}a_{0,0} & a_{0,1} & \cdots & a_{0,k} & 0 \
a_{1,0} & a_{1,1} & \cdots & a_{1,k} & 0 \
\vdots & \vdots & \ddots & \vdots & \vdots \
a_{k,0} & a_{k,1} & \cdots & a_{k,k} & 0 \
a_{0,0} & a_{0,1} & \cdots & a_{0,k} & 1
\end{bmatrix}^{i-1}
\end{aligned}$

最终答案为 $\sum_{i=0}^{k}dp_{m,i}$。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#include <bits/stdc++.h>
#define SZ(x) (int)(x).size()
#define ALL(x) (x).begin(),(x).end()
#define PB push_back
#define EB emplace_back
#define MP make_pair
#define FI first
#define SE second
using namespace std;
typedef double DB;
typedef long double LD;
typedef long long LL;
typedef unsigned long long ULL;
typedef pair<int,int> PII;
typedef vector<int> VI;
typedef vector<PII> VPII;
// head
const int N=35;
int tr[N][26],st[N],fail[N],cnt;
struct M {
ULL a[N][N];
void clear() {memset(a,0,sizeof a);}
M() {clear();}
M operator * (const M &T) const {
M ret;
for(int i=0;i<=cnt;i++)
for(int k=0;k<=cnt;k++) {
ULL t=a[i][k];
if(!t) continue;
for(int j=0;j<=cnt;j++) {
if(!T.a[k][j]) continue;
ret.a[i][j]+=T.a[k][j]*t;
}
}
return ret;
}
M operator ^ (ULL b) const {
M ret,bas;
for(int i=0;i<=cnt;i++) ret.a[i][i]=1;
for(int i=0;i<=cnt;i++)
for(int j=0;j<=cnt;j++)
bas.a[i][j]=a[i][j];
while(b) {
if(b&1) ret=ret*bas;
bas=bas*bas;
b>>=1;
}
return ret;
}
}a;
void init() {
for(int i=0;i<=cnt;i++) for(int j=0;j<26;j++) tr[i][j]=0;
for(int i=0;i<=cnt;i++) st[i]=fail[i]=0;
cnt=0;
}
void insert(const string &s) {
int u=0;
for(auto c:s) {
int v=c-'a';
if(!tr[u][v]) tr[u][v]=++cnt;
u=tr[u][v];
}
st[u]=1;
}
void build() {
queue<int> q;
for(int i=0;i<26;i++) if(tr[0][i]) q.push(tr[0][i]);
while(SZ(q)) {
int u=q.front();
q.pop();
for(int v=0;v<26;v++) {
if(tr[u][v]) fail[tr[u][v]]=tr[fail[u]][v],q.push(tr[u][v]);
else tr[u][v]=tr[fail[u]][v];
}
st[u]|=st[fail[u]];
}
}
void get_matrix() {
a.clear();
for(int i=0;i<=cnt;i++) {
if(st[i]) continue;
for(int j=0;j<26;j++) {
if(st[tr[i][j]]) continue;
++a.a[i][tr[i][j]];
}
}
++cnt;
for(int i=0;i<cnt;i++) a.a[cnt][i]=a.a[0][i];
a.a[cnt][cnt]=1;
}
ULL qpow(ULL a,ULL b) {
ULL ret=1;
while(b) {
if(b&1) ret*=a;
a*=a;
b>>=1;
}
return ret;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n;
ULL m;
M t;
t.a[0][0]=26,t.a[0][1]=0;
t.a[1][0]=26,t.a[1][1]=1;
while(cin>>n>>m) {
init();
for(int i=1;i<=n;i++) {
string s;
cin>>s;
insert(s);
}
build();
M tt;
tt.a[0][0]=26,tt.a[0][1]=1;
tt=tt*(t^(m-1));
ULL res=tt.a[0][0];
get_matrix();
M b;
for(int i=0;i<cnt;i++) b.a[0][i]=a.a[0][i];
b.a[0][cnt]=1;
b=b*(a^(m-1));
for(int i=0;i<cnt;i++) res-=b.a[0][i];
cout<<res<<'\n';
}
return 0;
}