POJ 2778 DNA Sequence (AC automata + matrix fast power) solution

Keywords: PHP

Meaning: given m pattern strings, you are required to construct a main string with a length of n (n < = 2000000000). The main string does not contain pattern strings. Ask how many such main strings are.

Train of thought: because it does not include pattern string, it is also ac automata obviously. Because n is very big, it's not good to use dp.

i n graph theory, if we know the adjacency matrix A of a graph, $a {i j} $= 1 means that there is a path from step i to step J, then $a {ij} $in $A^n $is the number of paths from step i to step J in this graph.

So with ac automata, we can create an adjacency matrix A of all suffixes. Then we can find all the path numbers with the fast power of the matrix $A^n $, $\ sum {I = 1} ^ n a {0I} $is all the way from root to all the feasible suffixes.

Code:

#include<cmath>
#include<set>
#include<map>
#include<queue>
#include<cstdio>
#include<vector>
#include<cstring>
#include <iostream>
#include<algorithm>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int maxn = 100 + 5;
const int M = 50 + 5;
const ull seed = 131;
const double INF = 1e20;
const int MOD = 100000;
int m, tn;
ll n;
struct Mat{
    ll s[maxn][maxn];
};
Mat mul(Mat &a, Mat &b){
    Mat t;
    memset(t.s, 0, sizeof(t.s));
    for(int i = 0; i < tn; i++){
        for(int j = 0; j < tn; j++){
            for(int k = 0; k < tn; k++){
                t.s[i][j] = (t.s[i][j] + a.s[i][k] * b.s[k][j])%MOD;
            }
        }
    }
    return t;
}
Mat ppow(Mat a, ll b){
    Mat ret;
    memset(ret.s, 0, sizeof(ret.s));
    for(int i = 0; i < maxn; i++) ret.s[i][i] = 1;
    while(b){
        if(b & 1) ret = mul(ret, a);
        a = mul(a, a);
        b >>= 1;
    }
    return ret;
}
int id(char a){
    if(a == 'A') return 0;
    if(a == 'T') return 1;
    if(a == 'C') return 2;
    if(a == 'G') return 3;
}
struct Aho{
    struct state{
        int next[4];
        int fail, cnt;
    }node[maxn];
    int size;
    queue<int> q;

    void init(){
        size = 0;
        newtrie();
        while(!q.empty()) q.pop();
    }

    int newtrie(){
        memset(node[size].next, 0, sizeof(node[size].next));
        node[size].cnt = node[size].fail = 0;
        return size++;
    }

    void insert(char *s){
        int len = strlen(s);
        int now = 0;
        for(int i = 0; i < len; i++){
            int c = id(s[i]);
            if(node[now].next[c] == 0){
                node[now].next[c] = newtrie();
            }
            now = node[now].next[c];
        }
        node[now].cnt = 1;
    }

    void build(){
        node[0].fail = -1;
        q.push(0);

        while(!q.empty()){
            int u = q.front();
            q.pop();
            if(node[node[u].fail].cnt && u) node[u].cnt = 1;   //Neither can be taken.
            for(int i = 0; i < 4; i++){
                if(!node[u].next[i]){
                    if(u == 0)
                        node[u].next[i] = 0;
                    else
                        node[u].next[i] = node[node[u].fail].next[i];
                }
                else{
                    if(u == 0) node[node[u].next[i]].fail = 0;
                    else{
                        int v = node[u].fail;
                        while(v != -1){
                            if(node[v].next[i]){
                                node[node[u].next[i]].fail = node[v].next[i];
                                break;
                            }
                            v = node[v].fail;
                        }
                        if(v == -1) node[node[u].next[i]].fail = 0;
                    }
                    q.push(node[u].next[i]);
                }
            }
        }
    }

    void query(){
        Mat a;
        memset(a.s, 0, sizeof(a.s));
        for(int i = 0; i < size; i++){
            for(int j = 0; j < 4; j++){
                if(node[node[i].next[j]].cnt == 0){
                    a.s[i][node[i].next[j]]++;
                }
            }
        }
        a = ppow(a, n);
        ll ans = 0;
        for(int i = 0; i < size; i++){
            if(node[i].cnt == 0) ans = (ans + a.s[0][i]) % MOD;
        }
        printf("%lld\n", ans);
    }

}ac;
char s[20];
int main(){
    while(~scanf("%d%lld", &m, &n)){
        ac.init();
        while(m--){

            scanf("%s", s);
            ac.insert(s);
        }
        ac.build();
        tn = ac.size;
        ac.query();
    }
    return 0;
}

Posted by nmphpwebdev on Mon, 28 Oct 2019 10:20:36 -0700