ITエンジニアのブログ

IT企業でエンジニアやってる人間の日常について

bit DP で TSP を解いてみた。

bit DP どころか DP すら碌に書いたことが無いレベルだが、 ARC016 の C 問題で bit DP を使うことになるようなので練習しました。学んだことを記しておきます。

bit DP とは

例えば TSP (巡回セールスマン問題)で、拠点 ABCD を回ることを考える場合、既に巡回した拠点の集合が次のそれぞれの場合に対して、最小距離を保持しておきます。

  • ---- [何も巡回していない場合の最小コスト]
  • A--- [A のみ巡回した場合の最小コスト]
  • -B-- [B のみ巡回した場合の最小コスト]
  • AB-- [A と B のみ巡回した場合の最小コスト]
  • --C- [以下同様]
  • A-C-
  • -BC-
  • ABC-
  • ---D
  • ...

つまり、 2**4=16 通りの状態に対して最小距離を保持しておく。この状態の列挙が bit っぽいから bit DP というらしい。

大体のアルゴリズムとしては、頂点を 0 として、

void start(){
  for(int i = 1; i < n; i++){
    dp[(1 << i) + 1] = cost[0][i];
    tsp(i, 1 << i);
  }
}

void tsp(今の場所 v, 巡回済みの集合 bits){
  if(巡回終了){ return; }
  for(int i = 0; i < n; i++){
    if(i == v || bits が i を含む){
        continue;
    }
    
    if(dp[bits + (1 << i)] > cost[i][v] + dp[bits]){
      dp[bits + (1 << i)] > cost[i][v] + dp[bits];
      tsp(i, bits + (1 << i));
    }
  }
}

つまり、最小コストから一つ辺を伸ばして、伸ばした先が最小コストを更新すれば再帰の繰り返しで、巡回が終了するまで繰り返せば OK という感じになります。

実装

なかなか、最後に出発点に戻るという制約を付加するとかなりややこしいことになったので、全ての拠点を巡回するだけのプログラムになりました。 Python3

# coding: utf-8

Infinity = 1000000

class TSP:
    def __init__(self, n, costs):
        self.n = n
        self.costs = costs
        self.dp = [(Infinity, None)] * (2 ** n)

    def run(self):
        m = (1 << self.n) - 1

        for i in range(self.n):
            if i == 0:
                continue
            self.dp[(1 << i) + 1] = (self.costs[0][i], 0)
            self.rec(i, (1 << i) + 1)

    def rec(self, v, bits):
        if bits == (1 << self.n) - 1:
            return

        for i in range(self.n):
            if i == v:
                continue
            if (bits & (1 << i)) > 0:
                continue
            if self.dp[bits + (1 << i)][0] > self.dp[bits][0] + self.costs[v][i]:
                self.dp[bits + (1 << i)] = (self.dp[bits][0] + self.costs[v][i], v)
                self.rec(i, bits + (1 << i))

    def route(self):
        bits = (1 << self.n) - 1
        (_, prev) = self.dp[(1 << self.n) - 1]
        rt = []

        while prev != 0:
            rt.append(prev)
            bits -= 1 << prev
            (_, prev) = self.dp[bits]

        rt.reverse()

        return [0] + rt + [x for x in list(range(self.n)) if x not in rt + [0]]

def main():
    start = input().strip()
    vertices = [start]
    edges = []

    while True:
        line = gets()

        if line == None:
            break

        xs = line.split()
        s = xs[0]
        c = int(xs[1])
        g = xs[2]

        if s not in vertices:
            vertices.append(s)
        if g not in vertices:
            vertices.append(g)

        edges.append((s, c, g))

    v_map = {}
    n = len(vertices)

    for i in range(n):
        v_map[vertices[i]] = i

    costs = [[None for j in range(n)] for i in range(n)]

    for (s, c, g) in edges:
        costs[v_map[s]][v_map[g]] = c
        costs[v_map[g]][v_map[s]] = c

    tsp = TSP(len(vertices), costs)
    tsp.run()

    route = tsp.route()

    print("->".join([vertices[r] for r in route]))

def gets():
    try:
        return input()
    except:
        return None

if __name__ == '__main__':
    main()

コストと共に、どこから来たのか、一つ前の履歴も残すようにしました。出力の部分で、最初と最後だけは追加しなければなりませんが。

次の入力に対し、

A
A 4 B
A 3 C
A 2 D
A 7 E
A 4 F
B 5 C
B 4 D
B 8 E
B 7 F
C 4 D
C 5 E
C 6 F
D 7 E
D 7 F
E 6 F

次のように出力します。

A->D->B->C->E->F