📜  算法测验| SP竞赛1 |问题6(1)

📅  最后修改于: 2023-12-03 15:41:09.856000             🧑  作者: Mango

SP竞赛1 算法测验 - 问题6

本文介绍 SP竞赛1 中的算法测验问题6,属于难度较高的问题。

题目描述

给定一个由n个节点(n<=10^5)组成的树,每个节点有一个正整数的权值和颜色。预处理一个操作:查询树上从节点u到节点v的路径中颜色为c的节点的权值的和,其中u和v为任意给定的节点,c是任意给定的颜色。预处理时间和单次查询时间均不得超过O(nlogn)。

解题思路

本题需要预处理一些信息,以便快速查询两个节点u和v之间,颜色为c的节点的权值之和。

考虑使用树链剖分(Heavy-Light Decomposition),这是一种常用的树分治算法。

具体来说,我们将一棵树分为若干重链(heavy)和非重链(light)。对于每个节点,我们记录它所在的重链、以及它在重链上的深度(dis)和它的轻儿子(lightson)。另外,从它所在重链的链头到该节点的路径上,我们维护一个线段树,以支持路径和的修改和查询。这些信息的预处理和单词查询的时间均为O(nlogn)。

现在考虑查询操作。假设要查询子树内颜色为c的节点的权值之和,我们可以用树状数组来完成。我们对每个节点维护一个树状数组,表示它的子树内各个颜色的节点出现次数的和。在查询时,我们搜索u和v之间的路径,并将路径上每个节点的颜色出现次数相加,最后查询树状数组即可得到答案。

具体网络流程如下:

//对节点u所在重链上的线段树,查询从dis[u]到dis[v]这一段,颜色为c的节点权值之和
int query_chain(int u,int v,int c){
    int ans=0;
    while(top[u]!=top[v]){
        if(dis[top[u]]<dis[top[v]]) swap(u,v);
        ans+=query_range(1,1,n,dis[top[u]],dis[u],c);
        u=fa[top[u]];//从u往上跳一级
    }
    if(dis[u]>dis[v]) swap(u,v);
    ans+=query_range(1,1,n,dis[u],dis[v],c);//查询剩余部分
    return ans;
}

//对节点u和v之间的路径,查询颜色为c的节点的权值之和
int query_path(int u,int v,int c){
    int ans=0;
    while(top[u]!=top[v]){
        if(dis[top[u]]<dis[top[v]]) swap(u,v);
        ans+=query_chain(u,top[u],c);
        u=fa[top[u]];//从u往上跳一级
    }
    if(dis[u]>dis[v]) swap(u,v);
    ans+=query_chain(u,v,c);//查询剩余部分
    return ans;
}
代码实现

以下是使用C++实现的代码。其中,$sum[u][c]$表示颜色c在以u为根节点的子树内的出现次数。$cnt[u]$表示以u为根节点的子树内的节点个数。

#include<bits/stdc++.h>
#define N 100010
using namespace std;
int n,m,rt,cnt;//cnt表示节点的数量
int sum[N][2],f[N][20],top[N],son[N],siz[N],cnt1,cnt2;
int w[N],c[N],head[N],vis[N],light[N],a[N];
struct node{
    int v,w,nxt;
}edge[N<<2];
void add(int u,int v,int w){
    edge[++cnt1].v=v;
    edge[cnt1].w=w;
    edge[cnt1].nxt=head[u];
    head[u]=cnt1;
}
void dfs1(int u,int fa){
    f[u][0]=fa;//层数为0,即自己的父亲的深度为0
    siz[u]=1;
    int mxson=-1;
    for(int i=head[u];i;i=edge[i].nxt){
        int v=edge[i].v;
        if(v==fa) continue;
        dfs1(v,u);
        siz[u]+=siz[v];
        if(siz[v]>mxson){
            mxson=siz[v];
            son[u]=v;
        }
    }
}
void dfs2(int u,int tp){
    top[u]=tp;
    vis[u]=(++cnt2);
    if(!son[u]) return;
    dfs2(son[u],tp);
    for(int i=head[u];i;i=edge[i].nxt){
        int v=edge[i].v;
        if(v==f[u][0]||v==son[u]) continue;
        dfs2(v,v);
    }
}
void update(int u){
    for(int i=0;i<=1;i++){
        sum[u][i]=sum[son[u]][i];
    }
    sum[u][c[u]]=sum[son[u]][c[u]]+w[u];
    for(int i=head[u];i;i=edge[i].nxt){
        int v=edge[i].v;
        if(v==f[u][0]||v==son[u]) continue;
        update(v);
        for(int j=0;j<=1;j++){
            sum[u][j]+=sum[v][j];
        }
    }
}
void modify_chain(int l,int r,int c,int x){
    if(l==r){
        if(c==1){
            a[l]++;
            if(a[l]==2){
                sum[x][1]+=w[l];
            }
        }else{
            a[l]--;
            if(a[l]==1){
                sum[x][0]+=w[l];
            }
        }
    }
    else{
        int mid=(l+r)>>1;
        if(dis[u]<=mid) modify_range(l,mid,ls,c,x);
        else modify_range(mid+1,r,rs,c,x);
        for(int i=0;i<=1;i++){
            sum[x][i]=sum[ls][i]+sum[rs][i];
        }
    }
}
void modify_path(int u,int v,int c,int k){
    while(top[u]!=top[v]){
        if(dis[top[u]]<dis[top[v]]) swap(u,v);
        modify_chain(vis[top[u]],vis[u],c,k);//修改top[u]到u这一段上的信息
        u=f[top[u]][0];//往上跳一级
    }
    if(dis[u]>dis[v]) swap(u,v);//u节点一定比v节点浅
    modify_chain(vis[u],vis[v],c,k);//修改u到v这一段上的信息
}
int query_chain(int l,int r,int c,int x){//查询u到v这一段,颜色为c的节点的权值之和
    if(l==r) return a[l]?w[l]:0;//如果该节点存在,返回对应的权值,否则返回0
    else{
        int mid=(l+r)>>1;
        if(dis[u]<=mid) return modify_range(l,mid,ls,c,x);
        else return modify_range(mid+1,r,rs,c,x);
    }
}
int query_path(int u,int v,int c){
    int ans=0;
    while(top[u]!=top[v]){
        if(dis[top[u]]<dis[top[v]]) swap(u,v);
        ans+=query_chain(vis[top[u]],vis[u],c,top[u]);//查询top[u]到u这一段上的信息
        u=f[top[u]][0];//往上跳一级
    }
    if(dis[u]>dis[v]) swap(u,v);//u节点一定比v节点浅
    ans+=query_chain(vis[u],vis[v],c,top[u]);//查询u到v这一段上的信息
    return ans;
}
int main(){
    cin>>n;
    for(int i=1;i<=n;i++) cin>>w[i];
    for(int i=2;i<=n;i++){
        int u,v;
        cin>>u>>v;
        add(u,v,0);
        add(v,u,0);
        f[v][0]=u;
    }
    dfs1(1,0);//求出每个节点的父亲、重儿子
    dfs2(1,1);//求出每个节点所在的重链、轻儿子的深度、每个节点的编号
    for(int i=1;i<=n;i++){
        c[i]=rand()%2;//生成随机数,作为节点的颜色
    }
    update(1);//预处理每个节点的子树内,两种颜色分别出现的次数
    cin>>m;//下面进行m次查询
    while(m--){
        int u,v,c;
        cin>>u>>v>>c;
        int ans=query_path(u,v,c);//查询u到v之间颜色为c的节点的权值之和
        cout<<ans<<endl;
    }
    return 0;
}

以上是本题的C++代码实现,其中不少变量名称和函数名称与普通的代码风格略有不同。具体实现中,还需要树状数组等数据结构的支持。