Data Visualization in Python with matplotlib, Seaborn and Bokeh

Originally posted on machinelearningmastery.

Data visualization is an important aspect of all AI and machine learning applications. You can gain key insights of your data through different graphical representations. In this tutorial, we’ll talk about a few options for data visualization in Python. We’ll use the MNIST dataset and the Tensorflow library for number crunching and data manipulation. To illustrate various methods for creating different types of graphs, we’ll use the Python’s graphing libraries namely matplotlib, Seaborn and Bokeh.

After completing this tutorial, you will know:

  • How to visualize images in matplotlib
  • How to make scatter plots in matplotlib, Seaborn and Bokeh
  • How to make multiline plots in matplotlib, Seaborn and Bokeh

Let’s get started.

Data Visualization in Python With matplotlib, Seaborn and Bokeh

Tutorial Overview

This tutorial is divided into 7 parts; they are:

  • Preparation of scatter data
  • Figures in matplotlib
  • Scatter plots in matplotlib and Seaborn
  • Scatter plots in Bokeh
  • Preparation of line plot data
  • Line plots in matplotlib, Seaborn, and Bokeh
  • More on visualization

Preparation of scatter data

In this post, we will use matplotlib, seaborn, and bokeh. They are all external libraries need to be installed. To install them using pip, run the following command:

For demonstration purposes, we will also use the MNIST handwritten digits dataset. We will load it from Tensorflow and run PCA algorithm on it. Hence we will also need to install Tensorflow and pandas:

The code afterwards will assume the following imports are executed:

We load the MNIST dataset from keras.datasets library. To keep things simple, we’ll retain only the subset of data containing the first three digits. We’ll also ignore the test set for now.



Figures in matplotlib

Seaborn is indeed an add-on to matplotlib. Therefore you need to understand how matplotlib handles plots even if you’re using Seaborn.

Matplotlib calls its canvas the figure. You can divide the figure into several sections called subplots, so you can put two visualizations side-by-side.

As an example, let’s visualize the first 16 images of our MNIST dataset using matplotlib. We’ll create 2 rows and 8 columns using the subplots() function. The subplots() function will create the axes objects for each unit. Then we will display each image on each axes object using the imshow() method. Finally, the figure will be shown using the show() function.


First 16 images of the training dataset displayed in 2 rows and 8 columns

First 16 images of the training dataset displayed in 2 rows and 8 columns


Here we can see a few properties of matplotlib. There is a default figure and default axes in matplotlib. There are a number of functions defined in matplotlib under the pyplot submodule for plotting on the default axes. If we want to plot on a particular axes, we can use the plotting function under the axes objects. The operations to manipulate a figure is procedural. Meaning, there is a data structure remembered internally by matplotlib and our operations will mutate it. The show() function simply display the result of a series of operations. Because of that, we can gradually fine-tune a lot of details on the figure. In the example above, we hid the “ticks” (i.e., the markers on axes) by setting xticks and yticks to empty lists.

Scatter plots in matplotlib and Seaborn

One of the common visualizations we use in machine learning projects is the scatter plot.

As an example, we apply PCA to the MNIST dataset and extract the first three components of each image. In the code below, we compute the eigenvectors and eigenvalues from the dataset, then projects the data of each image along the direction of the eigenvectors, and store the result in x_pca. For simplicity, we didn’t normalize the data to zero mean and unit variance before computing the eigenvectors. This omission does not affect our purpose of visualization.

The eigenvalues printed are as follows:

The array x_pca is in shape 18623 x 784. Let’s consider the last two columns as the x- and y-coordinates and make the point of each row in the plot. We can further color the point according to which digit it corresponds to.

The following code generates a scatter plot using matplotlib. The plot is created using the axes object’s scatter() function, which takes the x- and y-coordinates as the first two argument. The c argument to scatter() method specifies a value that will become its color. The s argument specifies its size. The code also creates a legend and adds a title to the plot.


2D scatter plot generated using Matplotlib

2D scatter plot generated using matplotlib


Putting the above altogether, the following is the complete code to generate the 2D scatter plot using matplotlib:

Matplotlib also allows a 3D scatter plot to be produced. To do so, you need to create an axes object with 3D projection first. Then the 3D scatter plot is created with the scatter3D() function, with the x-, y-, and z-coordinates as the first three arguments. The code below uses the data projected along the eigenvectors corresponding to the three largest eigenvalues. Instead of creating a legend, this code creates a colorbar.


3D scatter plot generated using Matplotlib

3D scatter plot generated using matplotlib


The scatter3D() function just puts the points onto the 3D space. Afterwards, we can still modify how the figure displays such as the label of each axis and the background color. But in 3D plots, one common tweak is the viewport, namely, the angle we look at the 3D space. Viewport is controlled by the view_init() function in the axes object:

The viewport is controlled by the elevation angle (i.e., angle to the horizon plane) and the azimuthal angle (i.e., rotation on the horizon plane). By default, matplotlib uses 30 degree elevation and -60 degree azimuthal, as shown above.

Putting everything together, the following is the complete code to create the 3D scatter plot in matplotlib:

Creating scatter plots in Seaborn is similarly easy. The scatterplot() method automatically creates a legend and uses different symbols for different classes when plotting the points. By default, the plot is created on the “current axes” from matplotlib, unless the axes object is specified by the ax argument.


2D scatter plot generated using Seaborn

2D scatter plot generated using Seaborn


The benefit of Seaborn over matplotlib is two fold: First we have a polished default style. For example, if we compare the point style in the two scatter plots above, the Seaborn one has a border around the dot to prevent the many points smurged together. Indeed, if we run the following line before calling any matplotlib functions:

we can still use the matplotlib functions but get a better looking figure by using Seaborn’s style. Secondly, it is more convenient to use Seaborn if we are using pandas DataFrame to hold our data. As an example, let’s convert our MNIST data from a tensor into a pandas DataFrame:

which the DataFrame looks like the following:

Then, we can reproduce the Seaborn’s scatter plot with the following:

which we do not pass in arrays as coordinates to the scatterplot() function, but column names to the data argument instead.

The following is the complete code to generate a scatter plot using Seaborn with the data stored in pandas:

Seaborn as a wrapper to some matplotlib functions, is not replacing matplotlib entirely. Plotting in 3D, for example, are not supported by Seaborn and we still need to resort to matplotlib functions for such purposes.

Scatter plots in Bokeh

The plots created by matplotlib and Seaborn are static images. If you need to zoom in, pan, or toggle the display of some part of the plot, you should use Bokeh instead.

Creating scatter plots in Bokeh is also easy. The following code generates a scatter plot and adds a legend. The show() method from Bokeh library opens a new browser window to display the image. You can interact with the plot by scaling, zooming, scrolling and more options that are shown in the toolbar next to the rendered plot. You can also hide part of the scatter by clicking on the legend.

Bokeh will produce the plot in HTML with Javascript. All your actions to control the plot are handled by some Javascript functions. Its output would looks like the following:

2D scatter plot generated using Bokeh in a new browser window. Note the various options on the right for interacting with the plot.

2D scatter plot generated using Bokeh in a new browser window. Note the various options on the right for interacting with the plot.

The following is the complete code to generate the above scatter plot using Bokeh:

If you are rendering the Bokeh plot in Jupyter notebook, you may see the plot is produced in a new browser window. To put the plot in the Jupyter notebook, you need to tell Bokeh that you are under the notebook environment by running the following before the Bokeh functions:

Also note that we create the scatter plot of the three digit in a loop, one digit at a time. This is required to make the legend interactive, since each time scatter() is called, a new object is created. If we use create all scatter points at once, like the following, clicking on the legend will hide and show everything instead of only the points of one of the digits.


Preparation of line plot data

Before we move on to show how we can visualize line plot data, let’s generate some data for illustration. Below is a simple classifier using the Keras library, which we train it to learn the handwritten digit classification. The history object returned by the fit() method is a dictionary that contains all the learning history of the training stage. For simplicity, we’ll train the model using only 10 epochs.

The code above will produce a dictionary with keys lossaccuracyval_loss, and val_accuracy, as follows:


Line plots in matplotlib, Seaborn, and Bokeh

Let’s look at various options for visualizing the learning history obtained from training our classifier.

Creating a multi-line plots in matplotlib is as trivial as following. We obtain the list of values of the training and validation accuracies from the history, and by default, matplotlib will consider that as sequential data (i.e., x-coordinates are integers counting from 0 onwards).


Multi-line plot using Matplotlib

Multi-line plot using Matplotlib


The complete code for creating the multi-line plot is as follows:

Similarly, we can do the same in Seaborn. As we have seen in the case of scatter plot, we can pass in the data to Seaborn as a series of values explicitly, or through a pandas DataFrame. Let’s plot the training loss and validation loss in the following using a pandas DataFrame:

It will print the following table, which is the DataFrame we created from the history:

And the plot it generated is as follows:

Multi-line plot using Seaborn

Multi-line plot using Seaborn

By default, Seaborn will understand the column labels from the DataFrame and use it as legend. In the above, we provide a new label for each plot. Moreover, the x-axis of the line plot is taken from the index of the DataFrame by default, which is integer running from 0 to 9 in our case as we can see above.

The complete code of producing the plot in Seaborn is as follows:

As you can expect, we can also provide arguments x and y together with data to our call to lineplot() as in our example of Seaborn scatter plot above if we want to control the x- and y-coordinates precisely.

Bokeh can also generate multi-line plots, as illustrated in the code below. As we saw in the scatter plot example, we need to provide the x- and y-coordinates explicitly and do one line at a time. Again, the show() method opens a new browser window to display the plot and you can interact with it.


Multi-line plot using Bokeh. Note the options for user interaction shown on the toolbar on the right.

Multi-line plot using Bokeh. Note the options for user interaction shown on the toolbar on the right.


The complete code for making the Bokeh plot is as follows:


More on visualization

Each of the tools we introduced above has a lot more functions for us to control the bits and pieces of the details in the visualization. It is important to search on their respective documentation to find the ways you can polish your plots. It is equally important to check out the example code in their documentation to learn how you can possibly make your visualization better.

Without providing too much detail, here are some ideas that you may want to add to your visualization:

  • add auxiliary lines, such as to mark the training and validation dataset on a time series data. The axvline() function from matplotlib can make a vertical line on plots for this purpose
  • add annotations, such as arrows and text labels to identify key points on the plot. See the annotate() function in matplotlib axes objects.
  • control the transparency level in case of overlapping graphic elements. All plotting functions we introduced above allows an alpha argument to provide a value between 0 and 1 for how much we can see through the graph.
  • if the data is better illustrated this way, we may show some of the axes in log scale. It is usually called the log plot or semilog plot.

Before we conclude this post, the following is an example that we can create a side-by-side visualization in matplotlib, which one of them is created using Seaborn:


Side-by-side visualization created using matplotlib and Seaborn


The equivalent in Bokeh is to create each subplot separately and then specify the layout when we show it:


Side-by-side plot created in Bokeh


Further Reading

This section provides more resources on the topic if you are looking to go deeper.



API Reference


In this tutorial, you discovered various options for data visualization in Python.

Specifically, you learned:

  • How to create subplots in different rows and columns
  • How to render images using Matplotlib
  • How to generate 2D and 3D scatter plots using Matplotlib
  • How to create 2D plots using seaborn and Bokeh
  • How to create multi-line plots using Matplotlib, Seaborn and Bokeh


Source: machinelearningmastery