Git Product home page Git Product logo

mjdmahasneh / simple-pytorch-semantic-segmentation-cnns Goto Github PK

View Code? Open in Web Editor NEW
7.0 2.0 0.0 5.46 MB

PyTorch Implementation of Semantic Segmentation CNNs: This repository features key architectures like UNet, DeepLabv3+, SegNet, FCN, and PSPNet. It's crafted to provide a solid foundation for Semantic Segmentation tasks using PyTorch.

Python 100.00%
cnn deep-learning deeplab-v3-plus deeplabv3plus fcn kaggle pspnet pytorch segmentation semantic-segmentation

simple-pytorch-semantic-segmentation-cnns's Issues

IoU score is not Improving

I am training the model for boundry extraction of satellite imagery. I have labelled dataset, I change config file put class number equal 2 , I am considering bacground as one and boundry second class. But after 100 epochs I am not getting any results IoU is 047 , for class 1 it is .94 and class 2 it is .0093 ... please suggest am using deeplab

other evaluations

thank you for the amazing work, but I want to get the precision, recall and F1 ,so how to add them to the code ?

Hi,

          Hi,

I didnt have time to test this solution but you can have a go and modify if needed:

@torch.inference_mode()
def evaluate_iou(net, dataloader, device, amp):
    net.eval()
    num_val_batches = len(dataloader)
    total_iou = 0.0
    total_precision = 0.0
    total_recall = 0.0
    total_f1 = 0.0

    classwise_iou = [0.0] * net.n_classes
    classwise_precision = [0.0] * net.n_classes
    classwise_recall = [0.0] * net.n_classes
    classwise_f1 = [0.0] * net.n_classes

    with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
        for batch in tqdm(dataloader, total=num_val_batches, desc='IoU evaluation', unit='batch', leave=False):
            image, mask_true = batch['image'], batch['mask']
            image = image.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
            mask_true = mask_true.to(device=device, dtype=torch.long)
            mask_pred = net(image)

            for cls in range(net.n_classes):
                mask_pred_cls = (mask_pred.argmax(dim=1) == cls).float()
                mask_true_cls = (mask_true == cls).float()

                iou_cls = iou_score(mask_pred_cls, mask_true_cls)
                precision_cls = (mask_pred_cls * mask_true_cls).sum() / (mask_pred_cls.sum() + 1e-6)
                recall_cls = (mask_pred_cls * mask_true_cls).sum() / (mask_true_cls.sum() + 1e-6)
                f1_cls = 2 * (precision_cls * recall_cls) / (precision_cls + recall_cls + 1e-6)

                classwise_iou[cls] += iou_cls
                classwise_precision[cls] += precision_cls
                classwise_recall[cls] += recall_cls
                classwise_f1[cls] += f1_cls

    # Averaging the metrics over all batches
    num_batches = max(num_val_batches, 1)
    classwise_iou = [iou / num_batches for iou in classwise_iou]
    classwise_precision = [prec / num_batches for prec in classwise_precision]
    classwise_recall = [rec / num_batches for rec in classwise_recall]
    classwise_f1 = [f1 / num_batches for f1 in classwise_f1]

    # Optionally, you can calculate overall precision, recall, and F1 across all classes, but that depends on your evaluation strategy.
    return classwise_iou, classwise_precision, classwise_recall, classwise_f1

Originally posted by @MjdMahasneh in #1 (comment)
I try to add them to the train.py, but it failed, how to solve it. and when run the predict.py, have not output images

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.