CF833B The Bakery 題解
題目傳送門
題目大意
將一個長度為 \(n\) 的序列分為 \(k\) 段,使得總價值最大。
一段區間的價值表示為區間內不同數字的個數。
\(n≤35000,k≤50\)
輸入輸出樣例 #1
輸入 #1
4 1
1 2 2 1
輸出 #1
2
輸入輸出樣例 #2
輸入 #2
7 2
1 3 3 1 4 4 4
輸出 #2
5
輸入輸出樣例 #3
輸入 #3
8 3
7 7 8 7 7 8 1 7
輸出 #3
6
說明/提示
In the first example Slastyona has only one box. She has to put all cakes in it, so that there are two types of cakes in the box, so the value is equal to $ 2 $ .
In the second example it is profitable to put the first two cakes in the first box, and all the rest in the second. There are two distinct types in the first box, and three in the second box then, so the total value is $ 5 $ .
思路
基礎動規
一眼動規(區間dp),設\(\text{dp}_{i,j}\)表示用 \(i\) 個盒子裝前 \(j\) 個蛋糕所能達到的最大價值。區間dp典型思路枚舉決策點(此處決策為最后一個盒子裝哪些蛋糕)。
枚舉最后一個盒子為盒子 \(i\) ,裝的是 \(a_{k+1},a_{k+2},\cdots,a_j\),此刪除此盒子后的狀態為則\(\text{dp}_{i-1,k}\)。
狀態轉移需加上最后一個盒子盒子 \(i\) 的代價,即 \(\text{w}(k + 1,j)\)。(\(\text w\)函數是用來統計給定區間\([l,r]\)里不同數字的個數的函數,即此區間的價值)
\(\text{dp}_{i,j}=\max_{k=i-1}^{j}{\text{dp}_{i-1,k}+\text{w}(k+1,j)}\)
然而! 我們注意到了數據范圍,\(n\in [1,3.5\times10^4]\) ,如果使用暴力循環動規的話時間復雜度為 外層循環和w函數的乘積,即\(n\times k\times n\times n=\text{O}(n^3k)\),顯然爆炸。
優化
顯然時間復雜度O(nk)的循環是 有必要的,那么時間復雜度的第三維就只能是\(\log n\)了
考慮將 \(\text{dp}_{i-1,k}+\text{w}(k+1,j)\) 捆綁,視為一項,這樣就可以用最簡單的線段樹模版(最值線段樹)維護,再\(\text{O}(\log n)\)計算w函數,程序的時間復雜度就降到了\(\text O(nk\log n)\),可以通過。
\(\text O(n^2)\rightarrow\text O(\log n)\)
這里我將盡力使用通俗易懂(或者說精確)的語言解釋各個題解中對于“數字貢獻”的描述,并且盡量提供微量具體的代碼片段,以幫助理解。
我們注意到,計算\(\text w\)函數的過程中,一個數的“影響范圍”顯然是連續的,可以用線段樹整體log維護,而不用逐個O(1)(累計O(n))的復雜度去累加
而這點具體表示為,若一個數 \(a_i\) ,與其兩側相鄰的數為 \(a_x,a_y\) ,則 \(a_i\) 能夠影響到\(\text{w}(L,R)\) 需滿足
我們在枚舉dp第二層(蛋糕個數為j時),可以不斷更新 \(a_x\) 中的下標 \(x\),記為 \(\text{pre}_{a_i}\),將 \(j\) 視為 上述 \(R\) ,這樣就不需要考慮在 \(a_i\) 右側的 \(a_y\)了。枚舉到 \(a_j\) 時,自然想到使用線段樹累加所有的\(\text{w}(l,j),其中l\in\{\text{pre}_{a_i}+1,j\}\)。做完這一步后,立刻更新\(\text{dp}_{i,j}\)為線段樹中\(1,j\)的最值,代表了\(\max_{k=i-1}^{j}{\text{dp}_{i-1,k}+\text{w}(k+1,j)}\),與上式匹配,互相印證,說明我們的方向是正確的。
代碼
#include <bits/stdc++.h>
using namespace std;
const int N = 35005, K = 55;
int n, k, a[N];
int dp[N]; // dp[i][j]: maximum value obtained by ordering first j cakes into i boxes
int pre[N]; // pre[i]: 元素i上次出現的位置
namespace ST
{
#define ls i << 1
#define rs i << 1 | 1
// 注意此處線段樹是為了區間max而服務的,因此并不是說主要思路就是線段樹
struct SegTree
{
struct NODE
{
int l, r; // 節點i表示的區間左右端點,不是左右兒子
int sum; // [l,r]區間max(不用管懶標記)
int lz; // sum的懶標記(寫慣了懶得改max了
} tr[N << 2];
void build(int i, int L, int R)
{
tr[i].l = L;
tr[i].r = R;
tr[i].lz = 0;
if (L == R) {
tr[i].sum = dp[L - 1];
return ;
}
int mid = (L + R) >> 1;
build(ls, L, mid);
build(rs, mid + 1, R);
tr[i].sum = max(tr[ls].sum, tr[rs].sum);
}
void pushdown(int i)
{
if (!tr[i].lz) return ;
tr[ls].sum += tr[i].lz;
tr[ls].lz += tr[i].lz;
tr[rs].sum += tr[i].lz;
tr[rs].lz += tr[i].lz;
tr[i].lz = 0;
}
void add(int i, int L, int R, int x)
{
if (R < tr[i].l || tr[i].r < L) return ;
if (L <= tr[i].l && tr[i].r <= R) {
tr[i].sum += x;
tr[i].lz += x;
return ;
}
pushdown(i);
if (tr[ls].r >= L) add(ls, L, R, x);
if (tr[rs].l <= R) add(rs, L, R, x);
tr[i].sum = max(tr[ls].sum, tr[rs].sum);
}
int query(int i, int L, int R)
{
if (R < tr[i].l || tr[i].r < L) return 0;
if (L <= tr[i].l && tr[i].r <= R) return tr[i].sum;
pushdown(i);
int res = 0;
if (tr[ls].r >= L) res = max(res, query(ls, L, R));
if (tr[rs].l <= R) res = max(res, query(rs, L, R));
return res;
}
};
#undef ls
#undef rs
}
using namespace ST;
SegTree st;
int main()
{
scanf("%d%d", &n, &k);
for (int i = 1; i <= n; ++i) scanf("%d", a + i);
/* 草稿:
// 初始化一個盒子的情況:
// for (int i = 1; i <= n; ++i) {
// dp[1][i] = c(1, i);
// }
// for (int i = 2; i <= k; ++i) { // 枚舉i個盒子
// for (int j = i; j <= n; ++j) { // 枚舉前j個元素
// // 考慮現有階段由少裝一個盒子的階段得到
// // 枚舉最后一個盒子裝的第一個元素,注意這里最后一個盒子至少裝了i-1個元素
// for (int st = i - 1; st < j; ++st) {
// // 轉移方程意義:由上一個狀態得到,再加上價值
// dp[i][j] = max(dp[i][j], dp[i - 1][st] + c(st + 1, j));
// }
// }
// }
*/
for (int i = 1; i <= k; ++i) {
fill(pre, pre + n + 1, 0);
st.build(1, 1, n);
for (int j = 1; j <= n; ++j) {
// 一個數的貢獻影響到它上一次出現的位置到這個下標
// printf("a[%d]: %d, the last time it appeared was %d\n", j, a[j], pre[a[j]]);
st.add(1, pre[a[j]] + 1, j, 1);
pre[a[j]] = j; // 更新此值最后出現的位置
dp[j] = st.query(1, 1, j);
}
// for (int j = 1; j <= n; ++j) {
// printf("w(%d, %d)=%d\n", j, i, st.query(1, j, j));
// }
}
printf("%d\n", dp[n]);
return 0;
}

浙公網安備 33010602011771號