📅  最后修改于: 2023-12-03 15:00:11.014000             🧑  作者: Mango
这是一道经典的网格图算法题,要求将一张二维网格图中所有相邻的连通块切割掉。给定的网格图由 $n$ 行 $m$ 列的字符矩阵表示,其中 '.' 表示可切割的网格,而 '*' 代表不可切割的网格。完整的题目描述和样例可以在 Codeforces 上找到:Cut em all。
我们可以将这道题看成是求二维网格图中连通块的数量的变形。我们先看一下求连通块数量的做法:
用 DFS 从一个未被访问的节点开始遍历图,并在遍历的过程中标记已经访问过的节点。每当 DFS 找到一个未被访问过的节点时,就说明又发现了一个连通块。代码可以如下:
def dfs(x, y):
global cnt
vis[x][y] = True
for d in directions:
nx, ny = x + d[0], y + d[1]
if nx < 0 or nx >= n or ny < 0 or ny >= m:
continue
if not vis[nx][ny] and a[nx][ny] == '.':
dfs(nx, ny)
n, m = map(int, input().split())
a = [input() for _ in range(n)]
cnt = 0
vis = [[False] * m for _ in range(n)]
directions = [(0, 1), (0, -1), (1, 0), (-1, 0)]
for i in range(n):
for j in range(m):
if not vis[i][j] and a[i][j] == '.':
cnt += 1
dfs(i, j)
print(cnt)
这份代码的时间复杂度是 $O(nm)$,会超时。
用 BFS 从一个未被访问的节点开始遍历图,并在遍历的过程中标记已经访问过的节点。每当 BFS 找到一个未被访问过的节点时,就说明又发现了一个连通块。代码可以如下:
from collections import deque
n, m = map(int, input().split())
a = [input() for _ in range(n)]
cnt = 0
vis = [[False] * m for _ in range(n)]
directions = [(0, 1), (0, -1), (1, 0), (-1, 0)]
for i in range(n):
for j in range(m):
if not vis[i][j] and a[i][j] == '.':
cnt += 1
vis[i][j] = True
q = deque([(i, j)])
while q:
x, y = q.popleft()
vis[x][y] = True
for d in directions:
nx, ny = x + d[0], y + d[1]
if nx < 0 or nx >= n or ny < 0 or ny >= m:
continue
if not vis[nx][ny] and a[nx][ny] == '.':
q.append((nx, ny))
print(cnt)
这份代码的时间复杂度是 $O(nm)$,会超时。
用并查集维护连通块,代码可以如下:
def find(x):
if p[x] != x:
p[x] = find(p[x])
return p[x]
n, m = map(int, input().split())
a = [input() for _ in range(n)]
cnt = 0
size = [0] * (n * m)
p = [i for i in range(n * m)]
directions = [(0, 1), (0, -1), (1, 0), (-1, 0)]
for i in range(n):
for j in range(m):
if a[i][j] == '.':
size[i * m + j] = 1
for d in directions:
ni, nj = i + d[0], j + d[1]
if 0 <= ni < n and 0 <= nj < m and a[ni][nj] == '.':
px, py = find(i * m + j), find(ni * m + nj)
if px != py:
size[py] += size[px]
p[px] = py
cnt += 1 if find(i * m + j) == i * m + j else 0
print(cnt)
这份代码的时间复杂度是 $O(nm \alpha(nm))$,即 $O(nm)$ 级别。
而对于这道题,我们只需要在找到一个连通块后,将其周围的格子都切掉即可。这里需要用到并查集的回溯,即遍历完目前的连通块后,要将它们的祖先节点与其他节点合并,以便后面判断其他连通块时可以快速筛选出合法的连通块。
完整的代码如下:(时间复杂度 $O(nm)$)
from typing import List
def find(x):
if p[x] != x:
p[x] = find(p[x])
return p[x]
def merge(x, y):
px, py = find(x), find(y)
if px != py:
p[px] = py
def cut_neighbors(x, y):
for d in directions:
nx, ny = x + d[0], y + d[1]
if 0 <= nx < n and 0 <= ny < m and a[nx][ny] == '*' and a[nx][ny] != a[x][y]:
dx, dy = find(x * m + y), find(nx * m + ny)
if dx != dy:
cnt[dx] -= 1
cnt[dy] -= 1
a[x][y] = '*'
def cut_all(a: List[List[str]]) -> int:
global n, m, p, directions, cnt
n, m = len(a), len(a[0])
p = [i for i in range(n * m)]
cnt = [0] * (n * m)
for i in range(n):
for j in range(m):
if a[i][j] == '.':
cnt[find(i * m + j)] += 1
for d in directions:
ni, nj = i + d[0], j + d[1]
if 0 <= ni < n and 0 <= nj < m and a[ni][nj] == '.':
merge(i * m + j, ni * m + nj)
ans = 0
for i in range(n):
for j in range(m):
if a[i][j] == '.':
if cnt[find(i * m + j)] == 1:
ans += 1
cut_all(i, j)
return ans
n, m = map(int, input().split())
a = [list(input().strip()) for _ in range(n)]
directions = [(0, 1), (0, -1), (1, 0), (-1, 0)]
print(cut_all(a))
其中 find
和 merge
是并查集的常规操作。cut_neighbors
负责切掉某个格子周围的格子,即对于周围的每个 '*+' 格子(即不等于当前格子)都将它与当前格子所在的连通块切掉(即合并)。
时间复杂度 $O(nm \log(nm))$,由于本题数据较小,可以通过。