AC automata + dp

preface

I wrote the topic of AC automata and found that I really like the topic of AC automata + dp. And the dp state is usually d p [ i ] [ j ] dp[i][j] dp[i][j] indicates that the string length has been i i i. Then it matches the of automata j j j node. Write a summary here

P4052 [JSOI2007] text generator

Title Link: [JSOI2007] text generator
General idea of the title: given n n n words, find out how many words are long m m m's article contains at least one word.
Data range: 1 ≤ n ≤ 60 , 1 ≤ m ≤ 100 , 1 ≤ ∣ s i ∣ ≤ 100 . 1≤n≤60,1 \leq m \le100 ,1≤|s_i|≤100. 1≤n≤60,1≤m≤100,1≤∣si​∣≤100.
Solution: AC automata + dp. This question was put into my collection when it was a konjaku last year. It seems that it is not konjaku now. But I haven't written it (mainly because it's too delicious). Let's analyze it. First of all, it's easy to think of tolerance and exclusion a n s = to less one individual s i → 2 6 m − ( one individual all no package contain ) ans = at least one s_i\to 26^m - (none included) ans = at least one si → 26m − (none included). Now the problem becomes how to find the exclusion s i s_i The number of strings of si #. consider d p dp dp, d p [ i ] [ j ] dp[i][j] dp[i][j] indicates that the string length has been i i i. Then it matches the of automata j j Number of strings of j nodes. Want to d p dp Another problem that needs to be solved in the transfer of dp is that some nodes cannot be reached, because the string required by the symbol cannot contain s i s_i si, so we need to preprocess it. We're here b u i l d build build function. The rest is recursion. Look at the code.
AC Code:

#include<bits/stdc++.h>

#define ld long double
#define ll long long
using namespace std;
template<class T>
void read(T& x)
{
	T res = 0, f = 1; char c = getchar();
	while (!isdigit(c)) {
		if (c == '-')f = -1; c = getchar();
	}
	while (isdigit(c)) {
		res = (res << 3) + (res << 1) + c - '0'; c = getchar();
	}
	x = res * f;
}
#define int long long
const ll N = 10000 + 10;
const int mod = 1e4 + 7;
int nxt[N][26],tot,pd[N],fail[N],dp[105][N];
int n,m;
void insert(char* s)
{
	int n = strlen(s);
	int now = 0;
	for (int i = 0; i < n; i++)
	{
		if (!nxt[now][s[i] - 'A'])
			nxt[now][s[i] - 'A'] = ++tot;
		now=nxt[now][s[i] - 'A'];
	}
	pd[now] = 1;
}
void build()
{
	queue<int>pls;
	for (int i = 0; i < 26; i++)if (nxt[0][i])pls.push(nxt[0][i]), fail[nxt[0][i]] = 0;
	while (pls.size())
	{
		int f = pls.front(); pls.pop();
		for (int i = 0; i < 26; i++)
		{
			if (nxt[f][i])
			{
				fail[nxt[f][i]] = nxt[fail[f]][i];
				pd[nxt[f][i]] |= pd[nxt[fail[f]][i]];
				pls.push(nxt[f][i]);
			}
			else
				nxt[f][i] = nxt[fail[f]][i];
		}
	}
}
char s[N];
int fpow(int x, int y)
{
	int ans = 1;
	while (y)
	{
		if (y & 1)ans = ans * x % mod;
		x = x * x % mod; y >>= 1;
	}
	return ans;
}
signed main()
{
	//ios::sync_with_stdio(false);
#ifndef ONLINE_JUDGE
	freopen("test.in", "r", stdin);
#endif // ONLINE_JUDGE
	read(n), read(m);
	for (int i = 1; i <= n; i++)
	{
		scanf("%s", s); insert(s);
	}
	build();
	dp[0][0] = 1;//dp[i][j] indicates that it has been long I and matched to node j
	for (int i = 1; i <= m; i++)
	{
		for (int j = 0; j <= tot; j++)
		{
			for (int nx = 0; nx < 26; nx++)
				if (!pd[nxt[j][nx]])
					dp[i][nxt[j][nx]] = (dp[i][nxt[j][nx]] + dp[i - 1][j]) % mod;
		}
	}
	int ans = fpow(26, m);
	for (int j = 0; j <= tot; j++)ans = (ans - dp[m][j]) % mod;
	printf("%lld\n", (ans % mod + mod) % mod);
	return 0;
}

P3041 [USACO12JAN]Video Game G

Title Link: [USACO12JAN]Video Game G
General idea of the title: given n n n words, find out how many words are long k k How many times does k's article match words.
Data range: 1 ≤ n ≤ 20 , 1 ≤ k ≤ 1 0 3 , 1 ≤ ∣ s i ∣ ≤ 15 1≤n≤20,1 \leq k \leq 10^3,1≤∣s_i∣≤15 1≤n≤20,1≤k≤103,1≤∣si​∣≤15
Solution: it's the same as the above question, but you don't have to deal with those points and can't transfer them to. then d p dp dp becomes demand m a x max max. actually d p dp dp on two questions, d p dp Initialization of dp array and transfer of equation of state.
AC Code:

#include<bits/stdc++.h>

#define ld long double
#define ll long long
using namespace std;
template<class T>
void read(T& x)
{
	T res = 0, f = 1; char c = getchar();
	while (!isdigit(c)) {
		if (c == '-')f = -1; c = getchar();
	}
	while (isdigit(c)) {
		res = (res << 3) + (res << 1) + c - '0'; c = getchar();
	}
	x = res * f;
}
#define int long long
const ll N = 1000 + 10;
const int mod = 1e9 + 7;
int n, k;
int nxt[N][3], fail[N], val[N],tot;
void insert(char* s)
{
	int now = 0;
	int n = strlen(s);
	for (int i = 0; i < n; i++)
	{
		if (!nxt[now][s[i] - 'A'])nxt[now][s[i] - 'A'] = ++tot;
		now = nxt[now][s[i] - 'A'];
	}
	val[now]++;
}
void build()
{
	queue<int>pls;
	for (int i = 0; i < 3; i++)if (nxt[0][i])pls.push(nxt[0][i]);
	while (pls.size())
	{
		int f = pls.front(); pls.pop();
		for (int i = 0; i < 3; i++)
		{
			if (nxt[f][i])
			{
				fail[nxt[f][i]] = nxt[fail[f]][i];
				val[nxt[f][i]] += val[nxt[fail[f]][i]];
				pls.push(nxt[f][i]);
			}
			else
				nxt[f][i] = nxt[fail[f]][i];
		}
	}
}
char s[N];
int dp[1005][N];
signed main()
{
	//ios::sync_with_stdio(false);
#ifndef ONLINE_JUDGE
	freopen("test.in", "r", stdin);
#endif // ONLINE_JUDGE
	read(n), read(k);
	for (int i = 1; i <= n; i++)
	{
		scanf("%s", s); insert(s);
	}
	build();
	memset(dp, -0x3f, sizeof(dp));
	dp[0][0] = 0;
	for (int i = 1; i <= k; i++)
	{
		for (int j = 0; j <= tot; j++)
		{
			for (int nx = 0; nx < 3; nx++)
			{
				dp[i][nxt[j][nx]] =max(dp[i][nxt[j][nx]], dp[i - 1][j] + val[nxt[j][nx]]);
			}
		}
	}
	int ans = 0;
	for (int j = 0; j <= tot; j++)ans = max(ans, dp[k][j]);
	printf("%lld\n", ans);
	return 0;
}

P3311 [SDOI2014] count

Title Link: [SDOI2014] count
General idea of the title: given m m m number strings, and find out that there is no more than n n How many numbers of n do not contain the given number string.
Data range: 1 ≤ n < 1 0 1201 , 1 ≤ m ≤ 100 , 1 ≤ ∑ i = 1 m ∣ s i ∣ ≤ 1500 1≤n<10^{1201},1 \leq m \leq 100,1 \leq \sum_{i = 1}^m |s_i| \leq 1500 1≤n<101201,1≤m≤100,1≤∑i=1m​∣si​∣≤1500
Solution: AC automata + Digital dp. A very good question with a lot of details. Mainly explain d f s dfs dfs function, c u r cur cur indicates how many bits are left at present, l i m i t limit limit indicates whether it is against the upper bound of the digit, p o s pos pos indicates the current position of the automaton_ 0 indicates whether there is a leading zero. See the code for specific details.
AC Code:

#include<bits/stdc++.h>
#define ld long double
#define ll long long
using namespace std;
template<class T>
void read(T& x)
{
	T res = 0, f = 1; char c = getchar();
	while (!isdigit(c)) {
		if (c == '-')f = -1; c = getchar();
	}
	while (isdigit(c)) {
		res = (res << 3) + (res << 1) + c - '0'; c = getchar();
	}
	x = res * f;
}
#define int long long
const ll N = 200000 + 10;
const int mod = 1e9 + 7;
int nxt[N][10], fail[N],tot, pd[N];
void insert(char* s)
{
	int now = 0, n = strlen(s);
	for (int i = 0; i < n; i++)
	{
		if (!nxt[now][s[i] - '0'])
			nxt[now][s[i] - '0'] = ++tot;
		now = nxt[now][s[i] - '0'];
	}
	pd[now] = 1;
}
void build()
{
	queue<int>pls;
	for (int i = 0; i < 10; i++)if (nxt[0][i])pls.push(nxt[0][i]);
	while (pls.size())
	{
		int f = pls.front(); pls.pop();
		for (int i = 0; i < 10; i++)
		{
			if (nxt[f][i])
			{
				fail[nxt[f][i]] = nxt[fail[f]][i];
				pd[nxt[f][i]] |= pd[nxt[fail[f]][i]];//Mark it to indicate that it cannot be reached
				pls.push(nxt[f][i]);
			}
			else
				nxt[f][i] = nxt[fail[f]][i];
		}
	}
}
int dp[1205][1505],w[1205],m;
int dfs(int cur, int limit, int pos,int _0)
{
	if (!cur)return 1;
	if (!limit&&!_0 && ~dp[cur][pos])return dp[cur][pos];
	int tp = limit ? w[cur]:9;
	int ans = 0;
	for (int i = 0; i <= tp; i++)
	{
		if (!pd[nxt[pos][i]]||(i==0&&_0))//Can go
		{
			ans = (ans + dfs(cur - 1, limit && i == tp, (i == 0 && _0) ? 0 : nxt[pos][i], (i == 0 && _0))) % mod;
		}
	}

	if (!limit && !_0)dp[cur][pos] = ans;
	return (ans % mod + mod)%mod;
}
char s[N];
signed main()
{
	//ios::sync_with_stdio(false);
#ifndef ONLINE_JUDGE
	freopen("test.in", "r", stdin);
#endif // ONLINE_JUDGE
	memset(dp, -1, sizeof(dp));
	scanf("%s", s + 1);
	w[0] = strlen(s + 1);
	reverse(s + 1, s + 1 + w[0]);
	for (int i = 1; i <= w[0]; i++)w[i] = s[i] - '0';
	read(m);
	
	for (int i = 1; i <= m; i++)
	{
		scanf("%s", s + 1);
		insert(s + 1);
	}
	build();

	printf("%lld\n",dfs(w[0], 1, 0, 1)-1);
	return 0;
}

Posted by robogenus on Mon, 18 Apr 2022 15:28:43 +0930