HDU 7055. Yiwen with Sqc

链接

https://acm.hdu.edu.cn/showproblem.php?pid=7055

题意

记 $sqc(s,i,j,c)$ 为字符串 $s$ 区间 $[i,j]$ 内 ASCII 码为 $c$ 的字符的个数。

求 $\sum_{c=97}^{122}\sum_{i=1}^n\sum_{j=i}^nsqc(s,i,j,c)^2$。

思路一

考虑每个字母对答案的贡献 $ans$,记 $s_i$ 为前 $i$ 个字母中当前字母的个数:

$$\begin{aligned}
ans=&\sum_{i=1}^{n}\sum_{j=i}^n(s_j-s_{i-1})^2\
=&\sum_{i=1}^{n}\sum_{j=i}^ns_j^2+\sum_{i=1}^{n}\sum_{j=i}^ns_{i-1}^2-\sum_{i=1}^{n}\sum_{j=i}^n2s_js_{i-1}\
=&\sum_{i=1}^{n}is_i^2+\sum_{i=1}^{n}(n-i)s_{i}^2-(\sum_{i=1}^ns_i)^2+\sum_{i=1}^n{s_i^2}\
=&(n+1)\sum_{i=1}^n{s_i^2}-(\sum_{i=1}^ns_i)^2
\end{aligned}$$

时间复杂度:$O(26n)$

可以发现对于某个字母出线的位置 $pos$,$[pos_i,pos_i+1)$ 之间的下标的 $s$ 是相同的,因此只要记录每个字母出现的位置即可 $O(n)$ 计算答案。

代码一

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#include <bits/stdc++.h>
#define SZ(x) (int)(x).size()
#define ALL(x) (x).begin(),(x).end()
#define PB push_back
#define EB emplace_back
#define MP make_pair
#define FI first
#define SE second
using namespace std;
typedef double DB;
typedef long double LD;
typedef long long LL;
typedef unsigned long long ULL;
typedef pair<int,int> PII;
typedef vector<int> VI;
typedef vector<PII> VPII;
// head
const LL MOD=998244353;
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int T;cin>>T;
while(T--) {
string s;
cin>>s;
VI pos[26];
for(int i=0;i<SZ(s);i++) {
int x=s[i]-'a';
pos[x].PB(i);
}
LL res=0;
for(int i=0;i<26;i++) if(SZ(pos[i])) {
LL sum1=0,sum2=0;
for(int j=0;j<SZ(pos[i]);j++) {
sum1=(sum1+((j==SZ(pos[i])-1?SZ(s):pos[i][j+1])-pos[i][j])*(j+1)%MOD*(j+1)%MOD)%MOD;
sum2=(sum2+((j==SZ(pos[i])-1?SZ(s):pos[i][j+1])-pos[i][j])*(j+1)%MOD)%MOD;
}
sum1=sum1*(SZ(s)+1)%MOD;
sum2=sum2*sum2%MOD;
res=(res+sum1-sum2+MOD)%MOD;
}
cout<<res<<'\n';
}
return 0;
}

思路二

记录每个字母出现的位置,对于当前要统计答案的字母,记出现了 $cnt$ 次,第 $i$ 次的位置是 $pos_i$。

枚举左端点,左端点在 $(pos_{i},pos_{i+1}]$ 之间的答案相同,记为 $ans_i=\sum_{k=i+1}^{cnt}{len_k*(k-i)^2}$ $(0\le i <cnt)$

记 $dif_i=ans_{i}-ans_{i-1}$:

$$\begin{aligned}
dif_i=&\sum_{k=i+1}^{cnt}{len_k*(k-i)^2}-\sum_{k=i}^{cnt}{len_k*(k-i+1)^2}\
=&\sum_{k=i+1}^{cnt}{len_k*(k-i)^2}-\sum_{k=i+1}^{cnt}{len_k*(k-i+1)^2}-len_i\
=&\sum_{k=i+1}^{cnt}{len_k*(-2k+2i-1)}-len_i\
\end{aligned}$$

记 $diff_i=dif_i-dif_{i-1}$:

$$\begin{aligned}
diff_i=&\sum_{k=i+1}^{cnt}{len_k*(-2k+2i-1)}-\sum_{k=i}^{cnt}{len_k*(-2k+2i-3)}-len_i+len_{i-1}\
=&\sum_{k=i+1}^{cnt}{len_k*(-2k+2i-1)}-\sum_{k=i+1}^{cnt}{len_k*(-2k+2i-3)}+2len_i+len_{i-1}\
=&2\sum_{k=i+1}^{cnt}{len_k}+2len_i+len_{i-1}
\end{aligned}
$$

维护 $diff$,进而维护 $dif$,进而维护 $ans$。

时间复杂度:$O(n)$

代码二

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#include <bits/stdc++.h>
#define SZ(x) (int)(x).size()
#define ALL(x) (x).begin(),(x).end()
#define PB push_back
#define EB emplace_back
#define MP make_pair
#define FI first
#define SE second
using namespace std;
typedef double DB;
typedef long double LD;
typedef long long LL;
typedef unsigned long long ULL;
typedef pair<int,int> PII;
typedef vector<int> VI;
typedef vector<PII> VPII;
// head
const LL MOD=998244353;
const int N=1e5+5;
VI pos[30];
LL len[N],sum[N],diff[N],dif[N],ans[N];
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int T;cin>>T;
while(T--) {
string s;cin>>s;
for(int i=0;i<26;i++) pos[i].clear();
for(int i=0;i<SZ(s);i++) {
int x=s[i]-'a';
pos[x].PB(i);
}
LL res=0;
for(int i=0;i<26;i++) if(SZ(pos[i])) {
for(int j=1;j<SZ(pos[i]);j++) len[j]=pos[i][j]-pos[i][j-1];
len[SZ(pos[i])]=SZ(s)-pos[i][SZ(pos[i])-1];
for(int j=1;j<=SZ(pos[i]);j++) sum[j]=(sum[j-1]+len[j])%MOD;
for(int j=1;j<=SZ(pos[i]);j++) diff[j]=(2*((sum[SZ(pos[i])]-sum[j]+MOD)%MOD)%MOD+2*len[j]%MOD+len[j-1])%MOD;
dif[1]=-len[1];
for(int j=2;j<=SZ(pos[i]);j++) dif[1]=(dif[1]+(len[j]*(-2*j+1)%MOD+MOD)%MOD)%MOD;
for(int j=2;j<=SZ(pos[i]);j++) dif[j]=(diff[j]+dif[j-1])%MOD;
ans[0]=0;
for(int j=1;j<=SZ(pos[i]);j++) ans[0]=(ans[0]+len[j]*j%MOD*j%MOD)%MOD;
for(int j=1;j<SZ(pos[i]);j++) ans[j]=(dif[j]+ans[j-1])%MOD;
len[0]=pos[i][0]+1;
for(int j=0;j<SZ(pos[i]);j++) res=(res+ans[j]*len[j]%MOD)%MOD;
len[0]=0;
}
cout<<res<<'\n';
}
return 0;
}