问题描述

$Hja$有一棵$N$个点的树,树上每个点有点权,每条边有颜色。

一条路径的权值是这条路径上所有点的点权和,一条合法的路径需要满足该路径上任意相邻的两条边颜色都不相同。

问这棵树上所有合法路径的权值和是多少。

输入格式

第一行一个数$N$。

接下来一行$N$个数代表每个点的权值。

接下来$N-1$行每行三个整数$s,e,c$,代表$s$到$e$之间有一条颜色为$c$的边。

输出格式

一行一个整数代表答案。

样例输入

6
6 2 3 7 1 4
1 2 1
1 3 2
1 4 3
2 5 1
2 6 2

样例输出

134

数据范围

对于$100\%$的数据,$1\leq N\leq 3\times 10^5,1\leq c\leq 10^9$。

菜鸡的咆哮

树形$DP$+二次扫描换根,挺好想,就是写起来各种手滑....

状态定义

$f[x]$表示在$x$为根的子树中,以$x$为起点的路径点权和。

$xx[x]$表示在$x$为根的子树中,以$x$为起点且爸爸与$x$连接合法的路径点权和。

$g[x]$表示在除$x$为根的子树外的所有节点中,以$x$为起点的路径点权和。

$path\_f/xx/g$分别表示上述条件下的路径条数。

最终的答案即为$\frac{\Sigma_1^n(f[x]+g[x])}{2}$(起点终点重复计算)

初始化

$f[x]=xx[x]=g[x]=0$

转移方程

$dfs1$中处理$xx[x]$与$f[x]$:

对于一个点$x$,我们定义$opt$为他与某个儿子连边的颜色,$fa\_edge$为与爸爸连边的颜色。

$f[x]+=xx[y]+path\_xx[y]*val[x]+val[x]+val[y]$

$path\_f[x]+=path\_xx[y]+1$

这里注意$x$与$y$构成了一条以$x$为起点的新路径,权值为$val[x]+val[y]$。

如果$opt!=fa\_edge$,$xx[y]$才能对$xx[x]$作出贡献,因为$fa->x->y$必须合法。

$xx[x]+=xx[y]+path\_xx[y]*val[x]+val[x]+val[y]$

$path\_xx[x]+=path\_xx[y]+1$

$dfs2$中处理$g[x]$:

e57a0ff704c6d937428599e6fbad79e5.md.png

假如我们从$x$递归到$y$,想要求出$y$的$g$值。

以$y$为起点出发,合法路径可以由三种途径产生:

①蓝点,到达$y$的子树中,即$f[y]$。

②绿点,紫点,经过$x$到达$x$其他儿子(的子树)中(两点与$x$连边的颜色不同)。

③黄点,经过$x$到达祖先($y$,$fa[x]$与$x$连边的颜色不同)。

①②通过$xx$数组转移,③通过$g$数组转移。

我们定义$tot\_val/path/node$为以上①②情况的权值和,路径数和,通向点数和。

同理$sum\_val/path/node[opt]$表示对应颜色的各项和。(所以颜色要离散化)。

然后......

$g[y]=(tot\_val-sum\_val[opt])+(tot\_path-sum\_path[opt]+tot\_node-sum\_node[opt])*(val[x]+val[y])+(val[x]+val[y])$

注意$Y->X,Y->X->目标点$这两种比较特殊的路径。

$path\_g[y]=tot\_path-sum\_path[opt]+tot\_node-sum\_node[opt]+1$

如果$opt!=fa\_edge$,那还要加上上文的情况③:

$g[y]+=g[x]+path\_g[x]*val[y]$

$path\_g[y]+=path\_g[x]$

$awsl......$

代码

#include<bits/stdc++.h>
#define ll long long
#define maxn 300005
using namespace std;
inline int read(){
    char ch=getchar();int s=0,w=1;
    while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){s=(s<<1)+(s<<3)+ch-48;ch=getchar();}
    return s*w;
}
int n;
ll val[maxn];
struct data1{
    int to;
    int p;
    int nextt;
}line[2*maxn];
int tail;
int first[maxn];
void add(int x,int y,int w){
    tail++;
    line[tail].to=y;
    line[tail].p=w;
    line[tail].nextt=first[x];
    first[x]=tail;
}
ll xx[maxn],f[maxn],g[maxn];
int path_xx[maxn],path_f[maxn],path_g[maxn];
void dfs1(int x,int fa,int fa_edge){
    for(int i=first[x];i;i=line[i].nextt){
        int y=line[i].to;
        int opt=line[i].p;
        if(y==fa) continue;
        dfs1(y,x,opt);
        f[x]+=xx[y]+path_xx[y]*val[x]+val[x]+val[y];
        path_f[x]+=path_xx[y]+1;
        if(opt==fa_edge) continue;
        xx[x]+=xx[y]+path_xx[y]*val[x]+val[x]+val[y];
        path_xx[x]+=path_xx[y]+1;
    }
}
ll tot_val,tot_path,tot_node;
int sum_node[maxn],sum_path[maxn];
ll sum_val[maxn];
int vis[maxn];
void dfs2(int x,int fa,int fa_edge){
    tot_val=0;tot_path=0;tot_node=0;
    for(int i=first[x];i;i=line[i].nextt){
        int y=line[i].to;
        if(y==fa) continue;
        int opt=line[i].p;
        if(vis[opt]==1){
            sum_node[opt]=sum_val[opt]=sum_path[opt]=0;
            vis[opt]=0;
        }
        tot_node++; sum_node[opt]++;
        sum_path[opt]+=path_xx[y]; tot_path+=path_xx[y];
        sum_val[opt]+=xx[y]+val[y]; tot_val+=xx[y]+val[y];
    }
    for(int i=first[x];i;i=line[i].nextt){
        int y=line[i].to;
        if(y==fa) continue;
        int opt=line[i].p;
        vis[opt]=1;
    }
    for(int i=first[x];i;i=line[i].nextt){
        int y=line[i].to;
        int opt=line[i].p;
        if(y==fa) continue;
        g[y]=(tot_val-sum_val[opt])+(tot_path-sum_path[opt]+tot_node-sum_node[opt])*(val[x]+val[y])+(val[x]+val[y]);
        path_g[y]=tot_path-sum_path[opt]+tot_node-sum_node[opt]+1;
        if(opt!=fa_edge){
            g[y]+=g[x]+path_g[x]*val[y];
            path_g[y]+=path_g[x];
        }
    }
    for(int i=first[x];i;i=line[i].nextt){
        int y=line[i].to;
        int opt=line[i].p;
        if(y==fa) continue;
        dfs2(y,x,opt);
    }
}
struct data2{
    int a,b,c;
}tree[maxn];
bool cmp(data2 x,data2 y){
    return x.c<y.c;
}
int main(){
    freopen("b.in","r",stdin);
    freopen("b.out","w",stdout);
    n=read();
    for(int i=1;i<=n;i++){
        val[i]=read();
    }
    for(int i=1;i<n;i++){
        tree[i].a=read();
        tree[i].b=read();
        tree[i].c=read();
    }
    sort(tree+1,tree+n,cmp);
    int pos=0;
    for(int i=1;i<n;i++){
        if(tree[i].c!=tree[i-1].c) pos++;
        add(tree[i].a,tree[i].b,pos);
        add(tree[i].b,tree[i].a,pos);
    }
    /*ll ans=0;
    for(int i=1;i<=n;i++){
        memset(f,0,sizeof(f));
        memset(path_f,0,sizeof(path_f));
        memset(xx,0,sizeof(xx));
        memset(path_xx,0,sizeof(path_xx));
        dfs1(i,0,-1);    
        ans+=f[i]; 
    }*/
    dfs1(1,0,-1);
    dfs2(1,0,-1);
    ll ans=0;
    for(int i=1;i<=n;i++){
        ans+=g[i]+f[i];
    }
    printf("%lld",ans/2);
    return 0;
}