題解:CF2146D2 Max Sum OR (Hard Version)
好題。
思路
首先貪心的考慮,對于兩個數 \(x\) 和 \(y\),必然是當 \(x\) 或 \(y\) 后二進制里全為 \(1\) 是最優的,在下文中我們稱其為互補。知道這個結論Easy Version就可以直接從 \(r\) 開始往 \(0\) 枚舉,對于每個數都取互補的數即可。
void solve() {
int l, r;
ll res = 0;
cin >> l >> r;
vector<int> ans(r + 1);
vector<bool> flag(r + 1);
for (int i = r; i >= l; i--) {
if (flag[i]) continue;
int cnt = 0;
for (int j = 0; j <= log2(i); j++) {
if (!((1 << j) & i)) {
cnt += (1 << j);
}
}
res += (cnt | i);
ans[i] = cnt;
ans[cnt] = i;
flag[i] = flag[cnt] = true;
}
cout << res * 2 << "\n";
for (int i = 0; i <= r; i++) {
cout << ans[i] << " \n"[i == r];
}
}
回到本題,因為 \(l\ge 0\),所以上面的貪心方法需要改進。容易得到,二進制中的相同的高位都是可以去掉的,所以我們只需要考慮剩下的位即可。考慮手玩幾組數據,當 \(l=2\) 且 \(r=7\) 時如下:
觀察發現此時仍然可以得到兩組互補的數,即:\((2, 5)\)、\((3, 4)\)。此時考慮先把這兩組的答案計算上,這樣互補的數中的每一位都用上了,所以一定不劣。那剩下的數怎么辦呢?
注意到標紅部分屬于二進制中相同的高位,于是乎把它們去掉,然后便得到了:
于是又得到了一對互補的數。
根據上述手玩的過程,我們歸納出一個方法:
對于一段區間 \([l, r]\),在去除二進制中相同的高位后,必然可以被劃分成兩個區間 \([l, pos]\) 與 \((pos, r]\),其中第一個區間中的數在二進制表示下的最高位為 \(0\),第二個區間中的數在二進制表示下的最高位為 \(1\)。此時,不妨令 \(pos - l + 1 \ge r - pos\)。在這種情況下,\((pos, r]\) 這段區間中的數一定可以與第一個區間中進行互補匹配,于是將匹配的數計算進答案,然后令 \(r\leftarrow 2\times pos - r\),將區間右邊界縮小。若 \(pos - l + 1 < r - pos\),則令 \(l\leftarrow 2\times pos - l + 2\)。
\(r\leftarrow 2\times pos - r\) 和 \(l\leftarrow 2\times pos - l + 2\) 是怎么來的?
還是上面那組樣例:
此時的 \(pos=3\),容易發現互補的匹配是從 \(pos\) 開始對稱的,即:
所以我們可以得到在計算完所有完美匹配后的 \(l\) 和 \(r\) 應該縮小到哪里。此時就有:
- 若 \(pos - l + 1 \ge r - pos\),則 \(r\leftarrow pos-(r-pos)=2\times pos-r\)
- 若 \(pos - l + 1 < r - pos\),則 \(l\leftarrow pos + pos - l + 1 + 1 = 2\times pos - l + 2\)
最后的問題便不斷的縮小,終止邊界為 \(l\ge r\)。注意,若 \(l=r\) 則還要計算一個數或上它本身的情況。
在將相同的高位刪除時需要帶一個 \(\log\),所以總時間復雜度為:\(O(n\log V)\),\(V\) 為取值上限。
代碼
部分地方的式子和文中有一些差別,但不影響理解(AC 記錄)。
#include <bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define i128 __int128
#define inf (1ll << 62)
#define PII pair<int, int>
#define pb push_back
#define fi first
#define se second
using namespace std;
void solve() {
ll l, r, d = 0, res = 0;
cin >> l >> r;
ll x = r, y = l;
vector<int> ans(r - l + 1);
while (r > l) {
// cerr << l << " " << r << "\n";
ll l1 = l + d, r1 = r + d;
int num = 0;
for (int i = 30; i >= 0; i--) {
int ok = -1;
bool vis = false;
for (int j = l; j <= r; j++) {
if (ok == -1) {
ok = ((1 << i) & j);
} else if (ok != ((1 << i) & j)) {
vis = true;
for (int k = l; k <= r; k++) {
if ((1 << i) & k) num++;
}
break;
}
}
if (vis) break;
d += ok;
}
l = l1 - d;
r = r1 - d;
if (num >= (r - l) / 2 + 1) {
int pos = r - num, pos1 = l + (pos - l) * 2 + 1;
res += (pos1 - l + 1) * d;
for (int i = l, j = pos1; i <= j; i++, j--) {
ans[i + d - y] = j + d;
ans[j + d - y] = i + d;
res += (i | j);
if (i != j) res += (i | j);
}
l = pos1 + 1;
} else {
int pos = r - num + 1, pos1 = r - (r - pos) * 2 - 1;
res += (r - pos1 + 1) * d;
for (int i = r, j = pos1; i >= j; i--, j++) {
ans[i + d - y] = j + d;
ans[j + d - y] = i + d;
res += (i | j);
if (i != j) res += (i | j);
}
r = pos1 - 1;
}
}
if (r == l) {
res += r + d;
ans[r + d - y] = r + d;
}
cout << res << "\n";
for (int i = 0; i <= x - y; i++) {
cout << ans[i] << " \n"[i == x - y];
}
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
int t = 1;
cin >> t;
while (t--) {
solve();
}
return 0;
}

浙公網安備 33010602011771號