ei1333の日記

ぺこい

SRM709 Med(500) Softmatch

TL で Aho-corasick というのが流れていて, 知らなかったので勉強して解いてみた.

ここらへんが練習になりそう.

Med(500) Softmatch

TopCoder Statistics - Problem Statement

問題概要

 {n(1 \le n \le 5)} 個の ‘0’, ‘1’, ‘2’, ‘3’ からなるパターン  {patterns_i(1 \le |patterns_i| \le 50)} と ‘a’, ‘A’, ‘b’, ‘B’ からなる文字列  {S(1 \le |S| \le 50)} が与えられる.

 {S} の 任意の ‘a’, ‘b’ を大文字 ‘A’, ‘B’ に変えることができる. それぞれのアルファベットは以下のようにマッチングされる.

  • ‘a’ は ‘0’ か ‘1’
  • ‘A’ は ‘2’ か ‘3’
  • ‘b’ は ‘0’ か ‘2’
  • ‘B’ は ‘1’ か ‘3’

このときマッチングの数の最大値を求めよ.

解法

Aho-corasick + DP.

Aho-corasick では Trie 木の機能に, 失敗時の遷移をいい感じに加えてオートマトンにしたものである(そういう理解で良いのかな?).

基本的に Aho-corasick では ‘0’ か ‘1’ みたいな複数の文字に対応させることはできない(Trieなので) が, set で Trie の現在にマッチするノードの番号をすべて持たせておくと, 全ての番号を試すことによって判定できる.

dp[idx][set] := Aho-corasick の状態集合が set で残り  {|S| - idx} 文字追加できるときのマッチングの数の最大値

とすれば, dpテーブルはなんとなく疎な気がするのでメモ化させれば間に合う. 計算量はわからず.

ソース

struct TrieNode
{
  int nxt[5];

  int exist; // 子ども以下に存在する文字列の数の合計
  vector< int > accept; // その文字列id

  TrieNode() : exist(0)
  {
    memset(nxt, -1, sizeof(nxt));
  }
};

struct Trie
{
  vector< TrieNode > nodes;
  int root;

  Trie() : root(0)
  {
    nodes.push_back(TrieNode());
  }
  
  void update_direct(int node, int id)
  {
    nodes[node].accept.push_back(id);
  }

  void update_child(int node)
  {
    ++nodes[node].exist;
  }
  
  void add(const string &str, int str_index, int node_index, int id)
  {
    if(str_index == str.size()) {
      update_direct(node_index, id);
    } else {
      const int c = str[str_index] - '0';
      if(nodes[node_index].nxt[c] == -1) {
        nodes[node_index].nxt[c] = (int) nodes.size();
        nodes.push_back(TrieNode());
      }
      add(str, str_index + 1, nodes[node_index].nxt[c], id);
      update_child(node_index);
    }
  }

  void add(const string &str, int id)
  {
    add(str, 0, 0, id);
  }

  void add(const string &str)
  {
    add(str, nodes[0].exist);
  }

  int size()
  {
    return (nodes[0].exist);
  }

  int nodesize()
  {
    return ((int) nodes.size());
  }
};

struct Aho_Corasick : Trie
{
  static const int FAIL = 4;
  vector< int > correct;

  Aho_Corasick() : Trie() {}

  void build()
  {
    correct.resize(nodes.size());
    for(int i = 0; i < nodes.size(); i++) {
      for(int j : nodes[i].accept) correct[i] |= 1 << j;
    }

    queue< int > que;
    for(int i = 0; i < 5; i++) {
      if(~nodes[0].nxt[i]) {
        nodes[nodes[0].nxt[i]].nxt[FAIL] = 0;
        que.emplace(nodes[0].nxt[i]);
      } else {
        nodes[0].nxt[i] = 0;
      }
    }
    while(!que.empty()) {
      TrieNode &now = nodes[que.front()];
      correct[que.front()] |= correct[now.nxt[FAIL]];
      que.pop();
      for(int i = 0; i < 4; i++) {
        if(now.nxt[i] == -1) continue;
        int fail = now.nxt[FAIL];
        while(nodes[fail].nxt[i] == -1) {
          fail = nodes[fail].nxt[FAIL];
        }
        nodes[now.nxt[i]].nxt[FAIL] = nodes[fail].nxt[i];
        que.emplace(now.nxt[i]);
      }
    }
  }

  pair< int, int > move(const string &str,int now = 0)
  {
    int count = 0;
    for(auto &c : str) {
      while(nodes[now].nxt[c - '0'] == -1) now = nodes[now].nxt[FAIL];
      now = nodes[now].nxt[c - '0'];
      count |= correct[now];
    }
    return {count, now};
  }
};

class Softmatch
{
public:
  int count(string S, vector<string> patterns)
  {
    Aho_Corasick aho;
    for(string& s : patterns) aho.add(s);
    aho.build();
    map< set< int >, int > dp[50];

    auto match = [&](set< int >& bit, char c, char d)
    {
      set< int > ss;
      int ret = 0;
      for(int p : bit) {
        auto get = aho.move(string(1, c), p); 
        ss.emplace(get.second);
        ret |= get.first;
      }
      for(int p : bit) {
        auto get = aho.move(string(1, d), p);
        ret |= get.first;
        ss.emplace(get.second);
      }
      return make_pair(__builtin_popcount(ret), ss);
    };
    
    function< int(int, set< int >) > rec = [&](int idx, set< int > bit) {
      if(idx == S.size()) return(0);
      if(dp[idx].count(bit)) return(dp[idx][bit]);
      char c = S[idx];
      int ret = 0;
      if(islower(c)) {
        auto get = match(bit, '0', c == 'a' ? '1' : '2');
        ret = max(ret, rec(idx + 1, get.second) + get.first);
      }
      auto get = match(bit, toupper(c) == 'A' ? '2' : '1', '3');
      return(dp[idx][bit] = max(ret, rec(idx + 1, get.second) + get.first));
    };

    set< int > state;
    state.insert(0);
    return(rec(0, state));
  }
};