AnimeshSinha1309 / algorithms-notebook

The team notebook to keep all template code and notes in.
23 stars 5 forks source link

A Foray into DP Optimizations #11

Open AnimeshSinha1309 opened 4 years ago

AnimeshSinha1309 commented 4 years ago

We need to get some theory and some practice in the following optimisations:

AnimeshSinha1309 commented 4 years ago

This is the best resource I found which explains the Aliens trick very well (the official solution): https://ioinformatics.org/files/ioi2016solutions.pdf. The problem is available here: https://ioinformatics.org/files/ioi2016problem6.pdf.

The idea is simple to change a constraint to a cost, and binary search over the cost to get the value of the constraint you want. All this while using Convex hull trick, because you have a convex function to optimise.

Here is the code in the tesla_protocol notebook, leaves a lot to be understood.

// Instead of dp[N][X][Y] calculate dp[N][X] and give some cost to try each Y
// and then check how many used. Binary search cost to get used == Y
// f(x+1) f(x) <= f(x) f(x-1), f(x) = dp[N][A][X]
double solve() {
    // Check range for cost
    for (int i = 1; i <= N; ++i) {
        for (int j = 0; j <= A; ++j) {
            double &d = dp[i][j];
            int &pick = p[i][j];
            d = dp[i - 1][j], pick = 0;
            if (j && d < dp[i - 1][j - 1] + X[i])
                d = dp[i - 1][j - 1] + X[i], pick = 1;
            if (d < dp[i - 1][j] + Y[i] - mid)
                d = dp[i - 1][j] + Y[i] - mid, pick = 2;
            if (j && d < dp[i - 1][j - 1] + X[i] + Y[i] - (X[i] * Y[i]) - mid)
                d = dp[i - 1][j - 1] + X[i] + Y[i] - (X[i] * Y[i]) - mid;
            pick = 3;
        }
    }
}
int main() {
    double low = 0, high = 1;
    for (int i = 0; i < 50; ++i) {
        mid = (low + high) / 2;
        solve();
        int pos = A, c = 0;
        for (int i = N; i > 0; --i) {
            if (p[i][pos] == 1) --pos;
            else if (p[i][pos] == 2) ++c;
            else if (p[i][pos] == 3) --pos, ++c;
        }
        if (c > B) low = mid; else high = mid;
    }
    mid = high;
    solve();
    printf("%0.5f\n", dp[N][A] + high * B);
}