动态DP(DDP)学习笔记

动态DP

动态DP就是将 \(DP\) 的状态作为一个向量,\(DP\) 的转移写成一个矩阵,因为矩阵乘法的结合律,我们可以用数据结构维护矩阵的积,然后就能够支持单点修改区间查询了。

洛谷P4719 【模板】"动态 DP"&动态树分治

Description

给定一棵 \(n\) 个点的树,点带点权。

有 \(m\) 次操作,每次操作给定 \(x,y\),表示修改点 \(x\) 的权值为 \(y\)。

你需要在每次操作之后求出这棵树的最大权独立集的权值大小。

Solution

首先考虑一个朴素 \(DP\) ,设 \(dp_{u,0/1}\) 表示只考虑 \(u\) 的子树,\(u\) 没有被选/被选时的最大权独立集,有转移方程:

\[dp_{u,0}=\sum_{v\in son_u}\max(dp_{v,0},dp_{v,1})\\ dp_{u,1}=\sum_{v\in son_u}dp_{v,0} \]

现在进行树链剖分,转移时只考虑重儿子的转移,轻儿子暴力合并上来。

\[g_{u,0}=\sum_{v\in son_u,v\not= hson_u}\max(dp_{v,0},dp_{v,1})\\ g_{u,1}=\sum_{v\in son_u,v\not= hson_u}dp_{v,0} \]

那么

\[dp_{u,0}=g_{u,0}+max(dp_{hson_u,0},dp_{hson_u,1})\\ dp_{u,1}=g_{u,1}+dp_{hson_u,0} \]

这个转移可以写成矩阵形式:

\[\begin{bmatrix} f_{hson_u,0}\\ f_{hson_u,1} \end{bmatrix} \times \begin{bmatrix} g_{u,0}&g_{u,1}\\ g_{u,1}&-\infty \end{bmatrix} = \begin{bmatrix} f_{u,0}\\ f_{u,1} \end{bmatrix} \]

注意这里的矩阵乘法使用的是我们自定义的运算,即:

\[A\times B=C:C_{i,j}=\max_k(A_{i,k}+B_{k,j}) \]

在具体实现过程中,我们先预处理出 \(f,g\),然后进行树链剖分,线段树维护矩阵\(\begin{bmatrix} g_{u,0}&g_{u,1}\\ g_{u,1}&-\infty \end{bmatrix}\)的乘积,查询时直接将路径上每一个链上的部分乘起来即可。

修改时,从下到上修改,\(g\) 矩阵只有轻边的顶点会被修改,在这些位置暴力修改即可。

总复杂度 \(\mathcal O(n\log^2 n)\)。

Code

#include<bits/stdc++.h>
using namespace std;
const int K=2;
const int N=1e5+10;
int n,m,a[N];
struct edge{
	int v,nxt;
}e[N<<2];
int first[N],cnt,siz[N],hson[N],top[N],pos[N],f[N][2],g[N][2],tot,fa[N],bot[N];
inline void add(int u,int v){e[++cnt]=(edge){v,first[u]};first[u]=cnt;}
//--------------Matrix----------------
struct mat{
	int c[K][K];
	mat(){memset(c,-0x3f,sizeof(c));}
	inline int& operator ()(int x,int y){return c[x][y];}
};
inline mat operator *(mat a,mat b){
	mat ret;
	for(int i=0;i<K;++i)
		for(int j=0;j<K;++j)
			for(int k=0;k<K;++k) ret(i,j)=max(ret(i,j),a(i,k)+b(k,j));
	return ret;
}

//-------------Segment Tree-----------
int num[N];
namespace SGT{
	#define lc (p<<1)
	#define rc (p<<1|1)
	#define mid ((l+r)>>1)
	mat tr[N<<2];
	inline void build(int p,int l,int r){
		if(l==r){
			int u=num[l];
			if(!hson[u]){tr[p](0,0)=f[u][0];tr[p](1,0)=g[u][1];return ;}
			tr[p](0,0)=tr[p](0,1)=g[u][0];
			tr[p](1,0)=g[u][1];
			return ;
		}
		build(lc,l,mid);build(rc,mid+1,r);
		tr[p]=tr[lc]*tr[rc];
	}
	inline void update(int p,int x,int l,int r){
		if(l==r){
			int u=num[x];
			if(!hson[u]){tr[p](0,0)=f[u][0];tr[p](1,0)=g[u][1];return ;}
			tr[p](0,0)=tr[p](0,1)=g[u][0];
			tr[p](1,0)=g[u][1];
			return ;
		}
		(x<=mid)?update(lc,x,l,mid):update(rc,x,mid+1,r);
		tr[p]=tr[lc]*tr[rc]; 
	}
	inline mat query(int p,int ql,int qr,int l,int r){
		if(ql<=l&&r<=qr) return tr[p];
		if(ql>mid) return query(rc,ql,qr,mid+1,r);
		else if(qr<=mid) return query(lc,ql,qr,l,mid);
		else return query(lc,ql,qr,l,mid)*query(rc,ql,qr,mid+1,r);
	}
}
using namespace SGT;

//------------Tree Chain partition----
inline void dfs1(int u){
	f[u][0]=0;f[u][1]=a[u];
	siz[u]=1;
	for(int i=first[u];i;i=e[i].nxt){
		int v=e[i].v;
		if(v^fa[u]){
			fa[v]=u;dfs1(v);siz[u]+=siz[v];
			if(siz[v]>siz[hson[u]]) hson[u]=v;
			f[u][0]+=max(f[v][0],f[v][1]),f[u][1]+=f[v][0];
		}
	}
	g[u][0]=f[u][0]-max(f[hson[u]][0],f[hson[u]][1]);g[u][1]=f[u][1]-f[hson[u]][0];
}
inline void dfs2(int u,int tp){
	top[u]=tp;pos[u]=++tot;num[tot]=u;
	if(hson[u]) dfs2(hson[u],tp);
	else bot[tp]=u;
	for(int i=first[u];i;i=e[i].nxt){
		int v=e[i].v;
		if(v!=fa[u]&&v!=hson[u]) dfs2(v,v);
	}
}
inline pair<int,int> que(int x){
	int b=bot[top[x]];
	mat m=query(1,pos[x],pos[b],1,n);
	return make_pair(m(0,0),m(1,0));
}
inline void modify(int x,int y){
	g[x][1]-=a[x];a[x]=y;g[x][1]+=a[x];
	update(1,pos[x],1,n);x=top[x];
	while(fa[x]){
		g[fa[x]][0]-=max(f[x][0],f[x][1]);g[fa[x]][1]-=f[x][0];
		pair<int,int> ff=que(x);
		f[x][0]=ff.first;f[x][1]=ff.second;
		g[fa[x]][0]+=max(f[x][0],f[x][1]);g[fa[x]][1]+=f[x][0];
		update(1,pos[fa[x]],1,n);
		x=top[fa[x]];
	}
}

//-------------main-------------
int main(){
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;++i) scanf("%d",&a[i]);
	for(int i=1,u,v;i<n;++i){
		scanf("%d%d",&u,&v);
		add(u,v);add(v,u);
	}
	dfs1(1);dfs2(1,1);
	build(1,1,n);
	for(int i=1,x,y;i<=m;++i){
		scanf("%d%d",&x,&y);
		modify(x,y);
		pair<int,int> p=que(1);
		printf("%d\n",max(p.first,p.second));
	}
	return 0;
}

CF750E New Year and Old Subsequence

Description

给定一长为 \(n\) 的字符串,\(q\) 次区间询问其至少删除多少个字符才能让其包
含子序列 2017 但不包含子序列 2016。\(n, q ≤ 2 \times 10^5\)。

Solution

首先列出朴素 \(DP\) ,设 \(dp_{u,0\sim 4}\) 表示考虑了字符串的前 \(u\) 位,当前保留的字符串的末尾是 \(\varnothing,2,20,201,2017\) 时最少需要删去多少字符。

于是大力列出转移方程:

\[dp_{i,0}=dp_{i-1,0}+[s_i=2]\\ dp_{i,1}=\min(dp_{i-1,1}+[s_i=0],dp_{i-1,0}[s_i=2])\\ dp_{i,2}=\min(dp_{i-1,2}+[s_i=1],dp_{i-1,1}[s_i=0])\\ dp_{i,3}=\min(dp_{i-1,2}+[s_i=6/7],dp_{i-1,2}[s_i=1])\\ dp_{i,4}=\min(dp_{i-1,4}+[s_i=6],dp_{i-1,3}[s_i=7]) \]

于是基于此列出一个基于 \((min,+)\) 的矩阵乘法,直接线段树维护即可。

#include<bits/stdc++.h>
using namespace std;
const int N=2e5+10;
const int inf=0x3f3f3f3f;
struct matrix{
	int c[5][5];
	inline void init(){for(int i=0;i<5;++i)for(int j=0;j<5;++j)c[i][j]=inf;}
};
matrix operator *(matrix A,matrix B){
	matrix ret;
	for(int i=0;i<5;++i){
		for(int j=0;j<5;++j){
			ret.c[i][j]=inf;
			for(int k=0;k<5;++k) ret.c[i][j]=min(ret.c[i][j],A.c[i][k]+B.c[k][j]);
		}
	}
	return ret;
}
int n,q;
char s[N];
namespace SGT{
	const int M=N<<2;
	matrix tr[M];
	#define lc (p<<1)
	#define rc (p<<1|1)
	#define mid ((l+r)>>1)
	inline void init(int p,int l,int r){
		if(l==r){
			tr[p].init();
			tr[p].c[0][0]=tr[p].c[1][1]=tr[p].c[2][2]=tr[p].c[3][3]=tr[p].c[4][4]=0;
			if(s[l]=='2') tr[p].c[0][0]=1,tr[p].c[0][1]=0;
			else if(s[l]=='0') tr[p].c[1][1]=1,tr[p].c[1][2]=0;
			else if(s[l]=='1') tr[p].c[2][2]=1,tr[p].c[2][3]=0;
			else if(s[l]=='7') tr[p].c[3][3]=1,tr[p].c[3][4]=0;
			else if(s[l]=='6') tr[p].c[3][3]=tr[p].c[4][4]=1;
			return ;
		}
		init(lc,l,mid);init(rc,mid+1,r);
		tr[p]=tr[lc]*tr[rc];
	}
	inline matrix query(int p,int l,int r,int ql,int qr){
		if(ql<=l&&r<=qr) return tr[p];
		if(ql>mid) return query(rc,mid+1,r,ql,qr);
		if(qr<=mid) return query(lc,l,mid,ql,qr);
		return query(lc,l,mid,ql,qr)*query(rc,mid+1,r,ql,qr);
	}
}
int main(){
	scanf("%d%d",&n,&q);
	scanf("%s",s+1);
	SGT::init(1,1,n);
	for(int i=1,l,r;i<=q;++i){
		scanf("%d%d",&l,&r);
		int x=SGT::query(1,1,n,l,r).c[0][4];
		printf("%d\n",x==inf?-1:x);
	}
	return 0;
}

洛谷P4428 [BJOI2018]二进制

Description

给定一个长为 \(n\) 的 \(01\) 串,\(q\) 次询问 区间 \([l,r]\) 中有多少个位置不同的连续子串满足可以在重新排列后变成一个 \(3\) 的倍数。支持单点修改。\(n,q\le 10^5\)。

Solution

首先,在四进制下,一个数的各位数字之和是 \(3\) 的倍数也意味着这个数是 \(3\) 的倍数,因此,只要二进制下一个子串能够在每一位的数两两配对后让组成的新数之和变为 \(3\) 的倍数即可。

首先,如果有偶数个 \(1\),那么直接所有 \(1\) 配对形成 \(11_{(2)}=3\),总和一定是 \(3\) 的倍数。

有奇数个 \(1\) 时,如果只有一个 \(1\),无论如何都不行。否则,注意到\((01)_{2}+(10)_2=3\),因此只要存在 \(2\) 个 \(0\) 与 \(2\) 个 \(1\) 进行配对,其他的 \(1\) 两两配对,总和也是 \(3\) 的倍数。

因此无法满足条件的只有一下两种情况:

  • 只有 \(1\) 个 \(1\)
  • 有奇数个 \(1\),且 \(0\) 的数量不超过 \(1\)

这两者之间有重复情况,那就是为 \(1\) 或 \(01\) 或 \(10\) 的子串,这样的子串个数是容易用树状数组求出的。

对于只有 \(1\) 个 \(1\) 的情况,定义 \(f_{i,0/1}\) 表示以 \(i\) 作为右端点,满足 \(1\) 的数量为 \(0/1\) 的左端点有多少个,容易写出转移矩阵,用线段树维护。

对于有奇数个 \(1\) 的情况,定义 \(g_{i,0/1.0/1}\),表示以 \(i\) 为右端点,满足 \(1\) 的数量为偶数/奇数,\(0\) 的数量为 \(0/1\) 的左端点有多少个,同样写出转移矩阵,用线段树维护。

还有一个问题,我们求的实际上是一段 \(dp\) 值的前缀和,因此在定义一个 \(s_i\) 表示以 \([l,i]\) 作为右端点时的合法子串数量,第一部分\(s_i=s_{i-1}+f_{i,1}\),第二部分 \(s_i=s_{i-1}+g_{i,1,0/1}\),然后将 \(f\) 与 \(g\) 的转移带进去,最终依然能写成矩阵的形式,用线段树一起维护即可。

Code

#include<bits/stdc++.h>
using namespace std;
const int N=2e5+10;
const int M=N<<2;
typedef long long ll;
const int inf=0x3f3f3f3f;
struct matrix{
	ll c[8][8];
	int n,m;
	inline void init(int _n,int _m){
		n=_n;m=_m;
		for(int i=0;i<n;++i) for(int j=0;j<m;++j) c[i][j]=0;
	}
};
matrix operator *(matrix A,matrix B){
	matrix ret;ret.init(A.n,B.m);
	for(int i=0;i<A.n;++i){
		for(int j=0;j<B.m;++j){
			for(int k=0;k<A.m;++k) ret.c[i][j]+=A.c[i][k]*B.c[k][j];
		}
	}
	return ret;
}
int n,q,s[N];
struct SGT{
	matrix tr[M];
	#define lc (p<<1)
	#define rc (p<<1|1)
	#define mid ((l+r)>>1)
	inline void init(int p,int tp,int l){
		if(!tp){
			tr[p].init(4,4);
			if(s[l]==0) tr[p].c[0][0]=tr[p].c[0][2]=tr[p].c[1][1]=1,tr[p].c[3][1]=1;
			else tr[p].c[1][0]=tr[p].c[1][2]=1,tr[p].c[3][0]=tr[p].c[3][2]=1;
			tr[p].c[3][3]=tr[p].c[2][2]=1;
		}
		else{
			tr[p].init(6,6);
			if(s[l]==0) tr[p].c[1][0]=tr[p].c[3][2]=tr[p].c[3][4]=1,tr[p].c[5][0]=1;
			else tr[p].c[0][2]=tr[p].c[1][3]=tr[p].c[2][0]=tr[p].c[3][1]=tr[p].c[0][4]=1,tr[p].c[5][2]=tr[p].c[5][3]=tr[p].c[5][4]=1;
			tr[p].c[5][5]=tr[p].c[4][4]=1;
		}
	}
	inline void init(int p,int l,int r,int tp){
		if(l==r){
			init(p,tp,l);
			return ;
		}
		init(lc,l,mid,tp);init(rc,mid+1,r,tp);
		tr[p]=tr[lc]*tr[rc];
	}
	inline void update(int p,int l,int r,int x,int tp){
		if(l==r){
			init(p,tp,l);
			return ;
		}
		if(x<=mid) update(lc,l,mid,x,tp);
		else update(rc,mid+1,r,x,tp);
		tr[p]=tr[lc]*tr[rc];
	}
	inline matrix query(int p,int l,int r,int ql,int qr){
		if(ql<=l&&r<=qr) return tr[p];
		if(ql>mid) return query(rc,mid+1,r,ql,qr);
		if(qr<=mid) return query(lc,l,mid,ql,qr);
		return query(lc,l,mid,ql,qr)*query(rc,mid+1,r,ql,qr);
	}
}A,B;
struct BIT{
	int c[N];
	inline int lowbit(int x){return x&(-x);}
	inline void update(int x,int v){
		for(;x<=n;x+=lowbit(x)) c[x]+=v;
	}
	inline int query(int x){
		int ans=0;
		for(;x;x-=lowbit(x)) ans+=c[x];
		return ans;
	}
}s1,s2,s3;
int main(){
	scanf("%d",&n);
	for(int i=1;i<=n;++i){
		scanf("%d",&s[i]);
		if(s[i]) s1.update(i,1);
		if(s[i]==1&&s[i-1]==0&&i!=1) s2.update(i,1);
		if(s[i]==0&&s[i-1]==1) s3.update(i,1);
	}
	scanf("%d",&q);
	A.init(1,1,n,0);B.init(1,1,n,1);
	for(int i=1,op,l,r;i<=q;++i){
		scanf("%d",&op);
		if(op==1){
			int x;scanf("%d",&x);
			if(s[x]) s1.update(x,-1);
			if(s[x]==1&&s[x-1]==0&&x!=1) s2.update(x,-1);
			if(s[x]==0&&s[x-1]==1) s3.update(x,-1);
			if(s[x+1]==1&&s[x]==0&&x<n) s2.update(x+1,-1);
			if(s[x+1]==0&&s[x]==1&&x<n) s3.update(x+1,-1);
			s[x]^=1;
			if(s[x]) s1.update(x,1);
			if(s[x]==1&&s[x-1]==0&&x!=1) s2.update(x,1);
			if(s[x]==0&&s[x-1]==1) s3.update(x,1);
			if(s[x+1]==1&&s[x]==0&&x<n) s2.update(x+1,1);
			if(s[x+1]==0&&s[x]==1&&x<n) s3.update(x+1,1);
			
			A.update(1,1,n,x,0);
			B.update(1,1,n,x,1);
		}
		if(op==2){
			scanf("%d%d",&l,&r);
			matrix m=A.query(1,1,n,l,r);
			int len=r-l+1;
			ll ans=m.c[3][2];//I
			matrix t=B.query(1,1,n,l,r);
			ll now=ans;
			ans+=t.c[5][4];//II
			ans-=s1.query(r)-s1.query(l-1);
			ans-=s2.query(r)-s2.query(l);
			ans-=s3.query(r)-s3.query(l);
			printf("%lld\n",1ll*len*(len+1)/2-ans);
		}
	}
	return 0;
}
上一篇:String s = new String(“java“) 到底创建了几个对象


下一篇:String#intern结果对比源码测试