Dijkstra’s algorithm in python: algorithms for beginners

Home / Developer Tools / Dijkstra’s algorithm in python: algorithms for beginners
Dijkstra’s algorithm in python: algorithms for beginners

Dijkstra’s algorithm can find for you the shortest path between two nodes on a graph. It’s a must-know for any programmer. There are nice gifs and history in its Wikipedia page.

In this post I’ll use the time-tested implementation from Rosetta Codechanged just a bit for being able to process weighted and unweighted graph data, also, we’ll be able to edit the graph on the fly. I’ll explain the code block by block.

The algorithm

The algorithm is pretty simple. Dijkstra created it in 20 minutes, now you can learn to code it in the same time.

  1. Mark all nodes unvisited and store them.
  2. Set the distance to zero for our initial node and to infinity for other nodes.
  3. Select the unvisited node with the smallest distance, it’s current node now.
  4. Find unvisited neighbors for the current node and calculate their distances through the current node. Compare the newly calculated distance to the assigned and save the smaller one. For example, if the node A has a distance of 6, and the A-B edge has length 2, then the distance to B through A will be 6 + 2 = 8. If B was previously marked with a distance greater than 8 then change it to 8.
  5. Mark the current node as visited and remove it from the unvisited set.
  6. Stop, if the destination node has been visited (when planning a route between two specific nodes) or if the smallest distance among the unvisited nodes is infinity. If not, repeat steps 3-6.

Python implementation

First, imports and data formats. The original implementations suggests using namedtuple for storing edge data. We’ll do exactly that, but we’ll add a default value to the cost argument. There are many ways to do that, find what suits you best.

from collections import deque, namedtuple


# we'll use infinity as a default distance to nodes.
inf = float('inf')
Edge = namedtuple('Edge', 'start, end, cost')


def make_edge(start, end, cost=1):
    return Edge(start, end, cost)

Let’s initialize our data:

class Graph:
    def __init__(self, edges):
        # let's check that the data is right
        wrong_edges = [i for i in edges if len(i) not in [2, 3]]
        if wrong_edges:
            raise ValueError('Wrong edges data: {}'.format(wrong_edges))

        self.edges = [make_edge(*edge) for edge in edges]

Let’s find the vertices. In the original implementation the vertices are defined in the _ _ init _ _, but we’ll need them to update when edges change, so we’ll make them a property, they’ll be recounted each time we address the property. Probably not the best solution for big graphs, but for small ones it’ll go.

    @property
    def vertices(self):
        return set(
            # this piece of magic turns ([1,2], [3,4]) into [1, 2, 3, 4]
            # the set above makes it's elements unique.
            sum(
                ([edge.start, edge.end] for edge in self.edges), []
            )
        )

Now, let’s add adding and removing functionality.

    def get_node_pairs(self, n1, n2, both_ends=True):
if both_ends:
node_pairs = [[n1, n2], [n2, n1]]
else:
node_pairs = [[n1, n2]]
return node_pairs
def remove_edge(self, n1, n2, both_ends=True):
node_pairs = self.get_node_pairs(n1, n2, both_ends)
edges = self.edges[:]
for edge in edges:
if [edge.start, edge.end] in node_pairs:
self.edges.remove(edge)
def add_edge(self, n1, n2, cost=1, both_ends=True):
node_pairs = self.get_node_pairs(n1, n2, both_ends)
for edge in self.edges:
if [edge.start, edge.end] in node_pairs:
return ValueError('Edge {} {} already exists'.format(n1, n2))
self.edges.append(Edge(start=n1, end=n2, cost=cost))
if both_ends:
self.edges.append(Edge(start=n2, end=n1, cost=cost))

Let’s find neighbors for every node:

    @property
def neighbours(self):
neighbours = {vertex: set() for vertex in self.vertices}
for edge in self.edges:
neighbours[edge.start].add((edge.end, edge.cost))
return neighbours

It’s time for the algorithm! I renamed the variables so it would be easier to understand.

    def dijkstra(self, source, dest):
assert source in self.vertices, 'Such source node doesn't exist'
# 1. Mark all nodes unvisited and store them.
# 2. Set the distance to zero for our initial node 
# and to infinity for other nodes.
distances = {vertex: inf for vertex in self.vertices}
previous_vertices = {
vertex: None for vertex in self.vertices
}
distances[source] = 0
vertices = self.vertices.copy()
while vertices:
# 3. Select the unvisited node with the smallest distance, 
# it's current node now.
current_vertex = min(
vertices, key=lambda vertex: distances[vertex])
# 6. Stop, if the smallest distance 
# among the unvisited nodes is infinity.
if distances[current_vertex] == inf:
break
# 4. Find unvisited neighbors for the current node 
# and calculate their distances through the current node.
for neighbour, cost in self.neighbours[current_vertex]:
alternative_route = distances[current_vertex] + cost
# Compare the newly calculated distance to the assigned 
# and save the smaller one.
if alternative_route < distances[neighbour]:
distances[neighbour] = alternative_route
previous_vertices[neighbour] = current_vertex
# 5. Mark the current node as visited 
# and remove it from the unvisited set.
vertices.remove(current_vertex)
path, current_vertex = deque(), dest
while previous_vertices[current_vertex] is not None:
path.appendleft(current_vertex)
current_vertex = previous_vertices[current_vertex]
if path:
path.appendleft(current_vertex)
return path

Let’s use it.

graph = Graph([
("a", "b", 7),  ("a", "c", 9),  ("a", "f", 14), ("b", "c", 10),
("b", "d", 15), ("c", "d", 11), ("c", "f", 2),  ("d", "e", 6),
("e", "f", 9)])
print(graph.dijkstra("a", "e"))
>>> deque(['a', 'c', 'd', 'e'])

The whole code from above:

from collections import deque, namedtuple
# we'll use infinity as a default distance to nodes.
inf = float('inf')
Edge = namedtuple('Edge', 'start, end, cost')
def make_edge(start, end, cost=1):
return Edge(start, end, cost)
class Graph:
def __init__(self, edges):
# let's check that the data is right
wrong_edges = [i for i in edges if len(i) not in [2, 3]]
if wrong_edges:
raise ValueError('Wrong edges data: {}'.format(wrong_edges))
self.edges = [make_edge(*edge) for edge in edges]
@property
def vertices(self):
return set(
sum(
([edge.start, edge.end] for edge in self.edges), []
)
)
def get_node_pairs(self, n1, n2, both_ends=True):
if both_ends:
node_pairs = [[n1, n2], [n2, n1]]
else:
node_pairs = [[n1, n2]]
return node_pairs
def remove_edge(self, n1, n2, both_ends=True):
node_pairs = self.get_node_pairs(n1, n2, both_ends)
edges = self.edges[:]
for edge in edges:
if [edge.start, edge.end] in node_pairs:
self.edges.remove(edge)
def add_edge(self, n1, n2, cost=1, both_ends=True):
node_pairs = self.get_node_pairs(n1, n2, both_ends)
for edge in self.edges:
if [edge.start, edge.end] in node_pairs:
return ValueError('Edge {} {} already exists'.format(n1, n2))
self.edges.append(Edge(start=n1, end=n2, cost=cost))
if both_ends:
self.edges.append(Edge(start=n2, end=n1, cost=cost))
@property
def neighbours(self):
neighbours = {vertex: set() for vertex in self.vertices}
for edge in self.edges:
neighbours[edge.start].add((edge.end, edge.cost))
return neighbours
def dijkstra(self, source, dest):
assert source in self.vertices, 'Such source node doesn't exist'
distances = {vertex: inf for vertex in self.vertices}
previous_vertices = {
vertex: None for vertex in self.vertices
}
distances[source] = 0
vertices = self.vertices.copy()
while vertices:
current_vertex = min(
vertices, key=lambda vertex: distances[vertex])
vertices.remove(current_vertex)
if distances[current_vertex] == inf:
break
for neighbour, cost in self.neighbours[current_vertex]:
alternative_route = distances[current_vertex] + cost
if alternative_route < distances[neighbour]:
distances[neighbour] = alternative_route
previous_vertices[neighbour] = current_vertex
path, current_vertex = deque(), dest
while previous_vertices[current_vertex] is not None:
path.appendleft(current_vertex)
current_vertex = previous_vertices[current_vertex]
if path:
path.appendleft(current_vertex)
return path
graph = Graph([
("a", "b", 7),  ("a", "c", 9),  ("a", "f", 14), ("b", "c", 10),
("b", "d", 15), ("c", "d", 11), ("c", "f", 2),  ("d", "e", 6),
("e", "f", 9)])
print(graph.dijkstra("a", "e"))

P.S. For those of us who, like me, read more books about the Witcher than about algorithms, it’s Edsger Dijkstra, not Sigismund.

Source: dev