Skip to content

AC自动机#

目录#


1. 模板#

#include <iostream>
#include <queue>
#include <string.h>
using namespace std;
const int maxN = 5e5 + 1; // 模式串长度 * 模式串数量
const int maxM = 1e6 + 1; // 目标串长度
const int chSize = 26; // 字符集大小

class AcAutomaton {
public:
    int trie[maxN][chSize];
    int vis[maxN], fail[maxN];
    int tot;
    // 初始化
    void init()
    {
        memset(vis, 0, sizeof vis);
        memset(trie, 0, sizeof trie);
        tot = 0;
    }
    // 插入
    void insert(char* str)
    {
        int len = strlen(str);
        int pos = 0;
        for (int i = 0; i < len; i++) {
            int c = str[i] - 'a';
            if (!trie[pos][c])
                trie[pos][c] = ++tot;
            pos = trie[pos][c];
        }
        vis[pos]++;
    }
    // DFS获取Fail数组
    void build()
    {
        queue<int> q;
        // 根结点下的元素入队
        for (int i = 0; i < chSize; i++) {
            if (trie[0][i]) {
                fail[trie[0][i]] = 0;
                q.push(trie[0][i]);
            }
        }
        // DFS
        while (!q.empty()) {
            // 获取队首元素
            int pos = q.front();
            // 输出队首元素
            // cout << "front: " << pos << endl;
            // 队首出队
            q.pop();
            // 遍历队首元素的子节点
            for (int i = 0; i < chSize; i++) {
                // 若元素ch存在
                if (trie[pos][i]) {
                    // 将fail数组的值置为fail[队首元素]的下一个ch对应的值
                    fail[trie[pos][i]] = trie[fail[pos]][i];
                    // 输出构建fail数组的过程
                    // cout << "True char: " << (char)('a' + i) << " fail: " << trie[fail[pos]][i] << endl;
                    // 元素入队
                    q.push(trie[pos][i]);
                } else {
                    // 若不存在,将trie的值置为fail[队首元素]的下一个ch对应的值
                    // 方便下次计算
                    trie[pos][i] = trie[fail[pos]][i];
                }
            }
        }
    }
    // 查询
    int query(char* str)
    {
        int len = strlen(str);
        int pos = 0, ans = 0;
        for (int i = 0; i < len; i++) {
            int c = str[i] - 'a';
            pos = trie[pos][c];
            for (int j = pos; j && vis[j] != -1; j = fail[j]) {
                ans += vis[j];
                vis[j] = -1;
            }
        }
        return ans;
    }
};

2. 题目#

AC自动机类题目博客专栏

Back to top