2021牛客暑期多校训练营7 F. xay loves trees

链接

https://ac.nowcoder.com/acm/contest/11258/F

题意

有两棵 $n$ 个节点的树,编号均为 $1\sim n$,找出这 $n$ 个节点的最大子集,满足:

  1. 在第一棵树上,集合内任意两点其中一点是另一点的祖先;
  2. 在第二棵树上,集合内不存在两点其中一点是另一点的祖先。

思路一

易发现,这个集合只可能是第一棵树上的一条从上至下的链,而我们要判断的是这些点在第二棵树上的子树有没有交集。

求出在第二棵树上的 DFS 序和子树大小,用线段树维护子树信息。

在第一棵树上 DFS,访问到点 $u$ 时,对其在第二棵树上的子树区间 $+1$。

  • 若 $[1,n]$ 最大值为 $1$,说明没有交集,更新答案;
  • 否则缩短链,让上端减小,直到区间最大值为 $1$。

可以用双端队列维护,而在对点 $u$ 访问结束时,需要回溯,弹出队尾,重新塞入之前弹出的队首。

可以发现每次弹出的队首可能不止一个,考虑优化。

若当前答案合法的最大长度为 $len$,那么若不合法时,我们只要保证队列长度 $\le len$,即当$>len$ 时,弹出一次队首即可。

代码一

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
// 树上划窗 & 线段树
// O(nlog(n))
#include <bits/stdc++.h>
#define SZ(x) (int)(x).size()
#define ALL(x) (x).begin(),(x).end()
#define PB push_back
#define EB emplace_back
#define MP make_pair
#define FI first
#define SE second
using namespace std;
typedef double DB;
typedef long double LD;
typedef long long LL;
typedef unsigned long long ULL;
typedef pair<int,int> PII;
typedef vector<int> VI;
typedef vector<PII> VPII;
// head
const int N=3e5+5;
int res,cnt,dfn[N],sz[N],dep[N],ls[N<<2],rs[N<<2],mx[N<<2],laz[N<<2];
VI g1[N],g2[N];
deque<int> q;
void init(int n) {
cnt=res=0;
for(int i=1;i<=n;i++) g1[i].clear(),g2[i].clear();
while(SZ(q)) q.pop_back();
}
void dfs1(int u,int fa) {
dfn[u]=++cnt;
sz[u]=1;
for(auto v:g2[u]) if(v!=fa) {
dfs1(v,u);
sz[u]+=sz[v];
}
}
void pushdown(int p) {
if(laz[p]) {
mx[p<<1]+=laz[p],mx[p<<1|1]+=laz[p];
laz[p<<1]+=laz[p],laz[p<<1|1]+=laz[p];
laz[p]=0;
}
}
void build(int p,int l,int r) {
ls[p]=l,rs[p]=r;
mx[p]=laz[p]=0;
if(l==r) return;
int mid=l+r>>1;
build(p<<1,l,mid);
build(p<<1|1,mid+1,r);
}
void update(int p,int l,int r,int v) {
if(ls[p]>=l&&rs[p]<=r) {
mx[p]+=v;
laz[p]+=v;
return;
}
pushdown(p);
int mid=ls[p]+rs[p]>>1;
if(l<=mid) update(p<<1,l,r,v);
if(r>mid) update(p<<1|1,l,r,v);
mx[p]=max(mx[p<<1],mx[p<<1|1]);
}
void dfs2(int u,int fa) {
dep[u]=dep[fa]+1;
update(1,dfn[u],dfn[u]+sz[u]-1,1);
q.PB(u);
int t=-1;
if(mx[1]==1) res=max(res,dep[u]-dep[q.front()]+1);
else if(dep[u]-dep[q.front()]+1>res) {
t=q.front();
update(1,dfn[t],dfn[t]+sz[t]-1,-1);
q.pop_front();
}
for(auto v:g1[u]) if(v!=fa) {
dfs2(v,u);
}
if(t!=-1) {
q.push_front(t);
update(1,dfn[t],dfn[t]+sz[t]-1,1);
}
update(1,dfn[u],dfn[u]+sz[u]-1,-1);
q.pop_back();
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int T;cin>>T;
while(T--) {
int n;cin>>n;
init(n);
for(int i=1;i<n;i++) {
int u,v;cin>>u>>v;
g1[u].PB(v);
g1[v].PB(u);
}
for(int i=1;i<n;i++) {
int u,v;cin>>u>>v;
g2[u].PB(v);
g2[v].PB(u);
}
dfs1(1,0);
build(1,1,n);
dfs2(1,0);
cout<<res<<'\n';
}
return 0;
}

思路二

在第一棵树上 DFS,遍历到 $u$ 时,在其父亲的基础上更新 $u$ 在第二棵树上的子树区间为当前在第一棵树上的深度,用主席树维护。

并且访问父亲的主席树此区间的最大深度,记为 $mx$,即 $u$ 到根节点之间与 $u$ 在第二棵树上有交集的深度最大的点。

假设以父亲为最大深度的链的长度为 $len$,那么以 $u$ 为最大深度的链长为 $\min(len+1,dep_u-mx)$。

代码二

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
// 主席树
// O(nlog(n))
#include <bits/stdc++.h>
#define SZ(x) (int)(x).size()
#define ALL(x) (x).begin(),(x).end()
#define PB push_back
#define EB emplace_back
#define MP make_pair
#define FI first
#define SE second
using namespace std;
typedef double DB;
typedef long double LD;
typedef long long LL;
typedef unsigned long long ULL;
typedef pair<int,int> PII;
typedef vector<int> VI;
typedef vector<PII> VPII;
// head
const int N=3e5+5;
int n,res,cnt,dfn[N],sz[N],dep[N];
int t,rt[N],ls[N<<6],rs[N<<6],mx[N<<6],laz[N<<6];
VI g1[N],g2[N];
void init(int n) {
for(int i=1;i<=n;i++) g1[i].clear(),g2[i].clear();
for(int i=1;i<=t;i++) ls[i]=rs[i]=mx[i]=laz[i]=0;
cnt=t=0;
}
void dfs1(int u,int fa) {
dfn[u]=++cnt;
sz[u]=1;
for(auto v:g2[u]) if(v!=fa) {
dfs1(v,u);
sz[u]+=sz[v];
}
}
void pushdown(int p) {
if(laz[p]) {
if(!ls[p]) ls[p]=++t;
if(!rs[p]) rs[p]=++t;
mx[ls[p]]=mx[rs[p]]=laz[ls[p]]=laz[rs[p]]=laz[p];
laz[p]=0;
}
}
int update(int &p,int q,int L,int R,int l,int r,int v) {
p=++t;
ls[p]=ls[q],rs[p]=rs[q],mx[p]=v;
if(L>=l&&R<=r) {
int len=v-mx[q];
ls[p]=rs[p]=0;
laz[p]=v;
return len;
}
pushdown(q);
int mid=L+R>>1;
if(r<=mid) return update(ls[p],ls[q],L,mid,l,r,v);
if(l>mid) return update(rs[p],rs[q],mid+1,R,l,r,v);
return min(update(ls[p],ls[q],L,mid,l,mid,v),update(rs[p],rs[q],mid+1,R,mid+1,r,v));
}
void dfs2(int u,int fa,int len) {
dep[u]=dep[fa]+1;
len=min(len+1,update(rt[u],rt[fa],1,n,dfn[u],dfn[u]+sz[u]-1,dep[u]));
res=max(res,len);
for(auto v:g1[u]) if(v!=fa) {
dfs2(v,u,len);
}
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int T;cin>>T;
while(T--) {
cin>>n;
init(n);
for(int i=1;i<n;i++) {
int u,v;cin>>u>>v;
g1[u].PB(v);
g1[v].PB(u);
}
for(int i=1;i<n;i++) {
int u,v;cin>>u>>v;
g2[u].PB(v);
g2[v].PB(u);
}
dfs1(1,0);
res=0;
dfs2(1,0,0);
cout<<res<<'\n';
}
return 0;
}