Elixir: Dijkstra's Algorithm with Priority Queue


Implementing algorithms from Wikipedia pseudo code in Elixir can be frustrating for two reasons:

  1. Pseudo code from classic textbooks is often imperative pseudo code, not functional.
  2. Elixir lacks built-in data types such as ordered lists.

I stumbled over this while working on the day Advent of Code 2022 day 12 challenge. This post shows how to implement the algorithm in Elixir without additional libraries. Bonus: visualisation using VegaLite.

Example: Dijkstra's Algorithm

Let's see how those problems can be tackled for Dijkstra's algorithm that can find the shortest path in a graph. In, short, we'll solve the problems as follows:

  1. Use a functional reduce to "emulate" changing state in the imperative pseudo code.
  2. Build a tiny custom priority queue based on the gb_sets ordered set data type from the Erlang standard library.

Wikipedia lists this pseudo code for the priority queue based solution:

function Dijkstra(Graph, source):
  dist[source]0                       // Initialization
  create vertex priority queue Q

  for each vertex v in Graph.Vertices:
    if v ≠ source
      dist[v] ← INFINITY                 // Unknown distance from source to v
      prev[v] ← UNDEFINED                // Predecessor of v

    Q.add_with_priority(v, dist[v])

  while Q is not empty:                  // The main loop
    u ← Q.extract_min()                  // Remove and return best vertex
    for each neighbor v of u:            // Go through all v neighbors of u
      alt ← dist[u] + Graph.Edges(u, v)
      if alt < dist[v]:
        dist[v] ← alt
        prev[v] ← u
        Q.decrease_priority(v, alt)

  return dist, prev

Erlang standard modules are easy to overlook in Elixir, because they are not listed in hexdocs.pm. That doesn't mean they can't be used, but this is a slight barrier. Also, many Erlang modules don't lend themselves to Elixir pipelines, because they don't follow the Elixir standard of using the module data type as first parameter. In gb_sets the set is always the last parameter.

Update: Erlang/OTP now moved their docs to ex_docs, hosted on erlang.org/doc. So while still not in hexdocs.pm, at least the documentation is now far more accessible for Elixir devs.

Implementing a Priority Queue

With that in mind, we can implement our own priority queue based on ordered gb_sets. Complete with nice terminal output by implementing the Inspect protocol.

defmodule PrioQ do
  defstruct [:set]

  def new(), do: %__MODULE__{set: :gb_sets.empty()}
  def new([]), do: new()
  def new([{_prio, _elem} | _] = list), do: %__MODULE__{set: :gb_sets.from_list(list)}

  def add_with_priority(%__MODULE__{} = q, elem, prio) do
    %{q | set: :gb_sets.add({prio, elem}, q.set)}
  end

  def size(%__MODULE__{} = q) do
    :gb_sets.size(q.set)
  end

  def extract_min(%__MODULE__{} = q) do
    case :gb_sets.size(q.set) do
      0 -> :empty
      _else ->
        {{prio, elem}, set} = :gb_sets.take_smallest(q.set)
        {{prio, elem}, %{q | set: set}}
    end
  end

  defimpl Inspect do
    import Inspect.Algebra

    def inspect(%PrioQ{} = q, opts) do
      concat(["#PrioQ.new(", to_doc(:gb_sets.to_list(q.set), opts), ")"])
    end
  end
end

The set contains {priority, element} tuples. Erlang orders tuples based on their elements from left to right, so the tuple with the lowest priority comes first. A tuple with nil priority comes last, as nil is considered larger than all integer priorities.

From Pseudo Code to Code

This implementation translates the pseudo code to Elixir nearly one-to-one. It uses the alternatives mentioned on Wikipedia below the pseudo code to simplify things:

Instead of filling the priority queue with all nodes in the initialization phase, it is also possible to initialize it to contain only source; then, inside the if alt < dist[v] block, the decrease_priority() becomes an add_with_priority() operation if the node is not already in the queue.

Yet another alternative is to add nodes unconditionally to the priority queue and to instead check after extraction that no shorter connection was found yet. This can be done by additionally extracting the associated priority p from the queue and only processing further if p == dist[u] inside the while Q is not empty loop.

defmodule Dijkstra do
  def dijkstra(start, get_neighbours, get_distance) do
    # 2      dist[source] ← 0                           // Initialization
    # 8              dist[v] ← INFINITY                 // Unknown distance from source to v
    dist = %{start => 0}
    # 4      create vertex priority queue Q
    q = PrioQ.new([{0, start}])
    # 9              prev[v] ← UNDEFINED                // Predecessor of v
    prev = %{}

    loop(q, dist, prev, get_neighbours, get_distance)
  end

  defp loop(q, dist, prev, get_neighbours, get_distance) do
    # 14     while Q is not empty:                      // The main loop
    case q |> PrioQ.extract_min() |> check_outdated(dist) do
      :empty ->
        #  23     return dist, prev
        {dist, prev}

      :outdated ->
        loop(q, dist, prev, get_neighbours, get_distance)

      # 15         u ← Q.extract_min()                    // Remove and return best vertex
      {u, q} ->
        {dist, prev, q} =
          # 16         for each neighbor v of u:              // Go through all v neighbors of u
          for v <- get_neighbours.(u),
              reduce: {dist, prev, q} do
            {dist, prev, q} ->
              # 17             alt ← dist[u] + Graph.Edges(u, v)
              alt = dist[u] + get_distance.(u, v)

              # 18             if alt < dist[v]:
              # 19                 dist[v] ← alt
              # 20                 prev[v] ← u
              # 21                 Q.decrease_priority(v, alt)
              if alt < dist[v],
                do: {Map.put(dist, v, alt), Map.put(prev, v, u), PrioQ.add_with_priority(q, v, alt)},
                else: {dist, prev, q}
          end

        loop(q, dist, prev, get_neighbours, get_distance)
    end
  end

  defp check_outdated({{prio, u}, q}, dist) do
    if prio == dist[u], do: {u, q}, else: :outdated
  end

  defp check_outdated(other, _), do: other
end

Visualizing using VegaLite

As a bonus, let's re-build the Wikipedia example animation using kino_vega_lite.

Here's the entire livebook code.

defmodule DijkstraChart do
  @block_1 for x <- 6..16, y <- 14..16, into: MapSet.new(), do: {x, y}
  @block_2 for x <- 14..16, y <- 9..13, into: MapSet.new(), do: {x, y}
  @block MapSet.union(@block_1, @block_2)
  @start {1, 1}
  @goal {18, 18}
  @interval 25

  def chart() do
    chart =
      Vl.new(width: 400, height: 400)
      |> Vl.layers([
        Vl.new()
        |> Vl.mark(:point, filled: true, size: 400, shape: :square, color: :black)
        |> Vl.data_from_values(Enum.map(@block, fn {x, y} -> %{x: x, y: y} end))
        |> Vl.encode_field(:x, "x", type: :quantitative, scale: [domain: [1, 20]])
        |> Vl.encode_field(:y, "y", type: :quantitative, scale: [domain: [1, 20]]),
        Vl.new(data: [name: :distances])
        |> Vl.mark(:point, filled: true, size: 200, shape: :circle)
        |> Vl.encode_field(:x, "x", type: :quantitative)
        |> Vl.encode_field(:y, "y", type: :quantitative)
        |> Vl.encode_field(:color, "color",
          type: :quantitative,
          legend: false,
          scale: [scheme: :redyellowgreen, domain: [0, 30]]
        ),
        Vl.new(data: [name: :path])
        |> Vl.mark(:line, stroke_width: 10)
        |> Vl.encode_field(:x, "x", type: :quantitative)
        |> Vl.encode_field(:y, "y", type: :quantitative),
        Vl.new()
        |> Vl.mark(:point, filled: true, size: 600, shape: "triangle-up", color: :blue)
        |> Vl.data_from_values(Enum.map([@start, @goal], fn {x, y} -> %{x: x, y: y} end))
        |> Vl.encode_field(:x, "x", type: :quantitative)
        |> Vl.encode_field(:y, "y", type: :quantitative)
      ])
      |> Kino.VegaLite.new()
      |> Kino.render()

    Process.sleep(1000)
    Dijkstra.dijkstra(@start, &get_neighbours/1, &get_distance/2, &on_step(chart, &1))
  end

  def on_step(chart, {{x, y}, dist, _prev}) do
    Kino.VegaLite.push(chart, %{x: x, y: y, color: dist[{x, y}]}, dataset: :distances)
    Process.sleep(@interval)
  end

  def on_step(chart, {:done, _, prev}) do
    for {x, y} <- get_path(prev, @goal, @start) do
      Kino.VegaLite.push(chart, %{x: x, y: y}, dataset: :path)
      Process.sleep(50)
    end
  end

  def get_neighbours({x, y}) do
    [{1, 0}, {0, 1}, {1, 1}]
    |> Enum.map(fn {x2, y2} -> {x + x2, y + y2} end)
    |> Enum.filter(&(not MapSet.member?(@block, &1)))
    |> Enum.filter(fn {x, y} -> x <= 20 && y <= 20 end)
  end

  def get_distance(_, _), do: 1

  def get_path(prev, current, goal, path \\ [])

  def get_path(_prev, current, current, path), do: [current | path]

  def get_path(prev, current, goal, path) do
    get_path(prev, prev[current], goal, [current | path])
  end
end

DijkstraChart.chart()