当前位置 博文首页 > 文章内容

    UOJ#388. 【UNR #3】配对树 树链剖分+线段树

    作者: 栏目:未分类 时间:2020-07-15 9:00:16

    本站于2023年9月4日。收到“大连君*****咨询有限公司”通知
    说我们IIS7站长博客,有一篇博文用了他们的图片。
    要求我们给他们一张图片6000元。要不然法院告我们

    为避免不必要的麻烦,IIS7站长博客,全站内容图片下架、并积极应诉
    博文内容全部不再显示,请需要相关资讯的站长朋友到必应搜索。谢谢!

    另祝:版权碰瓷诈骗团伙,早日弃暗投明。

    相关新闻:借版权之名、行诈骗之实,周某因犯诈骗罪被判处有期徒刑十一年六个月

    叹!百花齐放的时代,渐行渐远!



    这道题卡常啊 !           

    出题人说 $O(n \log^2 n)$ 可过,但我写了个 $O(n \log^2 n)$ 的树剖卡了半天常数.     

    最暴力的做法:枚举区间,然后跑一个树形DP 来求最小匹配.     

    显然,因为要求匹配值最小,所以一定是能匹配就先匹配.   

    也就是说递归完 $x$ 的所有儿子后,$x$ 的每一个儿子最多只有 1 个点还没有匹配.      

    这个时间复杂度是 $O(n^3)$ 的.    

    然后我们对每一条边分别考虑:   

    令 $v[x]$ 表示点 $x$ 到其父亲的边权(以 1 为根),那么 $v[x]$ 能产生贡献,当且仅当一个区间中 $x$ 子树中有奇数个点.   

    这个很好理解,因为如果有奇数个点,就意味着 1 个点没有被匹配到,而需要向上延伸的 $x$ 的父亲,依此类推......       

    那么就枚举右端点,然后令 $f[x][0/1]$ 分别表示多少个长度为偶数的区间满足在 $x$ 的子树中有偶数/奇数个点.      

    由于要求区间长度是偶数,我们可以分别以 $1,2$ 为起点各跑一次,每次同时加入两个点来保证长度为偶数.      

    考虑加入 $x,y$ 后的影响:

    $x$ 到 $lca$ 与 $y$ 到 $lca$ (不包括 lca 这个点)的路径上 $f[x][0]=f[x][1]$,$f[x][1]=f[x][0]+1$             

    不在 $x,y$ 路径上的点 $f[x][1]$ 不变,$f[x][0] \leftarrow f[x][0]+1$.     

    这个暴力修改的话是 $O(n^2)$ 的,可以获得 $50$pts.  

    满分算法的话就是用树链剖分+线段树来维护上面的东西.   

    我们无外乎就是要支持:每个节点维护 $f[x][0],f[x][1]$,区间加,区间交换.  

    然后定义标记 $(rev,x,y)$ 表示是否要交换 $f[x][0],f[x][1]$ 的值,交换后对 $f[x][0]$,$f[x][1]$ 分别加上 $x,y$.     

    时间复杂度为 $O(n \log^2 n)$,但是会有点卡常.  

    这里说几个卡常技巧: 

    1. 读入优化  

    2. 开 long long 要比取模快.  

    3. 由于上述操作中每次加的数是 1 或 -1,所以这个标记可以直接开 int,然后区间和开 long long.    

    code:    

    #include <cstdio>
    #include <ctime>
    #include <cstring> 
    #include <algorithm>      
    #define N 100008  
    #define ll long long 
    #define mod 998244353
    #define lson now<<1  
    #define rson now<<1|1     
    #define setIO(s) freopen(s".in","r",stdin)  
    using namespace std;    
    int edges,n,m,tim;    
    int nd[N],f[N][2],fa[N];      
    int hd[N],to[N<<1],nex[N<<1],val[N<<1]; 
    int dep[N],a[N],size[N],top[N],son[N],dfn[N],bu[N];    
    ll ans;   
    struct data {
        int rev;  
        int vx,vy;  
        ll sx,sy,sum;    
        data(int rev=0,int vx=0,int vy=0):rev(rev),vx(vx),vy(vy){}  
    }s[N<<2];   
    inline void add(int u,int v,int c) {         
        nex[++edges]=hd[u];   
        hd[u]=edges,to[edges]=v,val[edges]=c;  
    }            
    void dfs(int x,int ff) {  
        size[x]=1;  
        fa[x]=ff,dep[x]=dep[ff]+1;        
        for(int i=hd[x];i;i=nex[i]) {              
            int y=to[i];  
            if(y==ff) continue;      
            nd[y]=val[i],dfs(y,x);    
            size[x]+=size[y];  
            if(size[y]>size[son[x]]) son[x]=y;    
        }
    }
    void dfs2(int x,int tp) {
        top[x]=tp;  
        dfn[x]=++tim;  
        bu[tim]=x;    
        if(son[x]) dfs2(son[x],tp); 
        for(int i=hd[x];i;i=nex[i]) 
            if(to[i]!=fa[x]&&to[i]!=son[x]) 
                dfs2(to[i],to[i]);   
    }
    inline int get_lca(int x,int y) { 
        while(top[x]!=top[y]) {
            dep[top[x]]>dep[top[y]]?x=fa[top[x]]:y=fa[top[y]];  
        }
        return dep[x]<dep[y]?x:y;   
    }      
    inline void pushup(int now) {
        s[now].sx=(ll)(s[lson].sx+s[rson].sx);  
        s[now].sy=(ll)(s[lson].sy+s[rson].sy);    
    }
    inline void mark_rev(int now) {
        swap(s[now].sx,s[now].sy);        
        swap(s[now].vx,s[now].vy);   
        s[now].rev^=1;     
    }
    inline void mark_add(int now,int vx,int vy) {
        if(vx) (s[now].sx+=(ll)vx*s[now].sum);   
        if(vy) (s[now].sy+=(ll)vy*s[now].sum);   
        if(vx) (s[now].vx+=vx);  
        if(vy) (s[now].vy+=vy);    
    }
    inline void pushdown(int now) {
        if(s[now].rev) {
            s[now].rev=0; 
            mark_rev(lson); 
            mark_rev(rson); 
        }   
        if(s[now].vx||s[now].vy) {
            mark_add(lson,s[now].vx,s[now].vy);  
            mark_add(rson,s[now].vx,s[now].vy);  
            s[now].vx=s[now].vy=0;   
        }
    }
    void build(int l,int r,int now) {
        s[now]=data(); 
        s[now].sx=0; 
        s[now].sy=0;      
        if(l==r) {
            s[now].sum=nd[bu[l]];   
            return; 
        }
        int mid=(l+r)>>1;  
        build(l,mid,lson),build(mid+1,r,rson);   
        s[now].sum=(ll)(s[lson].sum+s[rson].sum)%mod;   
    }
    void REV(int l,int r,int now,int L,int R) {
        if(l>=L&&r<=R) {
            mark_rev(now);   
            return;  
        }
        pushdown(now); 
        int mid=(l+r)>>1;   
        if(L<=mid) REV(l,mid,lson,L,R);  
        if(R>mid)  REV(mid+1,r,rson,L,R);   
        pushup(now);   
    }
    void ADD(int l,int r,int now,int L,int R,int vx,int vy) {
        if(l>=L&&r<=R) {
            mark_add(now,vx,vy);  
            return; 
        }
        pushdown(now); 
        int mid=(l+r)>>1;   
        if(L<=mid)  ADD(l,mid,lson,L,R,vx,vy);  
        if(R>mid)   ADD(mid+1,r,rson,L,R,vx,vy);  
        pushup(now);  
    }     
    inline void upd(int x,int y) {         
        while(top[y]!=top[x]) {  
            ADD(1,n,1,dfn[top[y]],dfn[y],-1,0);  
            REV(1,n,1,dfn[top[y]],dfn[y]);     
            ADD(1,n,1,dfn[top[y]],dfn[y],0,1);         
            y=fa[top[y]];   
        }     
        if(y!=x) {
            ADD(1,n,1,dfn[x]+1,dfn[y],-1,0);  
            REV(1,n,1,dfn[x]+1,dfn[y]);   
            ADD(1,n,1,dfn[x]+1,dfn[y],0,1);  
        } 
    }
    void sol(int st) {
        int x,y,lca;         
        build(1,n,1);   
        for(int i=st;i<=m;i+=2) {       
            if(i+1>m) break;  
            x=a[i],y=a[i+1];        
            if(dep[x]>dep[y]) swap(x,y);        
            lca=get_lca(x,y);      
            mark_add(1,1,0);              
            upd(lca,x); 
            upd(lca,y);   
            (ans+=s[1].sy)%=mod;
        }      
    }         
    char *p1,*p2,buf[100000];   
    #define nc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++)  
    int rd()
    {
        int x=0; char c;   
        while(c<48) c=nc();  
        while(c>47) x=(((x<<2)+x)<<1)+(c^48),c=nc();  
        return x;    
    }
    int main() {  
        // setIO("input");  
        n=rd(),m=rd();  
        int x,y,z;  
        for(int i=1;i<n;++i) { 
            x=rd(),y=rd(),z=rd();  
            if(z>=mod) z-=mod;  
            add(x,y,z),add(y,x,z);  
        }    
        dfs(1,0);    
        dfs2(1,1);                   
        for(int i=1;i<=m;++i) a[i]=rd();          
        sol(1),sol(2);  
        printf("%lld\n",ans);   
        return 0;   
    }