【グラフ】DFS・BFS・UnionFindで連結性判定を実装する【Python】

アルゴリズム
スポンサーリンク
スポンサーリンク

グラフの連結性を判定する

本記事では、AtCoder ABC C – Path Graph?を題材に、グラフの連結性判定の方法について解説します。

C - Path Graph?
AtCoder is a programming contest site for anyone from beginners to experts. We hold weekly programming contests online.

ここでは、汎用的なアルゴリズムである深さ優先探索・幅優先探索と、UninFindというデータ構造を使った、3つの判定方法を解説します。また、Pythonでの実装例を紹介します。

スポンサーリンク

DFS(深さ優先探索)

まずは、DFS(深さ優先探索)を用いた方法です。

まず、グラフの各辺の両端に位置する頂点の番号が与えられるので、二次元リストに格納していきます。

次に、各頂点の訪問履歴(訪問済みor未訪問)を格納する、リストを作成します。

最後に、深さ優先探索関数 dfs を用意します。具体的な内容は以下の通りです。

  • dfs(訪問する頂点) のように、引数で頂点を指定して呼び出す。
  • まず、引数で指定された頂点を訪問済みにする。
  • その後、訪問した頂点と隣接する各頂点が未訪問であれば、訪問する。
    ※訪問する = dfs(隣接する頂点)を呼び出す

以上のことから、深さ優先探索では、再帰的な呼び出しを行いながら、未訪問の頂点を訪問済みにする操作を繰り返すことがお分かりいただけると思います。

ここで注意したいのが、再帰回数の上限値を設定することです。Pythonでは、再帰呼び出しの上限値は、デフォルトで1000回(環境によって違うかもしれませんが)になっているようなので、上限値を大きめに設定しておく必要があります。

DFSの計算量はO(N + M)程度なので、今回は10^6回以上で設定しておけば大丈夫かと思われます。

import sys

sys.setrecursionlimit(10 ** 8)

N, M = map(int, input().split())
# i番目の要素には頂点iと隣接する頂点を格納
graph = [[] for _ in range(N)]
for _ in range(M):
    u, v = map(int, input().split())
    u -= 1
    v -= 1
    graph[u].append(v)
    graph[v].append(u)

# 全要素が連結成分であるか
visited = [False] * N
def dfs(position):
    visited[position] = True # 訪問済みにする
    for adjacent_node in graph[position]:
        # 未訪問の場合
        if visited[adjacent_node] == False:
            dfs(adjacent_node)
dfs(0) # どこからスタートしてもいいので0から探索を開始する
print(all_united = all(visited))
スポンサーリンク

BFS(幅優先探索)

では、次はBFS(幅優先探索)を用いた方法です。

グラフを2次元リストとして、訪問履歴をBool値の1次元リストとして用意するところまでは、DFSと全く同じです。

また、幅優先探索では、キュー(Pythonではdeque)を用意する必要があります。キューは、「これから訪問する予定の頂点を、訪問する順に先頭から保持する」役割を持ったデータとなります。

ここで、幅優先探索を行う前に、初期状態として次の準備をしておきます。

  • キューに、あらかじめ頂点0(最も番号の小さい頂点)を追加する。
  • 頂点0は訪問済みとしておく。

そして、幅優先探索関数 bfs の内容は次のようになります。

  • bfs(訪問する頂点) のように、引数で頂点を指定して呼び出す。
  • キューの中身がなくなるまで、次の操作を繰り返す。
  • キューの先頭から頂点を取り出す。
  • 取り出した頂点と「隣接するすべての頂点」に対して、以下の操作を行う。
  • 取り出した頂点と「隣接する頂点」が未訪問であれば、隣接する頂点を訪問済みにする。
  • 隣接する頂点をキューの末尾に追加する。

このように、深さ優先探索では、キューから取り出した頂点と隣接する、全ての頂点を訪問済みにした後、キューに格納することを繰り返していきます。

from collections import deque

N, M = map(int, input().split())
# i番目の要素には頂点iと隣接する頂点を格納
graph = [[] for _ in range(N)]
for _ in range(M):
    u, v = map(int, input().split())
    u -= 1
    v -= 1
    graph[u].append(v)
    graph[v].append(u)

# 全要素が連結成分であるか
visited = [False] * N
que = deque()
que.append(0)
visited[0] = True
def bfs(position):
    while que:
        position = que.pop()
        for adjacent_node in graph[position]:
            if visited[adjacent_node] == False:
                visited[adjacent_node] = True
                que.append(adjacent_node)
bfs(0) # どこからスタートしてもいいので0から探索を開始する
print(all_united = all(visited))

UninFind

最後に、UnionFindというデータ構造を、Pythonのクラスを用いて実装する方法を紹介します。

ここでは、簡単な紹介にとどめておきますが、UnionFindの詳細に関しては、こちらのサイトが分かりやすいです。

始めてみる際にはとっつきにくいですが、使い方は以下の通りです。

  • union_find = UnionFind(N)のように、グラフの総頂点数を指定してインスタンス化する。
  • 各辺の両端の頂点 u, v を指定して、union_find.union(u, v)を呼び出して、データ構造に頂点u, vを追加する。
  • これでデータの準備は整ったので、後は必要なメソッドを呼び出す
    今回は、グラフの連結性判定、つまり、グループ数が1つになっていることを確認できれば良いので、group_count()を呼び出す。
from collections import defaultdict

class UnionFind():
    """
    Union Find木クラス

    Attributes
    --------------------
    n : int
        要素数
    patents : list
        指定した要素の親(1つ上の)要素を格納
        指定した要素が根の場合は,
        -(グループの要素数)  を格納
        => sizeメソッドに反映
    """
    def __init__(self, n):
        self.n = n
        self.parents = [-1] * n

    def find(self, x):
        """
        ノードxの根を見つける
        """
        if self.parents[x] < 0:
            return x
        else:
            self.parents[x] = self.find(self.parents[x])
            return self.parents[x]

    def union(self, x, y):
        """
        木に新たな要素を併合(マージ)

        Parameters
        ---------------------
        x, y : int
            併合するノード
        """
        x = self.find(x)
        y = self.find(y)

        if x == y:
            return

        if self.parents[x] > self.parents[y]:
            x, y = y, x

        self.parents[x] += self.parents[y]
        self.parents[y] = x

    def size(self, x):
        """
        xの属する木のサイズ
        """
        return -self.parents[self.find(x)]

    def same(self, x, y):
        """
        x, yが同じ木に属するか判定
        """
        return self.find(x) == self.find(y)

    def members(self, x):
        """
        xの属する木に属する要素をリストで返す
        """        
        root = self.find(x)
        return [i for i in range(self.n) if self.find(i) == root]

    def roots(self):
        """
        全ての根をリストで返す
        """ 
        return [i for i, x in enumerate(self.parents) if x < 0]

    def group_count(self):
        """
        グループの数を返す
        """ 
        return len(self.roots())

    def all_group_members(self):
        """
        全てのグループの要素情報を辞書で返す
        """         
        group_members = defaultdict(list)
        for member in range(self.n):
            group_members[self.find(member)].append(member)
        return group_members

    def __str__(self):
        """
        print(uf)で全てのグループの要素情報を簡単に出力する
        """ 
        return '\n'.join(f'{r}: {m}' for r, m in self.all_group_members().items())

N, M = map(int, input().split())
union_find = UnionFind(N)
for _ in range(M):
    u, v = map(int, input().split())
    u -= 1
    v -= 1
    union_find.union(u, v)

print(union_find.group_count() == 1)

コメント