概念
DDP,可以理解為轉移會發生改變的動態規劃。
當然這個改變是題目中給的,包括系數,轉移位置的改變。顯然暴力枚舉這些改變是不現實的,我們要把改變體現到其他地方。
最經典的,體現到矩陣上。
我們把轉移寫成矩陣,那么改變轉移就是改變轉移矩陣。
具體的改變會落實到具體的題目上。
廣義矩陣乘法
因為轉移的多樣性,矩陣乘法不一定需要用一般乘法的乘完相加。在滿足結合律的情況下,可以是乘完取 \(\min\),加完取 \(\max\) 等。
如 CF750E,要刪除最少,轉移中需要取 \(\min\),所以寫成矩陣時,重載乘法就用到了加完取 \(\min\),同時因為其有結合律,其仍舊可以像一般矩陣乘法進行上樹等操作。
線段樹維護
矩陣滿足結合律,可以用線段樹維護。
面對每一位轉移不同的題目或者只需統計區間答案的題目時,使用線段樹維護區間轉移矩陣的積是很必要的。
主要是代碼實現的難度。
struct mat
{
int mat[6][6];
}a,c;
mat operator *(mat a,mat b)
{
mat c;
memset(c.mat,63,sizeof(c.mat));
for(int k=0;k<5;k++)
{
for(int i=0;i<5;i++)
{
for(int z=0;z<5;z++)
{
c.mat[i][z]=min(c.mat[i][z],a.mat[i][k]+b.mat[k][z]);
}
}
}
return c;
}
mat mul(mat a,mat b)
{
mat c;
memset(c.mat,63,sizeof(c.mat));
for(int k=0;k<5;k++)
{
for(int i=0;i<1;i++)
{
for(int z=0;z<5;z++)
{
c.mat[i][z]=min(c.mat[i][z],a.mat[i][k]+b.mat[k][z]);
}
}
}
return c;
}
int n,m,q,rt,w[200001];
mat sum[800001],inn;
void add(int o,int l,int r,int x,mat y)
{
if(l==r)
{
sum[o]=y;
return;
}
int mid=r+l>>1;
if(x<=mid) add((o<<1),l,mid,x,y);
else add((o<<1)+1,mid+1,r,x,y);
sum[o]=sum[(o<<1)]*sum[(o<<1)+1];
}
mat get(int o,int l,int r,int x,int y)
{
if(x<=l&&y>=r) return sum[o];
int mid=l+r>>1;
if(mid>=y)
{
return get(o<<1,l,mid,x,y);
}
if(x>mid)
{
return get((o<<1)+1,mid+1,r,x,y);
}
return get(o<<1,l,mid,x,y)*get((o<<1)+1,mid+1,r,x,y);
}
解決樹上DDP問題
使用樹鏈剖分把樹斷為鏈,重鏈內是序列問題可以自己解決。而重鏈之間的轉移成為難點。
我們稱一個重鏈頂與他的父親組成一個卡口。改變一個點的值后,所有他到父親的卡口值會改變。體現輕重鏈,我們設 \(g_u\) 為只與 \(u\) 親兒子有關的轉移,\(f_{uw}\) 為 \(u\) 的重兒子的 \(DP\) 值,我們必須把 \(f_u\) 轉移寫成只與 \(g_u\) 和 \(f_{uw}\) 有關的式子。
為什么呢?
保證時間復雜度,因為每個重鏈內是序列問題,它是不用改變的,而到了卡口,\(g\) 值會變。若和其他 \(f\) 有關,那么改變一個點的值將導致他到根的所有 \(f\) 值改變,因為他們的轉移都依賴于此。
模板題
#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
using namespace std;
struct mat
{
int mat[2][2];
}gg[100001];
mat operator *(mat a,mat b)
{
mat c;
for(int i=0;i<2;i++)
{
for(int z=0;z<2;z++)
{
c.mat[i][z]=-100000000;
}
}
for(int k=0;k<2;k++)
{
for(int i=0;i<2;i++)
{
for(int z=0;z<2;z++)
{
c.mat[i][z]=max(c.mat[i][z],a.mat[i][k]+b.mat[k][z]);
}
}
}
return c;
}
mat mul(mat a,mat b)
{
mat c;
for(int i=0;i<2;i++)
{
for(int z=0;z<2;z++)
{
c.mat[i][z]=-100000000;
}
}
for(int k=0;k<2;k++)
{
for(int i=0;i<1;i++)
{
for(int z=0;z<2;z++)
{
c.mat[i][z]=max(c.mat[i][z],a.mat[i][k]+b.mat[k][z]);
}
}
}
return c;
}
int n,m,q,rt,w[200001];
mat sum[800001];
int fat[100001],siz[100001],dep[100001],hson[100001],top[100001],cnt,dfn[100001],dis[100001],f[100001][2],downd[100001];
vector<int> g[1000001];
void add(int o,int l,int r,int x,mat y)
{
if(l==r)
{
sum[o]=y;
return;
}
int mid=r+l>>1;
if(x<=mid) add((o<<1),l,mid,x,y);
else add((o<<1)+1,mid+1,r,x,y);
sum[o]=sum[(o<<1)]*sum[(o<<1)+1];
}
mat get(int o,int l,int r,int x,int y)
{
if(x<=l&&y>=r) return sum[o];
int mid=l+r>>1;
if(mid>=y)
{
return get(o<<1,l,mid,x,y);
}
if(x>mid)
{
return get((o<<1)+1,mid+1,r,x,y);
}
return get(o<<1,l,mid,x,y)*get((o<<1)+1,mid+1,r,x,y);
}
void getdfsh(int u,int fa)
{
fat[u]=fa;
dep[u]=dep[fa]+1;
int lll=0;
f[u][1]=w[u];
for(int i=0;i<g[u].size();i++)
{
int v=g[u][i];
if(v==fa) continue;
getdfsh(v,u);
if(siz[v]>lll)
{
hson[u]=v;
lll=siz[v];
}
siz[u]+=siz[v];
f[u][1]+=f[v][0];
f[u][0]+=max(f[v][0],f[v][1]);
}
siz[u]++;
}
void gettd(int u,int fa)
{
gg[u].mat[1][0]=w[u];
gg[u].mat[1][1]=-100000000;
dfn[u]=++cnt;
dis[u]=cnt;
if(hson[fat[u]]==u)
{
top[u]=top[fa];
downd[top[u]]=dfn[u];
}
else
{
top[u]=u;
downd[top[u]]=dfn[u];
}
if(hson[u]!=0) gettd(hson[u],u);
for(int i=0;i<g[u].size();i++)
{
int v=g[u][i];
if(v==fa||v==hson[u]) continue;
gettd(v,u);
gg[u].mat[0][0]+=max(f[v][0],f[v][1]);
gg[u].mat[1][0]+=f[v][0];
}
gg[u].mat[0][1]=gg[u].mat[0][0];
}
void getdis(int u, int fa) {
for(int i=0;i<g[u].size();i++)
{
int v=g[u][i];
if (v==fa) continue;
getdis(v,u);
dis[u]=max(dis[u],dis[v]);
}
}
void update(int x,int val)
{
gg[x].mat[1][0]+=val-w[x];
w[x]=val;
while(x)
{
mat las=get(1,1,n,dfn[top[x]],downd[top[x]]);
add(1,1,n,dfn[x],gg[x]);
mat now=get(1,1,n,dfn[top[x]],downd[top[x]]);
x=fat[top[x]];
gg[x].mat[0][0]+=max(now.mat[0][0],now.mat[1][0])-max(las.mat[0][0],las.mat[1][0]);
gg[x].mat[0][1]=gg[x].mat[0][0];
gg[x].mat[1][0]+=now.mat[0][0]-las.mat[0][0];
}
}
signed main()
{
scanf("%d",&n);
scanf("%d",&m);
for(int i=1;i<=n;i++)
{
scanf("%d",&w[i]);
}
for(int i=1,u,v;i<n;i++)
{
scanf("%d%d",&u,&v);
g[u].push_back(v);
g[v].push_back(u);
}
getdfsh(1,0);
gettd(1,0);
getdis(1,0);
for(int i=1;i<=n;i++)
{
add(1,1,n,dfn[i],gg[i]);
}
for(int i=1;i<=m;i++)
{
int x,val;
scanf("%d%d",&x,&val);
update(x,val);
mat ans=get(1,1,n,1,downd[1]);
printf("%d\n",max(ans.mat[0][0],ans.mat[1][0]));
}
}
浙公網安備 33010602011771號