題解:AT_abc422_g Balls and Boxes
【模板】生成函數(shù)
Problem 1
無標(biāo)號數(shù)球盒問題,果斷想到普通生成函數(shù)。
對于第一個盒子,如果盒子個數(shù)是 \(A\) 的倍數(shù),就產(chǎn)生 \(1\) 的貢獻。容易構(gòu)造
\[f_A(x) = \sum _{k \isin \N} x^{kA}
\]
同理有
\[f_B(x) = \sum _{k \isin \N} x^{kB}
\]
\[f_C(x) = \sum _{k \isin \N} x^{kC}
\]
我們要求的是劃分 \(n\) 個球的方案數(shù),也就是
\[[x^n]f_A(x)f_B(x)f_C(x)
\]
NTT 即可。
Problem 2
有標(biāo)號數(shù)球盒問題,果斷想到指數(shù)生成函數(shù)。
同上,為了消除標(biāo)號的貢獻,我們要在每一項的系數(shù)上除以一個排列數(shù)。容易構(gòu)造
\[f_A(x) = \sum _{k \isin \N} \frac{x^{kA}}{(kA)!}
\]
\[f_B(x) = \sum _{k \isin \N} \frac{x^{kB}}{(kB)!}
\]
\[f_C(x) = \sum _{k \isin \N} \frac{x^{kC}}{(kC)!}
\]
記得把標(biāo)號乘回來。答案是
\[[\frac{x^n}{n!}]f_A(x)f_B(x)f_C(x)
\]
NTT 即可。
復(fù)雜度 \(O(n \log n)\),代碼沒人會想看的。
#include<bits/stdc++.h>
#define int long long
#define endl '\n'
#define N 2000006
using namespace std;
namespace POLY{ //by dyc2022
const int MOD=998244353,G=3,invg=(MOD+1)/3,img=86583718;
int r[N];
inline int qpow(int x,int y)
{
if(y==0)return 1;
if(y==1)return x%MOD;
int ret=qpow(x,y>>1);
return ret*ret%MOD*qpow(x,y&1)%MOD;
}
inline void NTT(int len,int *a,int opt)
{
for(int i=0;i<len;i++)if(i<r[i])
swap(a[i],a[r[i]]);
for(int i=1;i<len;i<<=1)
{
int tmp=i<<1,Wn=qpow(opt==1?G:invg,(MOD-1)/tmp);
for(int j=0;j<len;j+=tmp)
{
int w=1,x,y;
for(int k=0;k<i;k++,w=w*Wn%MOD)
x=a[j+k],y=w*a[i+j+k]%MOD,a[j+k]=(x+y)%MOD,a[i+j+k]=(x+MOD-y)%MOD;
}
}
if(opt!=1)
{
int invn=qpow(len,MOD-2);
for(int i=0;i<len;i++)a[i]=a[i]*invn%MOD;
}
}
inline void times(int n,int m,int *a,int *b,int *ans)
{
int len=1,lg=0;
while(len<=n+m)len<<=1,lg++;
for(int i=0;i<len;i++)
r[i]=(r[i>>1]>>1)|((i&1)<<lg-1);
NTT(len,a,1),NTT(len,b,1);
for(int i=0;i<=len;i++)a[i]=a[i]*b[i]%MOD;
NTT(len,a,-1);
for(int i=0;i<=n+m;i++)ans[i]=a[i];
}
int c[N];
inline void getinv(int n,int *a,int *b)
{
if(n==1)return b[0]=qpow(a[0],MOD-2),(void)0;
getinv(n+1>>1,a,b);
int len=1,lg=0;
while(len<(n<<1))len<<=1,lg++;
for(int i=0;i<len;i++)r[i]=(r[i>>1]>>1)|((i&1)<<(lg-1));
for(int i=0;i<n;i++)c[i]=a[i];
for(int i=n;i<len;i++)c[i]=0;
NTT(len,c,1),NTT(len,b,1);
for(int i=0;i<len;i++)
b[i]=(2+MOD-c[i]*b[i]%MOD)%MOD*b[i]%MOD;
NTT(len,b,-1);
for(int i=n;i<len;i++)b[i]=0;
memset(c,0,sizeof(c));
}
int f[N],g[N],fr[N],gr[N],invgr[N],dr[N],td[N],tmp[N];
inline void divide(int n,int m,int *tf,int *tg,int *d)
{
for(int i=0;i<n+m<<1;i++)
f[i]=g[i]=fr[i]=gr[i]=invgr[i]=dr[i]=tmp[i]=0;
for(int i=0;i<n;i++)f[i]=tf[i];
for(int i=0;i<m;i++)g[i]=tg[i];
for(int i=0;i<n;i++)fr[n-i-1]=f[i];
for(int i=0;i<m;i++)gr[m-i-1]=g[i];
getinv(n-m+1,gr,invgr),times(n,n-m+1,fr,invgr,dr);
for(int i=n-m;~i;i--)d[n-m-i]=dr[i];
}
inline void modulo(int n,int m,int *tf,int *tg,int *r)
{
for(int i=0;i<n+m<<1;i++)
f[i]=g[i]=fr[i]=gr[i]=invgr[i]=dr[i]=tmp[i]=td[i]=0;
for(int i=0;i<n;i++)f[i]=tf[i];
for(int i=0;i<m;i++)g[i]=tg[i];
for(int i=0;i<n;i++)fr[n-i-1]=f[i];
for(int i=0;i<m;i++)gr[m-i-1]=g[i];
getinv(n-m+1,gr,invgr),times(n,n-m+1,fr,invgr,dr);
for(int i=n-m;~i;i--)td[n-m-i]=dr[i];
times(m,n-m+1,g,td,tmp);
for(int i=0;i<m-1;i++)
r[i]=(f[i]+MOD-tmp[i])%MOD;
}
int *p[N<<2],length[N<<2],ta[N],tb[N];
inline void init(int len,int lg){for(int i=0;i<len;i++)r[i]=(r[i>>1]>>1)|((i&1)<<(lg-1));}
inline void getp(int u,int l,int r,int *a)
{
length[u]=r-l+1,p[u]=new int[length[u]+1];
if(l==r)
return p[u][0]=(MOD-a[l]),p[u][1]=1,(void)0;
int mid=l+r>>1;
getp(u<<1,l,mid,a),getp(u<<1|1,mid+1,r,a);
int len=1,lg=0;
while(len<(length[u]+1<<1))len<<=1,lg++;
init(len,lg);
for(int i=0;i<=length[u<<1];i++)ta[i]=p[u<<1][i];
for(int i=length[u<<1]+1;i<len;i++)ta[i]=0;
for(int i=0;i<=length[u<<1|1];i++)tb[i]=p[u<<1|1][i];
for(int i=length[u<<1|1]+1;i<len;i++)tb[i]=0;
NTT(len,ta,1),NTT(len,tb,1);
for(int i=0;i<len;i++)ta[i]=ta[i]*tb[i]%MOD;
NTT(len,ta,-1);
for(int i=0;i<=length[u];i++)p[u][i]=ta[i];
}
inline void solve(int u,int l,int r,int *a,int *f,int *ans)
{
if(length[u]<=500)
{
int m=length[u]-1;
for(int i=l;i<=r;i++)
for(int j=m;~j;j--)ans[i]=(ans[i]*a[i]+f[j])%MOD;
return;
}
if(l==r)return ans[l]=*f,(void)0;
int mid=l+r>>1,md[length[u]+2]={0};
modulo(length[u],length[u<<1]+1,f,p[u<<1],md);
solve(u<<1,l,mid,a,md,ans);
modulo(length[u],length[u<<1|1]+1,f,p[u<<1|1],md);
solve(u<<1|1,mid+1,r,a,md,ans);
}
inline void evaluation(int n,int m,int *a,int *f,int *ans)
{
getp(1,1,m,a);
if(n>m)modulo(n,m+1,f,p[1],f);
solve(1,1,m,a,f,ans);
}
inline void getdev(int n,int *f,int *g)
{
for(int i=1;i<n;i++)g[i-1]=i*f[i]%MOD;
g[n-1]=0;
}
inline void getinvdev(int n,int *f,int *g)
{
for(int i=1;i<n;i++)
g[i]=f[i-1]*qpow(i,MOD-2)%MOD;
g[0]=0;
}
inline void getln(int n,int *f,int *g)
{
memset(ta,0,sizeof(ta));
memset(tb,0,sizeof(tb));
getdev(n,f,ta),getinv(n,f,tb);
int lg=0,len=1;
while(len<(n<<1))lg++,len<<=1;
init(len,lg),NTT(len,ta,1),NTT(len,tb,1);
for(int i=0;i<len;i++)ta[i]=ta[i]*tb[i]%MOD;
NTT(len,ta,0),getinvdev(n,ta,g);
}
int t[N];
inline void getexp(int n,int *f,int *g)
{
if(n==1)return g[0]=1,(void)0;
getexp(n+1>>1,f,g),getln(n,g,t);
int len=1,lg=0;
while(len<=(n<<1))len<<=1,lg++;
for(int i=1;i<len;i++)r[i]=(r[i>>1]>>1)|((i&1)<<lg-1);
memset(ta,0,sizeof(ta));
for(int i=0;i<n;i++)ta[i]=f[i];
for(int i=n;i<len;i++)t[i]=ta[i]=0;
for(int i=0;i<len;i++)ta[i]=(ta[i]-t[i]+MOD)%MOD;
ta[0]++,NTT(len,g,1),NTT(len,ta,1);
for(int i=0;i<len;i++)g[i]=g[i]*ta[i]%MOD;
NTT(len,g,0);
for(int i=n;i<len;i++)g[i]=0;
for(int i=0;i<len;i++)ta[i]=0;
}
int pwtmp[N];
inline void getpow(int n,int *f,int *g,int k)
{
memset(pwtmp,0,sizeof(pwtmp)),getln(n,f,pwtmp);
for(int i=0;i<n;i++)pwtmp[i]=pwtmp[i]*k%MOD;
getexp(n,pwtmp,g);
}
int tx1[N],tx2[N],tx3[N];
inline void getsin(int n,int *f,int *g)
{
memset(tx1,0,sizeof(tx1));
memset(tx2,0,sizeof(tx2));
memset(tx3,0,sizeof(tx3));
for(int i=0;i<n;i++)tx1[i]=f[i]*img%MOD;
getexp(n,tx1,tx2),getinv(n,tx2,tx3);
for(int i=0;i<n;i++)g[i]=(tx2[i]-tx3[i]+MOD)%MOD*qpow(img<<1,MOD-2)%MOD;
}
inline void getcos(int n,int *f,int *g)
{
memset(tx1,0,sizeof(tx1));
memset(tx2,0,sizeof(tx2));
memset(tx3,0,sizeof(tx3));
for(int i=0;i<n;i++)tx1[i]=f[i]*img%MOD;
getexp(n,tx1,tx2),getinv(n,tx2,tx3);
for(int i=0;i<n;i++)g[i]=(tx2[i]+tx3[i])%MOD*(MOD+1>>1)%MOD;
}
}
using POLY::MOD;
using POLY::qpow;
int n,a,b,c,fac[N],ifac[N];
int f1[N],f2[N],f3[N],g1[N],g2[N];
void init()
{
fac[0]=1;
for(int i=1;i<N;i++)fac[i]=fac[i-1]*i%MOD;
ifac[N-1]=qpow(fac[N-1],MOD-2);
for(int i=N-2;~i;i--)ifac[i]=ifac[i+1]*(i+1)%MOD;
}
void solve1()
{
for(int i=0;i*a<=n;i++)f1[i*a]=1;
for(int i=0;i*b<=n;i++)f2[i*b]=1;
for(int i=0;i*c<=n;i++)f3[i*c]=1;
POLY::times(n+1,n+1,f1,f2,g1);
POLY::times(n+1,n+1,g1,f3,g2);
printf("%lld\n",g2[n]);
}
void solve2()
{
init();
for(int i=0;i<N;i++)f1[i]=f2[i]=f3[i]=g1[i]=g2[i]=0;
for(int i=0;i*a<=n;i++)f1[i*a]=ifac[i*a];
for(int i=0;i*b<=n;i++)f2[i*b]=ifac[i*b];
for(int i=0;i*c<=n;i++)f3[i*c]=ifac[i*c];
POLY::times(n+1,n+1,f1,f2,g1);
POLY::times(n+1,n+1,g1,f3,g2);
printf("%lld\n",g2[n]*fac[n]%MOD);
}
main()
{
scanf("%lld%lld%lld%lld",&n,&a,&b,&c);
solve1(),solve2();
return 0;
}

浙公網(wǎng)安備 33010602011771號