原题链接:https://atcoder.jp/contests/dp/tasks/dp_j

J - Sushi

Time Limit: 2 sec / Memory Limit: 1024 MB

Problem Statement

There are N dishes, numbered 1,2,,N1,2,…,N. Initially, for each i(1iN)i(1≤i≤N), Dish ii has ai(1ai3)a_i(1≤a_i≤3) pieces of sushi on it.

Taro will perform the following operation repeatedly until all the pieces of sushi are eaten:

  • Roll a die that shows the numbers 1,2,,N1,2,…,N with equal probabilities, and let ii be the outcome. If there are some pieces of sushi on Dish ii, eat one of them; if there is none, do nothing.

Find the expected number of times the operation is performed before all the pieces of sushi are eaten.

Constraints

  • All values in input are integers.
  • 1N3001≤N≤300
  • 1ai31≤a_i≤3

Input

Input is given from Standard Input in the following format:

NN

a1 a2  aNa_1\ a_2\ …\ a_N

Output

Print the expected number of times the operation is performed before all the pieces of sushi are eaten. The output is considered correct when the relative difference is not greater than 10910^{−9}.

Sample Input 1

1
2
3
1 1 1

Sample Output 1

1
5.5

The expected number of operations before the first piece of sushi is eaten, is 1. After that, the expected number of operations before the second sushi is eaten, is 1.5. After that, the expected number of operations before the third sushi is eaten, is 3. Thus, the expected total number of operations is 1+1.5+3=5.5.

Sample Input 2

1
2
1
3

Sample Output 2

1
3

Outputs such as 3.00, 3.000000003 and 2.999999997 will also be accepted.

Sample Input 3

1
2
2
1 2

Sample Output 3

1
4.5

Sample Input 4

1
2
10
1 3 2 3 3 2 3 2 1 3

Sample Output 4

1
54.48064457488221

题目描述

nn个盘子,每个盘子里有aia_i块寿司,1ai31\le a_i \le 3,Taro重复执行下述操作来吃寿司:从NN个盘子里随机选一个盘子,如果这个盘子里有寿司,则吃掉11块,否则什么也不做,求吃完所有寿司所需要的操作次数的期望值。

解决方案

因为NN个盘子选择到的概率是相同的,所以盘子的顺序不影响结果,因此我们可以只记录包含有xx个寿司的盘子数量,记为cxc_x,而每个盘子只会有<=3<=3块寿司,因此我们只需要记录c1,c2,c3c_1,c_2,c_3,而吃完寿司的盘子数量为Nc1c2c3N-c_1-c_2-c_3,我们用dp[c1][c2][c3]dp[c1][c2][c3]记录对应状态的期望值,则:

dp[c1][c2][c3]=Nc1c2c3Ndp[c1][c2][c3]+c1Ndp[c11][c2][c3]+c2Ndp[c1+1][c21][c3]+c3Ndp[c1][c2+1][c31]+1dp[c1][c2][c3] = \frac{N-c1-c2-c3}{N}*dp[c1][c2][c3] + \frac{c1}{N}*dp[c1-1][c2][c3] \\ + \frac{c2}{N}*dp[c1+1][c2-1][c3] + \frac{c3}{N}*dp[c1][c2+1][c3-1] + 1

由于两边都有dp[c1][c2][c3]dp[c1][c2][c3],我们可以基于上述等式做一下移项:

dp[c1][c2][c3]Nc1c2c3Ndp[c1][c2][c3]=c1Ndp[c11][c2][c3]+c2Ndp[c1+1][c21][c3]+c3Ndp[c1][c2+1][c31]+1dp[c1][c2][c3] - \frac{N-c1-c2-c3}{N}*dp[c1][c2][c3] = \frac{c1}{N}*dp[c1-1][c2][c3] \\ + \frac{c2}{N}*dp[c1+1][c2-1][c3] + \frac{c3}{N}*dp[c1][c2+1][c3-1] + 1

c1+c2+c3Ndp[c1][c2][c3]=c1Ndp[c11][c2][c3]+c2Ndp[c1+1][c21][c3]+c3Ndp[c1][c2+1][c31]+1\frac{c1+c2+c3}{N}*dp[c1][c2][c3] = \frac{c1}{N}*dp[c1-1][c2][c3] + \frac{c2}{N}*dp[c1+1][c2-1][c3] \\ + \frac{c3}{N}*dp[c1][c2+1][c3-1] + 1

dp[c1][c2][c3]=(c1Ndp[c11][c2][c3]+c2Ndp[c1+1][c21][c3]+c3Ndp[c1][c2+1][c31]+1)Nc1+c2+c3dp[c1][c2][c3] = (\frac{c1}{N}*dp[c1-1][c2][c3] + \frac{c2}{N}*dp[c1+1][c2-1][c3] \\ + \frac{c3}{N}*dp[c1][c2+1][c3-1] + 1) * \frac{N}{c1+c2+c3}

dp[c1][c2][c3]=c1c1+c2+c3dp[c11][c2][c3]+c2c1+c2+c3dp[c1+1][c21][c3]+c3c1+c2+c3dp[c1][c2+1][c31]+Nc1+c2+c3dp[c1][c2][c3] = \frac{c1}{c1+c2+c3}*dp[c1-1][c2][c3] + \frac{c2}{c1+c2+c3}*dp[c1+1][c2-1][c3] \\ + \frac{c3}{c1+c2+c3}*dp[c1][c2+1][c3-1] + \frac{N}{c1+c2+c3}

时间复杂度

O(N3)O(N^3)

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#include <bits/stdc++.h>

using namespace std;

template <typename T1, typename T2> istream &operator>>(istream &is, const pair<T1, T2> &pa) { is >> pa.first >> pa.second; return is; }
template <typename T> istream &operator>>(istream &is, vector<T> &vec) { for (auto &v : vec) is >> v; return is; }
template <typename T1, typename T2> ostream &operator<<(ostream &os, const pair<T1, T2> &pa) { os << "(" << pa.first << "," << pa.second << ")"; return os; }
template <typename T> ostream &operator<<(ostream &os, const vector<T> &vec) { os << "["; for (auto v : vec) os << v << ","; os << "]"; return os; }
template <typename T> ostream &operator<<(ostream &os, const deque<T> &vec) { os << "deq["; for (auto v : vec) os << v << ","; os << "]"; return os; }
template <typename T> ostream &operator<<(ostream &os, const set<T> &vec) { os << "{"; for (auto v : vec) os << v << ","; os << "}"; return os; }
template <typename T> ostream &operator<<(ostream &os, const multiset<T> &vec) { os << "{"; for (auto v : vec) os << v << ","; os << "}"; return os; }
template <typename T> ostream &operator<<(ostream &os, const unordered_set<T> &vec) { os << "{"; for (auto v : vec) os << v << ","; os << "}"; return os; }
template <typename T> ostream &operator<<(ostream &os, const unordered_multiset<T> &vec) { os << "{"; for (auto v : vec) os << v << ","; os << "}"; return os; }
template <typename TK, typename TV> ostream &operator<<(ostream &os, const unordered_map<TK, TV> &mp) { os << "{"; for (auto v : mp) os << v.first << "=>" << v.second << ","; os << "}"; return os; }
template <typename TK, typename TV> ostream &operator<<(ostream &os, const map<TK, TV> &mp) { os << "{"; for (auto v : mp) os << v.first << "=>" << v.second << ","; os << "}"; return os; }
template <typename T> void resize_array(vector<T> &vec, int len) { vec.resize(len); }
template <typename T, typename... Args> void resize_array(vector<T> &vec, int len, Args... args) { vec.resize(len); for (auto &v : vec) resize_array(v, args...); }
template <typename T1, typename T2> pair<T1, T2> operator+(const pair<T1, T2> &l, const pair<T1, T2> &r) { return make_pair(l.first + r.first, l.second + r.second); }
template <typename T1, typename T2> pair<T1, T2> operator-(const pair<T1, T2> &l, const pair<T1, T2> &r) { return make_pair(l.first - r.first, l.second - r.second); }
long long gcd(long long a, long long b) { return b ? gcd(b, a % b) : a; }
mt19937 mrand(random_device{}());
int rnd(int x) { return mrand() % x; }

#define rep(i, a, n) for (int i = a; i < (n); i++)
#define per(i, a, n) for (int i = (n)-1; i >= a; i--)
#define pb push_back
#define mp make_pair
#define all(x) (x).begin(), (x).end()
#define fi first
#define se second
#define sz(x) ((int)(x).size())
typedef vector<int> vi;
typedef long long ll;
typedef unsigned int uint;
typedef unsigned long long ull;
typedef pair<int, int> pii;
typedef double db;
#if DEBUG
#define dbg(x) cerr << #x << " = " << (x) << " (L" << __LINE__ << ") " << __FILE__ << endl;
#else
#define dbg(x)
#endif

class Solution {
public:
void Solve() {
int n;
while(cin>>n) {
vector<vector<vector<double>>> dp;
resize_array(dp, n+1, n+1, n+1);

vi cnt(4);
int t;
rep(i,0,n) {
cin>>t;
cnt[t]++;
}

function<double(int,int,int)> solve = [&] (int c1, int c2, int c3) {
if (c1==0&&c2==0&&c3==0) return 0.0;
if (dp[c1][c2][c3] > 0) return dp[c1][c2][c3];
double ans = n * 1.0 / (c1 + c2 + c3);
if (c1) ans += c1 * 1.0 / (c1 + c2 + c3) * solve(c1-1, c2, c3);
if (c2) ans += c2 * 1.0 / (c1 + c2 + c3) * solve(c1+1, c2-1, c3);
if (c3) ans += c3 * 1.0 / (c1 + c2 + c3) * solve(c1, c2+1, c3-1);
return dp[c1][c2][c3] = ans;
};

cout << setiosflags(ios::fixed) << setprecision(9);
cout << solve(cnt[1], cnt[2], cnt[3]) << endl;
}
}

private:
};

void set_io(const string &name = "") {
ios::sync_with_stdio(false);
cin.tie(nullptr);
if (!name.empty()) {
freopen((name + ".in").c_str(), "r", stdin);
freopen((name + ".out").c_str(), "w", stdout);
}
}

int main() {
set_io("");
Solution().Solve();

return 0;
}