I didnt have time to test this solution but you can have a go and modify if needed:
@torch.inference_mode()
def evaluate_iou(net, dataloader, device, amp):
net.eval()
num_val_batches = len(dataloader)
total_iou = 0.0
total_precision = 0.0
total_recall = 0.0
total_f1 = 0.0
classwise_iou = [0.0] * net.n_classes
classwise_precision = [0.0] * net.n_classes
classwise_recall = [0.0] * net.n_classes
classwise_f1 = [0.0] * net.n_classes
with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
for batch in tqdm(dataloader, total=num_val_batches, desc='IoU evaluation', unit='batch', leave=False):
image, mask_true = batch['image'], batch['mask']
image = image.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
mask_true = mask_true.to(device=device, dtype=torch.long)
mask_pred = net(image)
for cls in range(net.n_classes):
mask_pred_cls = (mask_pred.argmax(dim=1) == cls).float()
mask_true_cls = (mask_true == cls).float()
iou_cls = iou_score(mask_pred_cls, mask_true_cls)
precision_cls = (mask_pred_cls * mask_true_cls).sum() / (mask_pred_cls.sum() + 1e-6)
recall_cls = (mask_pred_cls * mask_true_cls).sum() / (mask_true_cls.sum() + 1e-6)
f1_cls = 2 * (precision_cls * recall_cls) / (precision_cls + recall_cls + 1e-6)
classwise_iou[cls] += iou_cls
classwise_precision[cls] += precision_cls
classwise_recall[cls] += recall_cls
classwise_f1[cls] += f1_cls
# Averaging the metrics over all batches
num_batches = max(num_val_batches, 1)
classwise_iou = [iou / num_batches for iou in classwise_iou]
classwise_precision = [prec / num_batches for prec in classwise_precision]
classwise_recall = [rec / num_batches for rec in classwise_recall]
classwise_f1 = [f1 / num_batches for f1 in classwise_f1]
# Optionally, you can calculate overall precision, recall, and F1 across all classes, but that depends on your evaluation strategy.
return classwise_iou, classwise_precision, classwise_recall, classwise_f1