Friday, February 1, 2013

Maxim and Restaurant (Part 2)

Maxim and Restaurant (Part 1)

In previous post I discussed a O(n ^ 3 * k) solution. Today, I'm going to discuss how we can reduce the complexity to O(n ^ 2 * k). Basically if we see the problem from a different angle we don't need to run our DP for each blocker.

For n persons in queue, we have n! permutations. Lets consider the number of permutations for which we can allow only one person is p1, number of permutations for which we can allow two persons is p2, and so on.

So expected number of persons -
$e = \frac{p1 * 1 + p2 * 2 + ... pn * n}{n!}$

Which can be represented as -
$e = \frac{(p1 + p2 + ... + pn) + (p2 + p3 + ... + pn) + ... + (pn)}{n!}$

Here, p1 + p2 + ... + pn actually represents the number of ways where we can allow at least 1 persons. Lets call it m1. Similarly, p2 + p3 + .. + pn represents number of ways where we can allow at least 2 persons, which is m2 and so on.

So our equation is now -
$e = \frac{m1 + m2 + ... + mn}{n!}$
$=> e = \frac{m1}{n!} + \frac{m2}{n!} + \frac{m2}{n!}$
$=> e = f1 + f2 + .. + fn$

Here, fi = probability of allowing at least i persons.

In last post, in our dynamic programming has states dp[i][j]. In dp[i][j], we calculated probability of having sum j by adding first i persons. For a given k, if we add up all dp[k][j] for all 1 <= j <= p, we'll find the probability of allowing at least k persons. So by removing the blocker and running the dynamic programming algorithm only once we can simply calculate f1, f2, .. fn, and by adding these values, we can get our result.

So here is the simplified code -

public void solve(int testNumber, InputReader in, OutputWriter out) {
        int n = in.readInt();
        int[] a = new int[n];
        int total = 0;
        for (int i = 0; i < n; ++i) {
            a[i] = in.readInt();
            total += a[i];
        }
        int p = in.readInt();
        if (total <= p) {
            out.printLine(n);
            return;
        }

        double res = 0.;
        double[][] dp = new double[n + 1][p + 1];
        dp[0][0] = 1.0;
        for (int i = 0; i < n; ++i) {
            int cur = a[i];
            for (int oldCount = n - 1; oldCount >= 0; --oldCount) {
                for (int oldSum = 0; oldSum + cur <= p; ++oldSum) {
                    dp[oldCount + 1][oldSum + cur] += dp[oldCount][oldSum] / (n - oldCount) * (oldCount + 1);
                }
            }
        }

        for (int count = 1; count <= n; ++count) {
            for (int size = 0; size <= p; ++size) {
                res += dp[count][size];
            }
        }
        out.printLine(res);
    }