While a confusion matrix is undoubtedly a crucial tool in machine learning, the purpose of this post is not to explore its interpretation. Instead, I aim to share my personal journey of confusion matrix visualization.
Several years ago, during my initial machine learning project, I started by displaying a confusion matrix as raw numbers, keeping it simple yet informative. As I progressed, I explored different implementations, eventually using Seaborn’s heatmap()
function to create more visually appealing representations of the confusion matrix.
Recently, I discovered the power of Sankey Diagrams and their interactive nature for visualizing the confusion matrix. The ability to hover over nodes and links to access numerical details added a new level of understanding to the evaluation process.
Throughout this journey, I continuously refined and improved my visualization techniques, amending and collecting all the developed code in my repository. This collection now represents a comprehensive set of tools to visualize confusion matrices effectively, reflecting the evolution of my skills and understanding in the field of machine learning.
Confusion Matrix
A confusion matrix is a popular and essential tool in the field of machine learning and statistics used to evaluate the performance of a classification model. It provides a comprehensive and easy-to-understand summary of the model’s predictions and their accuracy on a given dataset by contrasting predicted labels against actual labels. The name “confusion” stems from the possibility of the classifier mistakenly identifying one class as another, leading to errors.
Axes Convention
A confusion matrix typically takes the form of a square table with rows and columns representing different classes in a classification problem. In the literature, two common variants for representing samples in a confusion matrix exist:
- The first variant arranges the matrix where each row corresponds to samples in the actual class, and each column corresponds to samples in the predicted class.
- The second variant reverses this arrangement, with each row representing samples in the predicted class and each column representing samples in the actual class.
Personally, I prefer using the first variant, where the actual labels are represented on the horizontal axes, and the predicted labels are shown on the vertical axes. As an example, let’s consider a binary classification problem with two classes: 0
(Negative) and 1
(Positive). The confusion matrix would be organized accordingly.
- TN – True Negative
- FP – False Positive
- FN – False Negative
- TP – True Positive
Helper Functions
To streamline the visualization of confusion matrices, I have created several helper functions in the metrics_utilities.py module:
cm_cr()
: This function displays unnormalized and normalized confusion matrix dataframes side by side, along with the classification report if desired.plot_cm()
: It allows plotting one unnormalized or normalized confusion matrix heatmap.plot_cm_unnorm_and_norm()
: This function plots one model’s confusion matrix heatmaps without and with normalization, side by side.plot_conf_matrices()
: It facilitates the visualization of heatmaps for normalized (default) or unnormalized confusion matrices for multiple models.plot_cm_sankey()
: This function enables the creation of an interactive confusion matrix using a Sankey diagram.
Throughout the rest of the post, I will be referring to these functions to demonstrate and analyze various confusion matrix visualizations. These helper functions aim to simplify the process of understanding and interpreting classification results effectively.
The Basic
When I started working on my initial machine learning project, I learned the basic and widely used method of presenting a confusion matrix as a raw numbers array. Utilizing the scikit-learn library and its confusion_matrix()
function, I obtained the following output:
[[1979 410]
[ 198 413]]
This output lacks visual appeal and it is just a binary class example. The confusion matrix can become more challenging to interpret as the number of classes increases in a multi-class classification problem. With more classes, the confusion matrix becomes larger and more complex, making it harder to extract meaningful insights from the raw numbers alone.
Seaborn heatmap()
To enhance the interpretability of the above confusion matrix array, we can make it more visually appealing by plotting it as a color-encoded matrix. For this purpose, I rely on the Seaborn library in Python due to its user-friendly interface and efficiency.
Let’s proceed and create a heatmap using the Seaborn heatmap()
function, utilizing the provided matrix. For that we could use the helper function plot_cm().
This really provided a clearer and more intuitive representation of the confusion matrix.
Normalized Confusion Matrix
Since real-life data often tends to be imbalanced, utilizing a confusion matrix without normalization could lead to misleading or improper conclusions. The class distribution imbalance can overshadow the actual model performance, making it crucial to normalize the confusion matrix to obtain more accurate insights.
To compute the normalized version of the confusion matrix, we’ll begin with the previously defined confusion matrix array and follow these steps:
- Divide each element in a row by the sum of the entire row.
- Each row represents the total number of actual (true) values for each class label.
By creating the normalized matrix, we obtain the percentage of predictions made by the model for each class with respect to the corresponding actual (true) label. This representation helps us better understand the model’s performance for each class, taking into account the underlying class distribution.
[[0.83 0.17]
[0.32 0.68]]
Seaborn heatmap()
To enhance the visual appeal, we can once again utilize the helper function plot_cm()
, but this time with the option to plot the normalized matrix.
Displaying the unnormalized and normalized matrices one above the other might not be very practical for comparison. Instead, we should aim to present both matrices side by side to facilitate a more convenient analysis.
To achieve this, I will implement the helper function plot_cm_unnorm_and_norm()
using the Seaborn heatmap()
function. This will allow us to visualize both the unnormalized and normalized confusion matrices side by side, making it easier to detect the differences in each class performance.
Indeed, the plot effectively showcases the utility of the normalized confusion matrix, making it much simpler to compare and contrast both matrices side by side. This visual representation helps highlight the impact of normalization on the evaluation of the model’s performance for each class, providing valuable insights at a glance.
Pandas DataFrame
We can also display both the unnormalized and normalized confusion matrices side by side using a Pandas DataFrame. To achieve this, I will utilize the cm_cr()
function, which will help create DataFrames for both the unnormalized and normalized confusion matrices. This approach will offer a convenient and comprehensive view of the model’s performance, allowing easy comparison between the two versions of the confusion matrix.
Combining with Classification Report
Certainly, utilizing Pandas DataFrame for the confusion matrix has the advantage of combining it seamlessly with the scikit-learn classification_report()
, as they both provide CLI outputs. By employing the cm_cr()
helper function, this time with the switch for the classification report, we can conveniently display the unnormalized and normalized matrices side-by-side, followed by the classification report just below them. This integrated visualization will offer a comprehensive overview of the model’s performance, making it easier to assess its effectiveness for each class label along with additional evaluation metrics from the classification report.
Comparing Multiple Models
When working on a project involving multiple models, it is beneficial to compare their confusion matrices in a single plot. The first time I encountered this requirement, I created a helper function within the notebook, which functioned well.
However, when I reused the same function for the next project, it encountered some issues and did not work as expected. One of the problems was that the initial version could only handle an even number of matrices and could not correctly display just two matrices. To address these issues, I made the necessary corrections in the code, ensuring that the function can now properly handle any number of matrices, whether it’s an odd number or just two. After these adjustments, everything worked smoothly, and I could successfully visualize and compare the confusion matrices for multiple models without any difficulties.
In my subsequent project, I took the initiative to consolidate all the previously employed confusion matrix visualization functions into a single module. This decision served two main purposes: enhancing reusability and eliminating the need for code duplication.
Now, for plotting confusion matrices of multiple models, we can efficiently utilize the function plot_conf_matrices()
. By default, this function will display normalized matrices, but it also allows the option to plot unnormalized matrices if desired. I have successfully addressed the issues encountered previously, ensuring that the function now accommodates any number (>1) of matrices with ease. This way, we can seamlessly visualize and compare the performance of multiple models in a more convenient and efficient manner.
Sankey Diagram
As mentioned earlier, the traditional way to display a confusion matrix is as raw numbers in an array, or as a heatmap. However, there exists an elegant and interactive alternative for visualizing a confusion matrix called the Sankey Diagram.
Just as I was about to wrap up my repository and begin writing my post, I came across an insightful article titled Enrich Your Confusion Matrix With A Sankey Diagram. This discovery sparked my curiosity, and I studied more about Sankey Diagrams. Inspired by this newfound knowledge, I decided to create a helper function to visualize a confusion matrix using the Sankey Diagram.
By using Python and Plotly, I was able to construct our Sankey confusion matrix, which presents the confusion matrix in a dynamic and informative manner, making it easier to comprehend and interpret the classification results.
How to Interpret Sankey Confusion Matrix?
Interpreting the Sankey confusion matrix involves understanding the flow of information it represents. The Sankey Diagram visualizes the transitions between the actual (true) class labels and the predicted class labels, providing a clear representation of how the model’s predictions align with the ground truth.
Here is how to interpret our Sankey confusion matrix, with Stays representing the negative class and Exits representing the positive class:
- Nodes – The left-side nodes represent the actual class labels (“ACTUAL Stays” and “ACTUAL Exits“), while the right-side nodes correspond to the predicted class labels (“PREDICTED Stays” and “PREDICTED Exits“).
- Node size – The size of each node is indicative of the number of samples belonging to that specific class, providing a visual sense of class distribution.
- Links – The links between nodes show the flow of samples during the classification process. The width is proportional to the number of samples classified correctly (shown in green) or incorrectly (indicated in red) between the respective class labels. A wider green link indicates a higher number of correctly classified samples, while a wider red link indicates a larger number of misclassified samples.
- Tooltips – Hovering over the nodes and links within the diagram provides users with numerical and textual representations of the confusion matrix. This interactive feature facilitates a deeper understanding of the classification results by offering detailed information about each classification outcome.
By observing the Sankey confusion matrix, we can quickly identify how well the model performs for each class and detect any patterns of confusion between different classes. It helps us gain insights into the model’s strengths and weaknesses, aiding in the evaluation and fine-tuning of the classification model.
Conclusion
In this post, I am delighted to share a collection of my favorite confusion matrix visualization techniques. The choice of which technique to use depends entirely on your specific needs, the nature of your data, the project requirements, and the goals you aim to achieve. Each visualization approach offers its unique advantages and insights, empowering you to select the most suitable one to effectively evaluate and communicate the performance of your machine learning models.
I hope you enjoyed reading this post. For those interested in exploring the full code, you can find it in this repository.
Please note that Plotly visualizations may not render on GitHub, but they should work perfectly on nbviewer, offering an interactive experience in many cases. To test the interactive functionality of our Sankey confusion matrix, you can use the following links to access properly rendered notebooks:
- Confusion Matrix Visualization.ipynb – This notebook showcases various visualization techniques for confusion matrices using a collection of helper functions.
- Confusion Matrix as Sankey Diagram.ipynb – In this notebook, you will be guided step by step through the process of creating an interactive Sankey confusion matrix using Plotly.
Feel free to explore these notebooks to gain insights into different visualization approaches and harness the interactive capabilities of Plotly in understanding your classification results.