atcorder 295 E
題目鏈接:https://atcoder.jp/contests/abc295/tasks/abc295_e
題意:
給定一個長為N的數(shù)字序列,序列中每個數(shù)字都在[0, M]這個區(qū)間中。按順序做兩步操作:
第一步,對于數(shù)字序列中每個數(shù)字0,獨立的并且等概率的從區(qū)間[1, M]中選擇一個數(shù), 把這個0代替成選出來的數(shù)。
第二步,把這個數(shù)字序列按照升序排列。
問第K位得到數(shù)字的期望是什么,輸出答案mod998244353
Simple input
3 5 2
2 0 4
Simple output
3
Solution:
第K位的期望是:\(E_k = \displaystyle\sum^{M}_{i = 1}{i * p_i}\)
這樣不是很好求,可以將其轉(zhuǎn)化成\(E_k = \displaystyle\sum^{M}_{i = 1}{i * (b_i - b_{i + 1})}\)
其中\(b_i\)指的是第k位大于等于i的概率。
那么\(E_k = 1 * (b_1 - b_2) + 2 * (b_2 - b_3) + \dots + m - 1 * (b_{m - 1} - b_{m}) + b_m = b_1 + b_2 + \dots + b_m\)
接下來就要求出\(b_i\)
分類討論:
1.:如果序列中大于等于i的數(shù)的個數(shù)是大于等于n - k + 1的那么\(b_i\)就是1,因為這種情況在不改變0的情況下已經(jīng)使得在升序排序后第k位恒大于等于i。
2.:如果序列中大于等于i的個數(shù)不足n - k + 1,那么就需要把一些0變成大于等于i的數(shù)。
對于這種情況,我們事先統(tǒng)計好其中大于等于i的數(shù)的個數(shù),設(shè)為cnt,并且統(tǒng)計出來其中0的個數(shù),設(shè)為num。
如果num + cnt < n - k + 1, 意思就是說,把所有0都轉(zhuǎn)換成大于等于i的數(shù),仍然無法使得第k位的數(shù)大于等于i,因此\(b_i\) = 0。
而如果num + cnt >= n - k + 1, 我們就可以從num個0中選出來需要進(jìn)行轉(zhuǎn)化的0。
對于每一個0,它轉(zhuǎn)換成大于等于i的數(shù)的概率是\(P = \frac{m - i + 1}{m}\)
那么,這種情況\(b_i = \displaystyle\sum^{num}_{i = n - k + 1 - cnt}\left({i\choose num} * P^i * (1 - P)^{num - i}\right)\)
就是加法原理和乘法原理。
Code:
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
#define int LL
const int mod = 998244353;
const int N = 5010;
int c[N][N];
int qmi(int a, int b, int c) {
LL res = 1;
while(b) {
if(b & 1) res = res * a % c;
a = (LL) a * a % c;
b >>= 1;
}
return res;
}
signed main()
{
// ios::sync_with_stdio(false);
// cin.tie(0);
for(int i = 0; i <= 5005; i ++) {
for(int j = 0; j <= i; j ++) {
if(!j) c[i][j] = 1;
else c[i][j] = (c[i - 1][j - 1] + c[i - 1][j]) % mod;
}
}
int n, m, k; cin >> n >> m >> k;
vector<int> seq(n + 1);
vector<int> stc(2010, 0);
int num = 0;//統(tǒng)計序列中零的個數(shù)
for(int i = 1; i <= n; i ++) {
cin >> seq[i];
if(!seq[i]) num ++;
stc[seq[i]] ++;
}
int need = n - k + 1;
int ans = 0;
for(int i = 1; i <= m; i ++) {
int cnt = 0;//統(tǒng)計序列中有多少個數(shù)字大于等于i
for(int j = i; j <= m; j ++) cnt += stc[j];
if(cnt >= need) ans = (ans + 1) % mod;
else {
if(cnt + num < need) continue;
else {
int p = (m - i + 1) * qmi(m, mod - 2, mod) % mod;
for(int k = need - cnt; k <= num; k ++) {
ans = (ans + c[num][k] * qmi(p, k, mod) % mod * qmi((1 - p + mod) % mod, num - k, mod) % mod) % mod;
}
}
}
}
cout << ans;
}

浙公網(wǎng)安備 33010602011771號