In this lab, you'll build upon the previous lesson on confusion matrices and visualize a confusion matrix using matplotlib
.
In this lab you will:
- Create a confusion matrix from scratch
- Create a confusion matrix using scikit-learn
- Craft functions that visualize confusion matrices
Recall that the confusion matrix represents the counts (or normalized counts) of our True Positives, False Positives, True Negatives, and False Negatives. This can further be visualized when analyzing the effectiveness of our classification algorithm.
Here's an example of how a confusion matrix is displayed:
With that, let's look at some code for generating this kind of visual.
As usual, we start by fitting a model to data by importing, normalizing, splitting into train and test sets and then calling your chosen algorithm. All you need to do is run the following cell. The code should be familiar to you.
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
import pandas as pd
# Load the data
df = pd.read_csv('heart.csv')
# Define appropriate X and y
X = df[df.columns[:-1]]
y = df.target
# Normalize the data
for col in df.columns:
df[col] = (df[col] - min(df[col]))/ (max(df[col]) - min(df[col]))
# Split the data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
# Fit a model
logreg = LogisticRegression(fit_intercept=False, C=1e12, solver='liblinear')
model_log = logreg.fit(X_train, y_train)
# Preview model params
print(model_log)
# Predict
y_hat_test = logreg.predict(X_test)
print("")
# Data preview
df.head()
To gain a better understanding of confusion matrices, complete the conf_matrix()
function in the cell below. This function should:
- Take in two arguments:
y_true
, an array of labelsy_pred
, an array of model predictions
- Return a confusion matrix in the form of a dictionary, where the keys are
'TP', 'TN', 'FP', 'FN'
def conf_matrix(y_true, y_pred):
pass
# Test the function
conf_matrix(y_test, y_hat_test)
# Expected output: {'TP': 39, 'TN': 24, 'FP': 9, 'FN': 4}
To check your work, make use of the confusion_matrix()
function found in sklearn.metrics
and make sure that sklearn
's results match up with your own from above.
- Import the
confusion_matrix()
function - Use it to create a confusion matrix for
y_test
versusy_hat_test
, as above
# Import confusion_matrix
# Print confusion matrix
cnf_matrix = None
print('Confusion Matrix:\n', cnf_matrix)
Creating a pretty visual is a little more complicated. Generating the initial image is simple but you'll have to use the itertools
package to iterate over the matrix and append labels to the individual cells. In this example, cnf_matrix
is the result of the scikit-learn implementation of a confusion matrix from above.
import numpy as np
import itertools
import matplotlib.pyplot as plt
%matplotlib inline
# Create the basic matrix
plt.imshow(cnf_matrix, cmap=plt.cm.Blues)
# Add title and axis labels
plt.title('Confusion Matrix')
plt.ylabel('True label')
plt.xlabel('Predicted label')
# Add appropriate axis scales
class_names = set(y) # Get class labels to add to matrix
tick_marks = np.arange(len(class_names))
plt.xticks(tick_marks, class_names, rotation=45)
plt.yticks(tick_marks, class_names)
# Add labels to each cell
thresh = cnf_matrix.max() / 2. # Used for text coloring below
# Here we iterate through the confusion matrix and append labels to our visualization
for i, j in itertools.product(range(cnf_matrix.shape[0]), range(cnf_matrix.shape[1])):
plt.text(j, i, cnf_matrix[i, j],
horizontalalignment='center',
color='white' if cnf_matrix[i, j] > thresh else 'black')
# Add a legend
plt.colorbar()
plt.show()
Generalize the above code into a function that you can reuse to create confusion matrix visuals going forward:
cm
: confusion matrixclasses
: the class labels
def plot_confusion_matrix(cm, classes,
title='Confusion matrix',
cmap=plt.cm.Blues):
# Pseudocode/Outline:
# Print the confusion matrix (optional)
# Create the basic matrix
# Add title and axis labels
# Add appropriate axis scales
# Add labels to each cell
# Add a legend
pass
When the normalization parameter is set to True
, your function should return percentages for each class label in the visual rather than raw counts:
def plot_confusion_matrix(cm, classes,
normalize=False,
title='Confusion matrix',
cmap=plt.cm.Blues):
# Check if normalize is set to True
# If so, normalize the raw confusion matrix before visualizing
print(cm)
plt.imshow(cm, cmap=cmap)
# Add title and axis labels
plt.title('Confusion Matrix')
plt.ylabel('True label')
plt.xlabel('Predicted label')
# Add appropriate axis scales
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
# Text formatting
fmt = '.2f' if normalize else 'd'
# Add labels to each cell
thresh = cm.max() / 2.
# Here we iterate through the confusion matrix and append labels to our visualization
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment='center',
color='white' if cm[i, j] > thresh else 'black')
# Add a legend
plt.colorbar()
plt.show()
Call the function to visualize a normalized confusion matrix for cnf_matrix
.
# Plot a normalized confusion matrix
Well done! In this lab, you created a confusion matrix from scratch and honed your matplotlib
skills by visualizing confusion matrices!