ITエンジニアのブログ

IT企業でエンジニアやってる人間の日常について

最短経路問題の解法:ダイクストラ法の実装

C++OCamlダイクストラ法の実装を行いました。
ダイクストラ法では、夫々の辺に負でない長さが与えられたグラフ上で、二つの頂点間の最短距離とそれを与える経路を求められます。

まずグラフを用意します。
インターネット上でいい感じのグラフを見つけ出すことができなかったので、自分で作成しました。

f:id:tfull:20160203193432p:plain

例えばこのグラフにおいて、頂点 0 と頂点 14 の最短経路を求めたいとします。

ダイクストラアルゴリズムでは、まず始点の頂点 0 を確定し、その次に最も近い頂点 1 を固定します。次に確定された二つの頂点に繋がっている最も距離の短い点を選択し、というように、一つずつ最短経路を確定させていきます。

まず C++ で実装しました。

#include <iostream>
#include <vector>

int main(){
    const int infty = 1000000000;
    int n;
    int start, goal;
    int **edge;
    int *previous;
    int *distance;
    bool *vertex;

    std::cin >> n;
    std::cin >> start >> goal;

    edge = new int*[n];
    for(int i = 0; i < n; i++){
        edge[i] = new int[n];
        for(int j = 0; j < n; j++){
            edge[i][j] = -1;
        }
    }

    while(! std::cin.eof()){
        int i, j, v;
        std::cin >> i >> j >> v;
        edge[i][j] = v;
        edge[j][i] = v;
    }

    previous = new int[n];
    for(int i = 0; i < n; i++){
        previous[i] = -1;
    }

    distance = new int[n];
    for(int i = 0; i < n; i++){
        distance[i] = infty;
    }

    vertex = new bool[n];
    for(int i = 0; i < n; i++){
        vertex[i] = true;
    }

    distance[start] = 0;

    while(true){
        int index = -1;
        int value = infty;
        for(int i = 0; i < n; i++){
            if(vertex[i] && distance[i] < value){
                index = i;
                value = distance[i];
            }
        }
        if(index == -1){
            break;
        }
        vertex[index] = false;
        for(int i = 0; i < n; i++){
            if(i == index || edge[index][i] == -1){
                continue;
            }
            if(distance[i] > distance[index] + edge[index][i]){
                distance[i] = distance[index] + edge[index][i];
                previous[i] = index;
            }
        }
    }

    std::vector<int> route;
    for(int i = goal; i != -1; i = previous[i]){
        route.push_back(i);
    }
    for(int i = route.size() - 1; i >= 0; i--){
        std::cout << route[i];
        if(i > 0){
            std::cout << " -> ";
        }
    }
    std::cout << std::endl;
    if(route[route.size() - 1] == start){
        std::cout << "distance: " << distance[goal] << std::endl;
    }else{
        std::cout << "unreachable" << std::endl;
    }

    for(int i = 0; i < n; i++){
        delete [] edge[i];
    }
    delete edge;

    delete previous;
    delete distance;
    delete vertex;

    return 0;
}

本質となる部分はかなり短い記述で表現できるようになっています。
入力を次のようにします。最初の行が頂点数 N で、頂点 0 から頂点 N-1 と数字が割り振られているとします。次の行に始点と終点、その後に頂点 i と頂点 j の辺とそのコストとします。

15
0 14
0 1 6
0 2 3
0 3 8
1 7 5
1 10 11
1 4 3
2 4 7
2 5 4
2 6 8
3 6 7
3 8 9
4 9 5
5 9 2
5 11 12
6 11 7
6 8 5
7 12 10
7 10 5
8 14 9
9 10 3
9 13 7
10 12 6
10 13 5
11 14 6
12 13 8
13 14 5

すると結果は次のようになります。

0 -> 2 -> 5 -> 9 -> 13 -> 14
distance: 21

次にこれを OCaml でも実装しました。殆ど C++ の翻訳となってしまっていますが...

exception UserError of string

let infty = 100000000

let rec input_edges l =
    try
        let x = Scanf.scanf "%d %d %d\n" (fun x y z -> ((x, y), z)) in
        input_edges (x :: l)
    with End_of_file -> l

let input () =
    let n = Scanf.scanf "%d\n" (fun x -> x) in
    let (s, g) = Scanf.scanf "%d %d\n" (fun x y -> (x, y)) in
    let edges = input_edges [] in
    (n, s, g, edges)

let rec make_0n_list n =
    if n < 0 then [] else n :: make_0n_list (n - 1)

let rec make_list x n =
    if n < 1 then [] else x :: make_list x (n - 1)

let rec remove z = function
    | [] -> []
    | x :: xs -> if x = z then xs else x :: remove z xs

let rec set z n = function
    | [] -> raise (UserError "index out of bounds")
    | x :: xs -> if n = 0 then z :: xs else x :: set z (n - 1) xs

let rec get_short vs ds =
    let rec sub (i, p) vs =
        match vs with
            | [] -> (i, p)
            | v :: vs -> if List.nth ds v < p then sub (v, List.nth ds v) vs else sub (i, p) vs
    in
    sub (-1, infty) vs

let rec sub_loop i distance previous = function
    | [] -> (distance, previous)
    | ((e0, e1), v) :: es ->
        if e0 = i || e1 = i then
            let j = e0 + e1 - i in
            if List.nth distance j > List.nth distance i + v then
                sub_loop i (set (List.nth distance i + v) j distance) (set i j previous) es
            else
                sub_loop i distance previous es
        else
            sub_loop i distance previous es

let rec main_loop edges previous distance vertex =
    let (i, _) = get_short vertex distance in
    if i = -1 then
        (previous, distance)
    else
        let vertex = remove i vertex in
        let (distance, previous) = sub_loop i distance previous edges in
        main_loop edges previous distance vertex

let rec get_route i ps =
    let j = List.nth ps i in
    if j = -1 then [i] else get_route j ps @ [i]

let rec join s = function
    | [] -> ""
    | [x] -> x
    | x :: xs -> x ^ s ^ join s xs

let _ =
    let (n, start, goal, edges) = input () in
    let previous = make_list (-1) n in
    let distance = set 0 start (make_list infty n) in
    let vertex = make_0n_list (n - 1) in
    let (previous, distance) = main_loop edges previous distance vertex in
    let route = get_route goal previous in
    print_string (join " -> " (List.map string_of_int route));
    print_string "\n";
    print_string (if List.hd route = start then "distance: " ^ string_of_int (List.nth distance goal) else "unreachable");
    print_string "\n"