2021年度训练联盟热身训练赛第一场 I. Full Depth Morning Show

链接

https://ac.nowcoder.com/acm/contest/12606/I

题意

记 $d_{u,v}$ 为节点 $u$ 到节点 $v$ 的最短距离,$t_u$ 为节点 $u$ 的权值,对树上任意节点 $u$ 求 $\sum_{v=1}^{n}{[d_{u,v}*(t_u+t_v)]}$。

思路

换根 DP。

记 $res_u=\sum_{v=1}^{n}{[d_{u,v}*(t_u+t_v)]}$:

$res_{u}=\sum_{v=1}^{n}{[d_{u,v}(t_u+t_v)}]=t_u\sum_{v=1}^{n}{d_{u,v}}+\sum_{v=1}^{n}{(d_{u,v}*t_v)}$。

记 $a_u=\sum_{v=1}^{n}{d_{u,v}}$,$b_u=\sum_{v=1}^{n}{(d_{u,v}*t_v)}$,$sz_u$ 为节点 $u$ 的子树节点数,$sum_u$ 为节点 $u$ 的子树节点权值和,$u’$ 为节点 $u$ 的儿子节点:

$a_{u’}=a_u-d_{u,u’}sz_u+d_{u,u’}(n-sz_{u’})=a_u+d_{u,u’}(n-2sz_{u’})$,

$b_{u’}=b_u-d_{u,u’}sum_u+d_{u,u’}(\sum_{i=1}^{n}{t_i}-sum_{u’})=b_u+d_{u,u’}(\sum_{i=1}^{n}{t_i}-2sum_{u’})$,

$res_{u’}=t_{u’}*a_{u’}+b_{u’}$。

两次 DFS,以 节点 $1$ 为根,先计算出 $a_1$,$b_1$,再推出其余节点的 $res$,$O(n)$ 求解。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#include <bits/stdc++.h>
#define SZ(x) (int)(x).size()
#define ALL(x) (x).begin(),(x).end()
#define PB push_back
#define EB emplace_back
#define MP make_pair
#define FI first
#define SE second
using namespace std;
typedef double DB;
typedef long double LD;
typedef long long LL;
typedef unsigned long long ULL;
typedef pair<int,int> PII;
typedef vector<int> VI;
typedef vector<PII> VPII;
// head
const int N=1e5+5;
int n;
LL t[N],d[N],sz[N],sum[N],a[N],b[N],res[N];
VPII g[N];
void dfs1(int u,int fa) {
sz[u]=1;
sum[u]=t[u];
for(auto x:g[u]) {
int v=x.FI,w=x.SE;
if(v==fa) continue;
d[v]=d[u]+w;
a[1]+=d[v];
b[1]+=d[v]*t[v];
dfs1(v,u);
sz[u]+=sz[v];
sum[u]+=sum[v];
}
}
void dfs2(int u,int fa) {
for(auto x:g[u]) {
int v=x.FI,w=x.SE;
if(v==fa) continue;
a[v]=a[u]+w*(n-2*sz[v]);
b[v]=b[u]+w*(sum[1]-2*sum[v]);
res[v]=a[v]*t[v]+b[v];
dfs2(v,u);
}
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cin>>n;
for(int i=1;i<=n;i++) cin>>t[i];
for(int i=1;i<n;i++) {
int u,v,w;
cin>>u>>v>>w;
g[u].EB(v,w);
g[v].EB(u,w);
}
dfs1(1,0);
dfs2(1,0);
res[1]=a[1]*t[1]+b[1];
for(int i=1;i<=n;i++) cout<<res[i]<<'\n';
return 0;
}