📅  最后修改于: 2023-12-03 15:42:22.867000             🧑  作者: Mango
给定一棵树,求出从任意起点到任意终点的简单路径(此路径上没有重复的节点)上,深度最浅的节点的深度值。
这是一道比较经典的树上最深公共祖先(LCA)的问题,可以使用倍增算法或者树剖算法进行求解。我们假设有两个点 $u$ 和 $v$,它们的 $\text{LCA}$ 为 $x$,那么两个点到它们的 $\text{LCA}$ 的路径上最深的结点就是 $\max{h(u,x), h(v,x)}$。其中 $h(u,x)$ 表示点 $u$ 到点 $x$ 的深度。
倍增算法的时间复杂度是 $O(n \log n)$,这里假设树的根节点为 $r$。我们首先求出每个节点的深度和它们与根节点的距离,然后预处理出每个节点的 $2^d$ 级祖先,其中 $d$ 是一个常数。我们可以预处理出 $2^d$ 级祖先,然后进行查询操作。在查询时,假设要查询 $u$ 和 $v$ 两个点的深度值,那么可以将 $u$ 和 $v$ 沿着它们的 $2^d$ 级祖先分别跳到它们的 $\text{LCA}$,然后再计算它们的深度值。
具体实现过程可以参考这篇博客。
树剖算法的时间复杂度是 $O(n \log n)$,与倍增算法的时间复杂度相同。树剖算法的思路是给每个节点标记一个时间戳,然后将一条路径划分成若干条链,每条链都用线段树来维护,查询的时候把路径分成若干条链的根路径和剩下的部分就可以了。
具体实现过程可以参考这篇博客。
以下是使用倍增算法实现的代码片段:
int f[N][M], d[N], dis[N];
vector<int> G[N];
void dfs(int u, int fa)
{
d[u] = d[fa] + 1;
f[u][0] = fa;
for (int i = 1; i <= M; i++)
f[u][i] = f[f[u][i - 1]][i - 1];
for (auto v : G[u])
{
if (v != fa)
{
dis[v] = dis[u] + 1;
dfs(v, u);
}
}
}
int lca(int u, int v)
{
if (d[u] < d[v])
swap(u, v);
for (int i = M; i >= 0; i--)
if (d[f[u][i]] >= d[v])
u = f[u][i];
if (u == v)
return u;
for (int i = M; i >= 0; i--)
if (f[u][i] != f[v][i])
{
u = f[u][i];
v = f[v][i];
}
return f[u][0];
}
int main()
{
int n, m;
cin >> n >> m;
for (int i = 1; i <= n - 1; i++)
{
int u, v;
cin >> u >> v;
G[u].push_back(v);
G[v].push_back(u);
}
dfs(1, 0);
while (m--)
{
int u, v;
cin >> u >> v;
int p = lca(u, v);
cout << max(dis[u] - dis[p], dis[v] - dis[p]) << endl;
}
return 0;
}
以下是使用树剖算法实现的代码片段:
int fa[N], dep[N], siz[N], son[N], top[N], dfn[N], rnk[N], tot;
vector<int> G[N];
void dfs1(int u, int f)
{
dep[u] = dep[f] + 1;
fa[u] = f;
siz[u] = 1;
for (int v : G[u])
{
if (v == f)
continue;
dfs1(v, u);
siz[u] += siz[v];
if (siz[v] > siz[son[u]])
son[u] = v;
}
}
void dfs2(int u, int t)
{
top[u] = t;
dfn[u] = ++tot;
rnk[dfn[u]] = u;
if (!son[u])
return;
dfs2(son[u], t);
for (int v : G[u])
{
if (v == fa[u] || v == son[u])
continue;
dfs2(v, v);
}
}
int getlca(int u, int v)
{
while (top[u] != top[v])
{
if (dep[top[u]] < dep[top[v]])
swap(u, v);
u = fa[top[u]];
}
if (dep[u] > dep[v])
swap(u, v);
return u;
}
int main()
{
int n, m;
scanf("%d%d", &n, &m);
for (int i = 1; i <= n - 1; i++)
{
int u, v;
scanf("%d%d", &u, &v);
G[u].push_back(v);
G[v].push_back(u);
}
dfs1(1, 0);
dfs2(1, 1);
while (m--)
{
int u, v;
scanf("%d%d", &u, &v);
int p = getlca(u, v);
printf("%d\n", max(dep[u] - dep[p], dep[v] - dep[p]));
}
return 0;
}