題解:luogu_P13696 「CyOI」出包魔法師
非常好的一道數(shù)學題。
原題鏈接。
題目分析
我們要意識到的一點是,題目中要求的最優(yōu)策略,實際上是一個固定的報數(shù)的序列,即對于每一個數(shù)字 \(i\),你要報 \(a_i\) 次這個數(shù)字,并要使得能拿走 \(k\) 張卡牌的概率最大。
那么最終的答案就是
并且滿足 \(\sum_{i = 1}^n a_i = k\)
初始時,每個 \(a_i\) 都為 \(0\),我們將 \(a_i\) 加 \(1\) 對答案的貢獻相當于乘上了 \(\dfrac{l_i - a_i}{a_i + 1}\)(由 \(\dbinom{l_i}{a_i}\) 變?yōu)?\(\dbinom{l_i}{a_i + 1}\))。
這時就已經(jīng)有一個貪心的思路了,每次選擇一個 \(\dfrac{l_i - a_i}{a_i + 1}\) 最大的數(shù)字 \(i\),將其 \(a_i\) 加 \(1\)??梢杂靡粋€堆來維護。復雜度 \(O(k \log m)\)。
優(yōu)化
我們發(fā)現(xiàn)如果將 \(l_i\) 從小到大排序,越靠后的數(shù)選的一定越多(有點廢話)。我們二分 \(l_m\) 選了 \(x\) 個,對于 \(a_i\),我們想讓它的貢獻盡可能的大。那么就是
解得 \(a_i \le \dfrac{(l_i + 1)x}{l_m + 1}\)。說明 \(a_i\) 隨 \(x\) 增大而增大,通過 \(\sum a_i\) 與 \(k\) 的大小關系調(diào)整二分。
最后有可能會沒選滿 \(k\),這時一定滿足 \(m - k - 1 \le \sum a_i \le k\),此時再使用一個堆來貪心地維護就可以了。
加上線性預處理階乘及其逆元,時間復雜度 \(O(\max(l_i) + m \log \max(l_i))\)。
Code
#include <bits/stdc++.h>
using namespace std;
#define int long long
constexpr int N = 1e6 + 9;
constexpr int V = 1e7 + 9;
constexpr int mod = 998244353;
int fac[V], ifac[V];
int l[N], a[N];
int m, k;
int ans = 1ll;
struct Node{
int l, a, id;
friend bool operator < (Node x, Node y){return (x.l - x.a) * (y.a + 1) < (y.l - y.a) * (x.a + 1);}
};
priority_queue<Node>q;
int fp(int a, int n){
int res = 1ll;
while(n){
if(n & 1) res = res * a % mod;
a = a * a % mod;
n >>= 1ll;
}
return res % mod;
}
int C(int n, int m){return fac[n] * ifac[m] % mod * ifac[n - m] % mod;}
bool check(int x){
int res = 0;
for(int i = 1; i <= m; i++) res += (l[i] + 1) * x / (l[m] + 1);
return res <= k;
}
void init(){
fac[0] = 1ll;
for(int i = 1; i < V; i++) fac[i] = 1ll * fac[i - 1] * i % mod;
ifac[V - 1] = fp(fac[V - 1], mod - 2); ifac[0] = 1ll;
for(int i = V - 2; i >= 1; i--) ifac[i] = 1ll * ifac[i + 1] * (i + 1) % mod;
}
signed main(){
init();
cin >> m >> k;
for(int i = 1; i <= m; i++) cin >> l[i];
sort(l + 1, l + 1 + m);
int L = 1, R = l[m], p;
while(L <= R){
int mid = (L + R) >> 1;
if(check(mid)){
L = mid + 1;
p = mid;
}
else R = mid - 1;
}
for(int i = 1; i <= m; i++){
a[i] = (l[i] + 1) * p / (l[m] + 1);
k -= a[i];
if(l[i] > a[i]) q.push((Node){l[i], a[i], i});
}
while(k--){
int i = q.top().id; q.pop();
a[i]++;
if(l[i] > a[i]) q.push((Node){l[i], a[i], i});
}
for(int i = 1; i <= m; i++){
ans = ans * C(l[i], a[i]) % mod;
}
cout << ans % mod;
return 0;
}

浙公網(wǎng)安備 33010602011771號