根号分治——平衡规划思想的应用

wjyppm
发布于 2025-07-16 / 2 阅读
0
0

1. 介绍与引入

没有前言,懒得写了。

根号分治,本质是平衡规划思想(大纲 9 级),在预处理和询问复杂度中寻找平衡,我们通常用根号作为问题规模的分界线。我们确定一个界限 B,小于 B 的暴力预处理,大于的回答一次时间只需要 \dfrac{n}{B}\le \sqrt{n} ,那么整个题目就可以做到 O(n \sqrt{n})

根号平衡思想,是平衡规划思想中重要的内容,例如空间平衡,时间平衡,根号滚动数组,都可以用这种思想。

我们以一道例题引入:CF797E Array Queries

这种操作我们发现没有什么很好的性质来维护,因为 a_{p} 和变化的 p 是有关的。这两个关系是相互制约的,如果我们只关注一个肯定是不行的。怎么办?

首先我们先想暴力,我们有两种想法:

  • 预处理所有 p,k 的答案。
  • 暴力模拟。

第一个虽然可以 O(1) 回答但是预处理时间空间复杂度 O(nk) 无法承受,而暴力算法时间复杂度一次是 O(\dfrac{n}{k}),是无法承受的。

我们怎么平衡这一算法呢,通过基本不等式 k+\dfrac{n}{k}\ge 2\sqrt{b\times \dfrac{n}{k}}=2\sqrt{n}, 当 k=\sqrt{n} 是取等号。也就是我们当 k\le \sqrt{n},我们可以使用预处理的答案,空间是 O(n\sqrt{n}),而 k>\sqrt{n} 的时候我们暴力模拟即可。故时间空间复杂度为 O(n\sqrt{n})

#include<bits/stdc++.h>
using namespace std;
constexpr int MN=1e5+15,MB=300;
int n,m,a[MN],f[MB+15][MN];

int main(){
    cin>>n;
    for(int i=1;i<=n;i++){
        cin>>a[i];
    }
    for(int i=1;i<=MB;i++){
        for(int j=n;j>=1;j--){
            f[i][j]=(j+a[j]+i>n)?1:f[i][j+a[j]+i]+1;
        }
    }
    cin>>m;
    for(int i=1;i<=m;i++){
        int p,k,ans=0;
        cin>>p>>k;
        if(k>=MB){
            while(p<=n){
                ans++;
                p+=a[p]+k;
            }
        }else ans=f[k][p];
        cout<<ans<<'\n';
    }
    return 0;
}

我们简单回顾一下这道题,我们通过将两种看似暴力的算法结合起来,形成了一个时间复杂度为 O(n\sqrt{n}) 的高效的算法,相互制约的关系我们很难用常规的数据结构进行处理,因为只关心一个的话另一个就会制约你的复杂度。这个时候我们利用平衡规划思想,我们通过制约关系设计出两种算法:O(nk)O(\dfrac{n^2}{k}) 的算法,但是我们通过基本不等式算出分界点,通过这个分界点来进行所谓的 “分治”,数据小的跑预处理,大的进行暴力。

这一类思想,就是根号分治的思想,平衡规划。而一般制约关系,或同时涉及两个集合的关系,如果没有特殊性质,基本不能 \text{polylog} 去做,但是我们通过根号分治就可以做。

接下来我们来看几道例题:

2. 例题

CF1039D You Are Given a Tree

先考虑 O(n^2) 的情况下怎么做,也就是说我们枚举 k,让后一次查询必须是 O(n)。有一个想法就是暴力贪心,从叶子向根合并,搜到长为 k 的链直接取,证明考虑反证法即可。

让后考虑如何优化,由样例手摸不难发现几个特性:

  • 答案随 k 的增大单调不升。
  • 答案不超过 \lfloor \dfrac{n}{k} \rfloor

答案不超过 \lfloor \dfrac{n}{k} \rfloor?那么我们能不能从这里下手呢?由基本不等式可以得到答案不超过 2\sqrt{n},不难想到 2\sqrt{n}< n,那么也就是说答案是重复的,进一步推广,当 k>\sqrt{n} 的时候答案取值是很少的,而 k\le \sqrt{n} 的答案取值是较多的。哎,一会多,一会少,平衡规划?出动!我们对 k\le \sqrt{n} 直接做是 O(n\sqrt{n})。考虑 k>\sqrt{n} 的答案是连续是连续的一段,并且答案具有优秀的单调不升的特性,我们可以通过二分找出下一个答案取值的区间,我们只需要跑 O(\sqrt{n}) 次就可以了,时间复杂度是 O(n\sqrt{n} \log n)

你说得对,但是我学过基本不等式,上面的操作都是假设块长为 \sqrt{n},我们考虑基本不等式,设分治阈值为 B,那么第一块是 nB,第二块是 \dfrac{n^2 \log n}{B},由基本不等式有 B=\sqrt{n\log n} 时取最优时间。

代码如下:

#include<bits/stdc++.h>
using namespace std;
constexpr int MN=1e6+15;
int n,bl,fa[MN],ans[MN],f[MN],dfn[MN],dtot;
vector<int> adj[MN];

void dfs(int u,int pre){
    fa[u]=pre;
    for(auto v:adj[u]){
        if(v==pre) continue;
        dfs(v,u);
    }
    dfn[++dtot]=u;
}

int solve(int k){
    int ret=0;
    for(int i=1;i<=n;i++) f[i]=1;
    for(int i=1;i<=n;i++){
        int u=dfn[i],pre=fa[u];
        if(pre&&f[u]!=-1&&f[pre]!=-1){
            if(f[u]+f[pre]>=k){
                ret++;
                f[pre]=-1;
            }else f[pre]=max(f[pre],f[u]+1);
        }
    }
    return ret;
}

int main(){
    cin>>n;
    bl=sqrt(n*__lg(n));
    for(int i=1;i<n;i++){
        int u,v;
        cin>>u>>v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    dfs(1,0);
    ans[1]=n;
    for(int i=2;i<=bl;i++){
        ans[i]=solve(i);
    }
    for(int i=bl+1;i<=n;i++){
        int tmp=solve(i);
        int l=i,r=n,cnt=i;
        while(l+1<r){
            int mid=(l+r)>>1;
            if(solve(mid)==tmp){
                cnt=max(cnt,mid);
                l=mid;
            }else r=mid;
        }
        for(;i<=cnt;i++) ans[i]=tmp;
        i--;
    }
    for(int i=1;i<=n;i++) cout<<ans[i]<<'\n';
    return 0;
}

[IOI2009] Regions

完蛋啦,又是制约关系,同时涉及两个集合的关系,并且颜色点数与总颜色数相互制约如果我们都考虑显然是不行的。根据我们前面提到的,我们考虑一下根号分治如何去做。

有两个暴力的想法:

  • 分别枚举两个颜色中所有点,利用 DFN 判定是不是在子树内。
  • 将一个颜色所有点加入数据结构,让后枚举另一个颜色所有点,看有多少点在当前子树。

第一个想法时间复杂度 O(n^2),第二个空间复杂度时 O(n^2) 但是时间很好。并且我们前面提到了制约关系,考虑根号分治,我们确定一个阈值 B,颜色点数 >B 的我们称之为重颜色,而 \le B 的我们称之为轻颜色(对应重儿子和轻儿子 www),考虑分类讨论。

  • 重颜色作为祖先节点:考虑预处理答案,时间复杂度容易做到 O(n\sqrt{n \log n}),空间 O(n\sqrt{n})
  • 轻颜色作为祖先节点:枚举轻颜色所有点,考虑对于每一个颜色开 vector 按 DFN 将所有点排序,根据 DFN 顺序判断是否在子树即可,利用二分找边界即可。时间复杂度 O(n\sqrt{n \log n})

故总时间复杂度 O(n\sqrt{n \log n}),空间复杂度 O(n\sqrt{n})

至于为什么我用 DFN,答案是因为一开始我想的虚树,这和我说的加数据结构其实差不太多,因为虚树本身也算一种数据结构吗 www。当然每一次建虚树不如用 DFN 好写啦。

#include<bits/stdc++.h>
using namespace std;
constexpr int MN=2e5+15,ML=30,MK=2.5e4+15;
int n,r,q,ccnt[MN],fid[MN],id[MN],cf[MN],MB;
int ans[520][MK];
vector<int> adj[MN],col[MN],dcol[MN];

namespace Tree{
    int dfn[MN],siz[MN],dtot;

    void dfs(int u,int pre){
        dfn[u]=++dtot;
        siz[u]=1;
        for(auto v:adj[u]){
            if(v==pre) continue;
            dfs(v,u);
            siz[u]+=siz[v];
        }
    }

    int cmpdfn(int x,int y){
        return dfn[x]<dfn[y];
    }
}using namespace Tree;

bool cmp(int x,int y){
    return ccnt[x]>ccnt[y];
}

signed main(){
    ios::sync_with_stdio(0);
    cin>>n>>r>>q;
    MB=sqrt(n*__lg(n)*2);
    for(int i=1;i<=n;i++){
        int fa,color;
        if(i!=1){
            cin>>fa;
            adj[fa].push_back(i);
        }
        cin>>color;
        col[color].push_back(i);
        ccnt[color]++;
    }
    for(int i=1;i<=r;i++) id[i]=i;
    sort(id+1,id+1+r,cmp);
    dfs(1,0);
    for(int i=1;i<=r;i++){
        fid[id[i]]=i;
        for(auto p:col[i]) dcol[i].push_back(dfn[p]);
        sort(dcol[i].begin(),dcol[i].end());
    }
    for(int i=1;i<=r&&ccnt[id[i]]>=MB;i++){
        for(int i=1;i<=n+1;i++) cf[i]=0;
        for(auto p:col[id[i]]){
            cf[dfn[p]]++;
            cf[dfn[p]+siz[p]]--;
        }
        for(int i=1;i<=n+1;i++){
            cf[i]+=cf[i-1];
        }
        for(int j=1;j<=r;j++){
            for(auto p:col[j]){
                ans[i][j]+=cf[dfn[p]];
            }
        }
    }
    while(q--){
        int x,y;
        cin>>x>>y;
        if(ccnt[x]<MB){
            long long ret=0;
            for(auto p:col[x]){
                ret+=upper_bound(dcol[y].begin(),dcol[y].end(),dfn[p]+siz[p]-1)-lower_bound(dcol[y].begin(),dcol[y].end(),dfn[p]);
            }
            cout<<ret<<endl;
        }else cout<<ans[fid[x]][y]<<endl;
    }
    
    return 0;
}

[JRKSJ R2] 你的名字

这种复杂的取模操作,我们一般会利用根号分治来解决这一类问题。

模数不是给定的,这种情况下很难搞,因为我们直接维护模数不固定的数据时很难搞的。

首先,静态区间询问考虑莫队,注意到给出了一个 k>10^3 的包。考虑在 k 很大的情况下取模所构成的循环节长度很长,并且值域只有 10^5,我们可以通过利用 bitset 来暴力跑循环节,时间复杂度为 O(n\sqrt{m}+\dfrac{ma}{\omega}+\dfrac{ma}{k}),实现用 find_first 可以偷懒循环节。

但是 k 很小不能这么做,但是我们发现 k 很小的时候是一个取模数列的 RMQ 啊,可以暴力预处理 k 做四毛子(用传统 ST 表会炸杠)。

但是过不了,考虑若 k 的询问数量很小,我们还不如和 k\ge B 的询问一起处理,这样省下暴力计算的时间。我们可以通过确定一些平衡因子让其自适应数据,这样就能卡过了 www,具体如何操作可以看 meyi 的题解

#include<bits/stdc++.h>
using namespace std;
constexpr int MN=3e5+15,MB=1e5+15;
struct Query{
    int l,r,K,id;
};
int n,m,bl,a[MN],b[MN],ans[MN],cnt[MN];
bitset<MB> f;
vector<Query> qry[MN];

namespace ly
{
    namespace IO
    {
        #ifndef LOCAL
            constexpr auto maxn=1<<20;
            char in[maxn],out[maxn],*p1=in,*p2=in,*p3=out;
            #define getchar() (p1==p2&&(p2=(p1=in)+fread(in,1,maxn,stdin),p1==p2)?EOF:*p1++)
            #define flush() (fwrite(out,1,p3-out,stdout))
            #define putchar(x) (p3==out+maxn&&(flush(),p3=out),*p3++=(x))
            class Flush{public:~Flush(){flush();}}_;
        #endif
        namespace usr
        {
            template<typename type>
            inline type read(type &x)
            {
                x=0;bool flag(0);char ch=getchar();
                while(!isdigit(ch)) flag^=ch=='-',ch=getchar();
                while(isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
                return flag?x=-x:x;
            }
            template<typename type>
            inline void write(type x)
            {
                x<0?x=-x,putchar('-'):0;
                static short Stack[50],top(0);
                do Stack[++top]=x%10,x/=10;while(x);
                while(top) putchar(Stack[top--]|48);
            }
            inline char read(char &x){do x=getchar();while(isspace(x));return x;}
            inline char write(const char &x){return putchar(x);}
            inline void read(char *x){static char ch;read(ch);do *(x++)=ch;while(!isspace(ch=getchar())&&~ch);}
            template<typename type>inline void write(type *x){while(*x)putchar(*(x++));}
            inline void read(string &x){static char ch;read(ch),x.clear();do x+=ch;while(!isspace(ch=getchar())&&~ch);}
            inline void write(const string &x){for(int i=0,len=x.length();i<len;++i)putchar(x[i]);}
            template<typename type,typename...T>inline void read(type &x,T&...y){read(x),read(y...);}
            template<typename type,typename...T>
            inline void write(const type &x,const T&...y){write(x),putchar(' '),write(y...),sizeof...(y)^1?0:putchar('\n');}
            template<typename type>
            inline void put(const type &x,bool flag=1){write(x),flag?putchar('\n'):putchar(' ');}
        }
        #ifndef LOCAL
            #undef getchar
            #undef flush
            #undef putchar
        #endif
    }using namespace IO::usr;
}using namespace ly::IO::usr;

bool mdcmp(Query x,Query y){
    if(x.l/bl==y.l/bl){
        if((x.l/bl)&1) return x.r<y.r;
        return x.r>y.r;
    }
    return x.l/bl<y.l/bl;
}

signed main(){
    read(n,m);
    for(int i=1;i<=n;i++){
        read(a[i]);
    }
    for(int i=1;i<=m;i++){
        int l,r,k;
        read(l,r,k);
        qry[k].push_back({l,r,k,i});
    }
    for(int i=2;i<MN;i++){
        if(qry[i].empty()) continue;
        if(min(MB/i,MB>>6)*qry[i].size()<(n<<2)){
            qry[0].insert(qry[0].end(),qry[i].begin(),qry[i].end());
            continue;
        }
        bl=n/sqrt(qry[i].size()+1)+1;
        for(int j=1;j<=n;j++){
            b[j]=a[j]%i;
        }
        sort(qry[i].begin(),qry[i].end(),mdcmp);
        int l=1,r=0;
        for(auto p:qry[i]){
			while(l>p.l) (!cnt[b[--l]]++)&&(f[b[l]]=1);
			while(r<p.r) (!cnt[b[++r]]++)&&(f[b[r]]=1);
			while(l<p.l) (!--cnt[b[l]])&&(f[b[l]]=0),++l;
			while(r>p.r) (!--cnt[b[r]])&&(f[b[r]]=0),--r;
            for(int k=0;k<i;k++){
                if(f[k]){
                    ans[p.id]=k;
                    break;
                }
            }
        }
        f.reset();
        memset(cnt,0,sizeof(cnt));
    }
    if(!qry[0].empty()){
        bl=n/sqrt(qry[0].size()+1)+1;
        sort(qry[0].begin(),qry[0].end(),mdcmp);
        int l=1,r=0;
        for(auto p:qry[0]){
			while(l>p.l) (!cnt[a[--l]]++)&&(f[a[l]]=1);
			while(r<p.r) (!cnt[a[++r]]++)&&(f[a[r]]=1);
			while(l<p.l) (!--cnt[a[l]])&&(f[a[l]]=0),++l;
			while(r>p.r) (!--cnt[a[r]])&&(f[a[r]]=0),--r;
            ans[p.id]=1e9;
			for(int k=f._Find_first(); ans[p.id]&&k!=f.size(); k=(k/p.K+1)*p.K-1>=f.size()?f.size():f._Find_next((k/p.K+1)*p.K-1)) 
				(ans[p.id]>k%p.K)&&(ans[p.id]=k%p.K);
        }
    }
    for(int i=1;i<=m;i++) put(ans[i]);
    return 0;
}

[Ynoi2008] rplexq

我们要求的值就是 x 子树内 [l,r] 点个数的平法减去每个儿子子树内 [l,r] 点个数的平方,让后整体除二,即子树两两配对即可。

但是暴力做是 O(n^2) 的,无法接受,考虑如何优化,注意到瓶颈在枚举儿子的子树。我们从儿子下手,根号分治,对于每个节点按照儿子格式分成 >\sqrt{n}\le \sqrt{n} 两组。

小于 \sqrt{n} 的我们可以把区间拆成二维数点来动态加点进行维护,这样的点数是 O(n) 个,询问时 O(m\sqrt{n}),考虑复杂度平衡,我们可以利用分块前缀和的方式进行维护,单点修改时 O(\sqrt{n}),而查询 O(1)。时间复杂度 O((n+m)\sqrt{n}),但空间是 O(m\sqrt{n}) 的。注意到同一个点上每个询问的询问区间相同,我们 O(n) 的存下每个子树区间扫描线,扫描到一点 x 在将询问进行二维数点。

而大于 \sqrt{n} 这么做直接复杂度螺旋式爆炸上天,考虑涉及另一个算法,我们发现在一个区间的点可以类比于颜色,那么问题就是统计 [l,r] 编号内的颜色平方和,考虑对每一个点建立离散化莫队,时间复杂度 O(n\sqrt{m}),最终时间复杂度可以平衡到 O(n\sqrt{n}+n\sqrt{m}+q\sqrt{n})

但是可怕的是我没卡过,54 pt 代码如下:

#include<bits/stdc++.h>
#define ll long long
#define pir pair<int,int>
using namespace std;
constexpr int MN=5e5+100,MB=100,MBL=500;
struct Query{
    int l,r,id;
    ll op;
}tqry[MN];
int n,m,rt,R[MN],tmp[MN],dg[MN],pos[MN],bl;
ll ans1,ans2,ans3[MN],sum[MN],cnt[MN],ans[MN];
bool vis[MN];
vector<int> adj[MN];
vector<Query> qry[MN];

namespace Tree{
    int siz[MN],fa[MN],dfn[MN],dtot;
    pir a[MN];
    
    void dfs1(int u,int pre){
        siz[u]=1;
        fa[u]=pre;
        for(auto v:adj[u]){
            if(v==pre) continue;
            dfs1(v,u);
            siz[u]+=siz[v];
        }
    }

    void dfs2(int u,int pre){
        dfn[++dtot]=u;
        a[dtot]=pir(u,pre);
        for(auto v:adj[u]){
            if(v==pre) continue;
            dfs2(v,u);
        }
    }

}using namespace Tree;

bool cmpsiz(int x,int y){
    return siz[x]>siz[y];
}

bool cmpmd(Query x,Query y){
    if(pos[x.l]!=pos[y.l]) return pos[x.l]<pos[y.l];
    return (pos[x.l]&1)?x.r<y.r:x.r>y.r;
}

void add(int x,ll op){
    ans1+=1ll*1+cnt[x]*2*op;
    cnt[x]+=op;
    ans2+=op;
}

ll query(int x){
    return (x?(R[x]==x?sum[pos[x]]:sum[pos[x]-1]+cnt[x]):0);
}

void update(int x){
    for(int i=pos[x];i<=pos[n];i++){
        sum[i]++;
    }
    for(int i=x;i<=R[x];i++) cnt[i]++;
}

void solve1(int x){
    if(qry[x].empty()) return;
    int tmptot=0;
    for(auto v:adj[x]){
        if(v==fa[x]) continue;
        tmp[++tmptot]=v;
    }
    sort(tmp+1,tmp+1+tmptot,cmpsiz);
    dtot=0;
    for(int i=MB+1;i<=tmptot;i++){
        dfs2(tmp[i],tmp[i]);
        vis[tmp[i]]=1;
    }
    sort(a+1,a+1+dtot);
    sort(dfn+1,dfn+1+dtot);
    for(int i=0;i<qry[x].size();i++){
        int ql=lower_bound(dfn+1,dfn+1+dtot,qry[x][i].l)-dfn;
        int qr=upper_bound(dfn+1,dfn+1+dtot,qry[x][i].r)-dfn-1;
        tqry[i+1]={ql,qr,qry[x][i].id};
    }
    int l=1,r=0,bl=dtot/sqrt(qry[x].size())+1;
    for(int i=1;i<=dtot;i++){
        pos[i]=(i+bl-1)/bl;
    }
    sort(tqry+1,tqry+1+qry[x].size(),cmpmd);
    ans1=ans2=0;
    for(int i=1;i<=qry[x].size();i++){
        if(tqry[i].l>dtot||tqry[i].r<1) continue;
        while(l<tqry[i].l) add(a[l++].second,-1);
        while(l>tqry[i].l) add(a[--l].second,1);
        while(r<tqry[i].r) add(a[++r].second,1);
        while(r>tqry[i].r) add(a[r--].second,-1);
        ans[tqry[i].id]-=ans1;
        ans3[tqry[i].id]+=ans2;
    }
}

void solve2(int x){
    if(!vis[x]&&fa[x]&&!qry[fa[x]].empty()){
        for(int i=0;i<qry[fa[x]].size();i++){
            qry[fa[x]][i].op=query(qry[fa[x]][i].r)-query(qry[fa[x]][i].l-1);
        }
    }
    update(x);
    for(auto v:adj[x]){
        if(v==fa[x]) continue;
        solve2(v);
    }
    if(!vis[x]&&fa[x]&&!qry[fa[x]].empty()){
        for(int i=0;i<qry[fa[x]].size();i++){
            ll qwq=query(qry[fa[x]][i].r)-query(qry[fa[x]][i].l-1)-qry[fa[x]][i].op;
            ans[qry[fa[x]][i].id]-=qwq*qwq;
            ans3[qry[fa[x]][i].id]+=qwq;
        }
    }
    for(int i=0;i<qry[x].size();i++){
        if(qry[x][i].l<=x&&x<=qry[x][i].r){
            ans[qry[x][i].id]+=ans3[qry[x][i].id]*2;
        }
    }
}

signed main(){
    read(n,m,rt);
    for(int i=1;i<n;i++){
        int u,v;
        read(u,v);
        adj[u].push_back(v);
        adj[v].push_back(u);
        dg[u]++,dg[v]++;
    }
    for(int i=1;i<=m;i++){
        int l,r,x;
        read(l,r,x);
        qry[x].push_back({l,r,i,0});
    }
    dfs1(rt,0);
    for(int i=1;i<=n;i++){
        if(i!=rt) dg[i]--;
        if(dg[i]>MB){
            solve1(i);
        }
    }
    for(int i=1;i<=n;i++){
        pos[i]=(i+MBL-1)/MBL;
        R[i]=min(n,pos[i]*MBL);
    }
    memset(cnt,0,sizeof(cnt));
    solve2(rt);
    for(int i=1;i<=m;i++){
        put(((ans3[i]*ans3[i]+ans[i])>>1));
    }
    return 0;
}

P3591 [POI 2015] ODW

序列跳跃问题可以直接对后继的距离根号分治的。

一个显然的想法就是类似于倍增二分去模拟在树上走路(即倍增求树上 k 级祖先,也可以长链剖分做),但是这在 k 很大的情况下时可以的,但是 k 很小的情况是做不到的。具体来说,步数接近于 \dfrac{n}{k} 的范围附近。坏了又是 \dfrac{n}{k},我们考虑根号分治,设定阈值 B>B 当然就是我们树上 k 级祖先暴力跳,时间复杂度 O(n\sqrt{n})。当 k\le B 的时候我们可以考虑暴力处理 sum(i,j) 表示 i 往上每 j 级走一步的答案,查询可以树上差分即可。

时间复杂度 O(n\sqrt{n} \log n),预处理 O(n\sqrt{n})。利用长链剖分可以做到 O(n\log n+n\sqrt{n})

但是我写这篇文章的时候我还不会长链剖分?

#include<bits/stdc++.h>
using namespace std;
constexpr int MN=5e4+15,MB=250;
int n,a[MN],b[MN],c[MN],sum[MN],s[MB+5][MN],fa[32][MN],dep[MN];
vector<int> adj[MN];

void dfs1(int u,int pre){
    dep[u]=dep[pre]+1;
    fa[0][u]=pre;
    sum[u]=sum[pre]+a[u];
    for(int i=1;i<=30;i++){
        fa[i][u]=fa[i-1][fa[i-1][u]];
    }
    for(auto v:adj[u]){
        if(v==pre) continue;
        dfs1(v,u);
    }
}

void dfs2(int u,int pre){
    int p=pre;
    for(int i=2;i<=MB;i++){
        p=fa[0][p];
        if(p==0) break;
        s[i][u]=s[i][p]+a[u];
    }
    for(auto v:adj[u]){
        if(v==pre) continue;
        dfs2(v,u);
    }
}

int lca(int x,int y){
    if(dep[x]>dep[y]){
        swap(x,y);
    }
    for(int i=30;i>=0;i--){
        if(fa[i][y]&&dep[fa[i][y]]>=dep[x]) y=fa[i][y];
    }
    if(x==y) return x;
    for(int i=30;i>=0;i--){
        if(fa[i][x]!=fa[i][y]){
            x=fa[i][x],y=fa[i][y];
        }
    }
    return fa[0][x];
}

int getfa(int x,int k){
    for(int i=30;i>=0;i--){
        if((k>>i)&1) x=fa[i][x];
    }
    return x;
}

int main(){
    cin>>n;
    for(int i=1;i<=n;i++){
        cin>>a[i];
    }
    for(int i=1;i<n;i++){
        int u,v;
        cin>>u>>v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    for(int i=1;i<=n;i++){
        cin>>b[i];
    }
    for(int i=1;i<n;i++) cin>>c[i];
    dfs1(1,0);
    dfs2(1,0);
    for(int i=1;i<n;i++){
        int u=b[i],v=b[i+1],k=c[i];
        int lcaa=lca(u,v);
        if(k==1){
            cout<<sum[u]+sum[v]-sum[lcaa]-sum[fa[0][lcaa]]<<'\n';
        }
        else if(k<=MB){
            int ans=s[k][u],dis=(dep[u]-dep[lcaa])%k;
            if(dis==0) dis=k;
            for(int i=30;i>=0;i--){
                if(fa[i][u]&&dep[fa[i][u]]-dep[lcaa]>=dis) u=fa[i][u];
            }
            ans+=a[u]-s[k][u];
            if(dep[u]+dep[v]-(dep[lcaa]<<1)>=k){
                dis=k-dep[u]+dep[lcaa];
                u=v;
                for(int i=30;i>=0;i--){
                    if(fa[i][u]&&dep[fa[i][u]]-dep[lcaa]>=dis) u=fa[i][u];
                }
                dis=(dep[v]-dep[u])%k;
                if(dis!=0) ans+=a[v];
                v=getfa(v,dis);
                ans+=s[k][v]-s[k][u]+a[u];
            }else ans+=a[v];
            cout<<ans<<'\n';
        }else{
            int ans=0;
            while(dep[u]-dep[lcaa]>k){
                ans+=a[u];
                u=getfa(u,k);
            }
            ans+=a[u];
            if(dep[u]+dep[v]-(dep[lcaa]<<1)>=k){
                int dis=k-dep[u]+dep[lcaa];
                u=v;
                for(int i=30;i>=0;i--){
                    if(fa[i][u]&&dep[fa[i][u]]-dep[lcaa]>=dis) u=fa[i][u];
                }
                dis=(dep[v]-dep[u])%k;
                if(dis!=0) ans+=a[v];
                v=getfa(v,dis);
                while(dep[v]-dep[u]>=k){
                    ans+=a[v];
                    v=getfa(v,k);
                }
                ans+=a[v];
            }else ans+=a[v];
            cout<<ans<<'\n';
        }
    }
    return 0;
}

[Ynoi2011] 初始化

如果你看了论文的话,这个就是第一题的双倍经验。

大多数的数据结构都适合连续区间的询问,但是不擅长这种间隔离散的询问,步数与项数相互制约关系,我们考虑根号分治,设定一个阈值 B,对于 >B 的想法当然是暴力处理啦,但是 \le B 很难受。

分析性质,我们每一次修改都是对整个序列进行修改,对于 x,y 相同的修改我们可以累加贡献,但是我们查询要对所有 x 的都进行一次查询,我们要做到单次 O(1)。考虑我们维护每个 x,y 的前缀后缀和,我们一次询问相当于把序列按 x 分块,我们借用 YLWang 的 P5309 题解的图片:

那么之间完整段会被所有含 x 的操作修改,而两端会被部分修改,询问同理,我们利用前面提到的前缀后缀即可快速维护即可。

#include<bits/stdc++.h>
#define pos(x) ((x-1)/BL+1)
using namespace std;
constexpr int MN=2e5+5,MOD=1e9+7,BL=128;
int n,m,bl,a[MN],sum[MN],L[MN],R[MN],pre[BL+15][BL+15],suf[BL+15][BL+15];

inline void upd(int &x){x+=x>>31&MOD;}

void init(){
    bl=(n-1)/BL+1;
    for(int i=1;i<=bl;i++){
        L[i]=(i-1)*BL+1;
        R[i]=i*BL;
    }
    R[bl]=n;
    for(int i=1;i<=bl;i++){
        sum[i]=0;
        for(int j=L[i];j<=R[i];j++){
            sum[i]+=a[j];
            upd(sum[i]-=MOD);
        }
    }
}

int query(int l,int r){
    int ql=pos(l),qr=pos(r),ret=0;
    if(ql==qr){
        for(int i=l;i<=r;i++){ret+=a[i];upd(ret-=MOD);}
        return ret;
    }
    for(int i=l;i<=R[ql];i++){ret+=a[i];upd(ret-=MOD);}
    for(int i=ql+1;i<qr;i++){ret+=sum[i];upd(ret-=MOD);}
    for(int i=L[qr];i<=r;i++){ret+=a[i];upd(ret-=MOD);}
    return ret;
}

void add(int x,int y,int z){
    z-=MOD;upd(z);
    if(x>=BL){
        for(int i=y;i<=n;i+=x){
            a[i]+=z;upd(a[i]-=MOD);
            sum[pos(i)]+=z;upd(sum[pos(i)]-=MOD);
        }
    }else{
        for(int i=x;i>=y;i--){pre[x][i]+=z;upd(pre[x][i]-=MOD);}
        for(int i=1;i<=y;i++){suf[x][i]+=z;upd(suf[x][i]-=MOD);}
    }
}

signed main(){
    read(n);read(m);
    for(int i=1;i<=n;i++)read(a[i]);
    init();
    while(m--){
        int op,x,y,z,l,r;
        read(op);
        if(op==1){
            read(x);read(y);read(z);
            add(x,y,z);
        }else{
            read(l);read(r);
            int ret=query(l,r);
            for(int i=1;i<BL;i++){
                int blkL=(l-1)/i+1,blkR=(r-1)/i+1;
                if(blkL==blkR){
                    ret+=pre[i][(r-1)%i+1];upd(ret-=MOD);
                    ret-=pre[i][(l-1)%i];upd(ret);
                }else{
                    ret+=(blkR-blkL-1LL)*pre[i][i]%MOD;upd(ret-=MOD);
                    ret+=pre[i][(r-1)%i+1];upd(ret-=MOD);
                    ret+=suf[i][(l-1)%i+1];upd(ret-=MOD);
                }
            }
            put(ret);
        }
    }
    return 0;
}

CF1056H Detect Robots

s=\sum\limits len

考虑我们答案到底是怎么算的,其实就是枚举一个串两个字符 A\to C \to \dots \to B,如果之前出现过 A \to D \to \dots \to B 的路径,检查 C 是否等于 D 即可,若不等于输出 Human。一个显然的想法就是枚举 A,B 轻松 O(s^3),但是我们可以只需要枚举一个,比如枚举 B 让后拓展即可,这样的时间复杂度 O(s^2)。如果 s 过大直接就爆炸了,考虑一个关键性质:数据范围 3 \times 10^5,考虑到长为 k 的路径最多有 \dfrac{n}{k} 个,考虑根号分治,按串长分治。小串对小串暴力即可。

大串我们可以暴力枚举,记 pos_{i} 表示 i 在大串中的位置,对于其他串从后往前扫,判断当前字母的 pos<mxpos_{i} 是否成立即可,时间复杂度 O(s\sqrt{s})

#include<bits/stdc++.h>
#define pir pair<int,int>
using namespace std;
constexpr int MN=1e6+15;
int T,n,m,B,pos[MN],vis[MN];
vector<int> a[MN];
vector<pir> v[MN];

void solve(){
    cin>>n>>m;
    for(int i=1;i<=n;i++){
        v[i].clear();
        vis[i]=pos[i]=0;
    }
    for(int i=1;i<=m;i++){
        a[i].clear();
        int K;
        cin>>K;
        for(int j=1;j<=K;j++){
            int x;
            cin>>x;
            a[i].push_back(x);
        }
        if(a[i].size()<=B){
            for(int j=0;j<a[i].size();j++){
                for(int k=j+1;k<a[i].size();k++){
                    v[a[i][k]].push_back(pir(a[i][j],a[i][j+1]));
                }
            }
        }
    }
    for(int i=1;i<=m;i++){
        if(a[i].size()<=B) continue;
        for(int j=1;j<=n;j++) pos[j]=-1;
        for(int j=0;j<a[i].size();j++) pos[a[i][j]]=j;
        for(int j=i+1;j<=m;j++){
            int r=-1;
            for(int k=a[j].size()-1;k>=0;k--){
                if(pos[a[j][k]]==-1) continue;
                if(pos[a[j][k]]>r){
                    r=pos[a[j][k]];
                }
                else if(a[j][k+1]!=a[i][pos[a[j][k]]+1]){
                    cout<<"Human\n";
                    return;
                }
            }
        }
    }
    for(int i=1;i<=n;i++){
        for(auto p:v[i]){
            if(vis[p.first]&&vis[p.first]!=p.second){
                cout<<"Human\n";
                return;
            }
            vis[p.first]=p.second;
        }
        for(auto p:v[i]) vis[p.first]=0;
    }
    cout<<"Robot\n";
}

int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    B=sqrt(300000)/2;
    cin>>T;
    while(T--){
        solve();
    }
    return 0;
}

CF587F Duff is Mad

暴力的想法就是重新建 AC 自动机,不得不承认这个想法及其糟糕。

考虑直接对所有串建立 AC 自动机,那么答案就是将 s\in [l,r] 上串对应 Fail 树上的权值加一后求权值和,考虑离线下来加根号分治,阈值为 B,对于 |S_k|>B 的考虑将每个串询问做差,顺序扫过即可,时间复杂度 O(n\sqrt{n})

对于 |S_{k}| \le B 考虑扫描线,扫到一个串就权值加一,让后暴力查询即可,其实这两个操作都是在 DFS 序上区间加单点查,树状数组即可,时间房租啊都 O(n \log m+QT\log m)

#include<bits/stdc++.h>
#define int long long
#define pir pair<int,int>
using namespace std;
constexpr int N=1e5+7;
int n,q,sumlen,MB,ans[N];
string s[N];
vector<int> adj[N];
vector<pir> L1[N],R1[N],L2[N],R2[N];

struct BIT{
    int t[N];

    int lowbit(int x){
        return x&-x;
    }

    void modify(int x,int k){
        while(x<N){
            t[x]+=k;
            x+=lowbit(x);
        }
    }

    int query(int x){
        int ret=0;
        while(x){
            ret+=t[x];
            x-=lowbit(x);
        }
        return ret;
    }
}bit;

namespace ACAuto{
    int trie[N][26],fail[N],fa[N],ed[N],tot=1;
    int sum[N],siz[N],dfn[N],dtot;

    void insert(string s,int id){
        int p=1;
        for(auto c:s){
            int k=c-'a';
            if(!trie[p][k]) trie[p][k]=++tot,fa[tot]=p;
            p=trie[p][k];
        }
        ed[id]=p;
    }

    void build(){
        queue<int> q;
        for(int i=0;i<26;i++){
            if(trie[1][i]) fail[trie[1][i]]=1,q.push(trie[1][i]);
            else trie[1][i]=1;
        }
        while(!q.empty()){
            int x=q.front();
            q.pop();
            for(int i=0;i<26;i++){
                if(trie[x][i])
                    fail[trie[x][i]]=trie[fail[x]][i],q.push(trie[x][i]);
                else trie[x][i]=trie[fail[x]][i];
            }
        }
        for(int i=2;i<=tot;i++) adj[fail[i]].push_back(i);
    }

    void dfs1(int u){
        for(auto v:adj[u]){
            dfs1(v);
            sum[u]+=sum[v];
        }
    }

    void dfs2(int u){
        siz[u]=1;
        dfn[u]=++dtot;
        for(auto v:adj[u]){
            dfs2(v);
            siz[u]+=siz[v];
        }
    }
}using namespace ACAuto;

signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0);
    cin>>n>>q;
    for(int i=1;i<=n;i++){
        cin>>s[i];
        sumlen+=s[i].length();
        insert(s[i],i);
    }
    build();
    MB=sumlen/sqrt(q*log2(sumlen));
    
    for(int i=1;i<=q;i++){
        int l,r,k;
        cin>>l>>r>>k;
        if(s[k].length()>MB){
            L1[k].emplace_back(l,i);
            R1[k].emplace_back(r,i);
        }else{
            L2[l].emplace_back(k,i);
            R2[r].emplace_back(k,i);
        }
    }

    for(int i=1;i<=n;i++){
        if(s[i].length()>MB){
            int p=ed[i];
            while(p!=1) sum[p]=1,p=fa[p];
            dfs1(1);
            sort(L1[i].begin(),L1[i].end());
            sort(R1[i].begin(),R1[i].end());
            reverse(L1[i].begin(),L1[i].end());
            reverse(R1[i].begin(),R1[i].end());
            int tmp=0;
            for(int j=1;j<=n;j++){
                while(L1[i].size()&&L1[i].back().first==j){
                    ans[L1[i].back().second]-=tmp;
                    L1[i].pop_back();
                }
                tmp+=sum[ed[j]];
                while(R1[i].size()&&R1[i].back().first==j){
                    ans[R1[i].back().second]+=tmp;
                    R1[i].pop_back();
                }
            }
            for(int i=2;i<=tot;i++) sum[i]=0;
        }
    }

    dfs2(1);
    for(int i=1;i<=n;i++){
        for(auto [k,id]:L2[i]){
            int p=ed[k];
            while(p!=1) ans[id]-=bit.query(dfn[p]),p=fa[p];
        }
        bit.modify(dfn[ed[i]],1);
        bit.modify(dfn[ed[i]]+siz[ed[i]],-1);
        for(auto [k,id]:R2[i]){
            int p=ed[k];
            while(p!=1) ans[id]+=bit.query(dfn[p]),p=fa[p];
        }
    }

    for(int i=1;i<=q;i++) cout<<ans[i]<<'\n';
    return 0;
}

3. 后言

根号分治,做那么多题其实就是根号平衡时空复杂度。注意分治后的情况下具有的性质,同时对次数分类讨论就可以了。

4. 参考


评论