stevenhalim / cpbook-code

CP4 Free Source Code Project (C++17, Java11, Python3 and OCaml)
2.03k stars 488 forks source link

Rabin-Karp implementation runs in O(n log n) due to using extended euclidean algorithm for modInverse #95

Open BrandonTang89 opened 2 years ago

BrandonTang89 commented 2 years ago

Considering the following section in /ch6/string_matching.cpp, we call modInverse for every check (we check O(n) possible sub-strings of length m). The comment for hash_fast states it is O(1) but since the extended euclidean algorithm is O(log m), then our total time for string matching is O(n log m). https://github.com/stevenhalim/cpbook-code/blob/c3fb85a1acc5f31e15879741e4c826684243fddf/ch6/string_matching.cpp#L71-L97

We can fix this by pre-computing the inverse of each P[i] using the following implementation in O(log m + n) since we only use modInverse once.

class RollingHash {
   public:
    vi P, H;   // P[i] = p^i mod m, H[i] is the hash of prefix length i
    vi P_inv;  // P_inv[i] = p^(-i) mod m
    const int n;
    string T;
    const ll p, M;

    RollingHash(string _s, int _p = 131, int _M = (int)1e9 + 7)
        : n(_s.size()), T(_s), p(_p), M(_M) {
        PrepareP();
        computeRollingHash();
    }
    void PrepareP() {  // precompute P and P_inv
        P.assign(n, 0);
        P[0] = 1;
        for (int i = 1; i < n; i++) P[i] = (P[i - 1] * p) % M;

        P_inv.assign(n, 0);
        P_inv[n - 1] = modInverse(P[n - 1], M);
        for (int i = n - 2; i >= 0; i--) P_inv[i] = (P_inv[i + 1] * p) % M;
    }

    void computeRollingHash() {  // precompute H
        H.assign(n, 0);
        for (int i = 0; i < n; i++) {
            if (i != 0) H[i] = H[i - 1];
            H[i] = (H[i] + ((ll)T[i] * P[i]) % M) % M;
        }
    }

    int getHash(int l, int r) {  // get hash of substring [l, r]
        if (l == 0) return H[r];
        int ans = ((H[r] - H[l - 1]) % M + M) % M;
        ans = ((ll)ans * P_inv[l]) % M;
        return ans;
    }
};

// Returns a vector of indices of all occurrences of pattern in text
vi rabinKarp(string P, string T) {
    RollingHash P_rolling(P);
    RollingHash T_rolling(T);
    vi matches;

    int n = T.size(), m = P.size();
    int p_hash = P_rolling.getHash(0, m - 1);
    for (int i = 0; i <= n - m; i++) {
        if (p_hash == T_rolling.getHash(i, i + m - 1)) {  // match
            matches.push_back(i);
        }
    }
    return matches;
};

Note that using the original code gets TLE for kattis stringmatching while the modified can AC.