Git Product home page Git Product logo

Comments (9)

cruigo93 avatar cruigo93 commented on September 18, 2024

@PanPan0210 have you done it?

from magic-vnet.

cruigo93 avatar cruigo93 commented on September 18, 2024

because I have just one class and when I make VNet(out_channels=1) loss does not change, but when I make background as separate class in separate channel it appears to work but not so great

from magic-vnet.

Hsuxu avatar Hsuxu commented on September 18, 2024

Delete the softmax function and retry.

out = self.softmax(out)

from magic-vnet.

cruigo93 avatar cruigo93 commented on September 18, 2024

thank you for reply! I use vnet with out_channels=1 and dice loss as follows

def dice_loss(input, target):
    smooth = 1.
    loss = 0.
    # print(torch.unique(input), torch.unique(target))
    iflat = input.view(-1) 
    tflat = target.view(-1)
    intersection = (iflat * tflat).sum()
    
    loss += (1 - ((2. * intersection + smooth) /
                        (iflat.sum() + tflat.sum() + smooth)))
    return loss

but still i have issues, because loss is fluctuating. Can you help with that? maybe provide some demo code. Thank you

from magic-vnet.

Hsuxu avatar Hsuxu commented on September 18, 2024

You need applt sigmoid function in your loss function. Or you can use my implemention of BinaryDiceLoss

from magic-vnet.

cruigo93 avatar cruigo93 commented on September 18, 2024
    model = VNet(in_channels=1, out_channels=1)
    model = model.cuda()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    nTrain = len(train_loader.dataset)
    nVal = len(val_loader.dataset)
    for epoch in range(50):
        model.train()
        nProcessed = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            if torch.cuda.is_available():
                data, target = data.cuda(), target.cuda()
            optimizer.zero_grad()
            output = model(data.float())
            loss = criterion(output, target.float())
            loss.backward()
            optimizer.step()
            nProcessed += len(data)
            
            print('Train Epoch {}: [{}/{} ({:.0f}%)]\tLoss: {:.8f}'.format(
                epoch, nProcessed, nTrain, 100. * batch_idx / len(train_loader),
                loss.item()))
        model.eval()
        nProcessed = 0
        for batch_idx, (data, target) in enumerate(val_loader):
            if torch.cuda.is_available():
                data, target = data.cuda(), target.cuda()
            optimizer.zero_grad()
            # print(data.shape)
            output = model(data.float())
            loss = criterion(output, target.float())
            # make_graph.save('/tmp/t.dot', loss.creator); assert(False)
            
            nProcessed += len(data)
          
            print('Val Epoch {}:  [{}/{} ({:.0f}%)]\tLoss: {:.8f}'.format(
                epoch, nProcessed, nVal, 100. * batch_idx / len(train_loader),
                loss.item()))
    torch.save(model, './best_model.pth')
    print('Model saved!')

from magic-vnet.

cruigo93 avatar cruigo93 commented on September 18, 2024

can you take a look? is it right?

from magic-vnet.

cruigo93 avatar cruigo93 commented on September 18, 2024

because i have just empty masks

from magic-vnet.

Hsuxu avatar Hsuxu commented on September 18, 2024

The target should be a binary mask, and apply sigmoid function on output before compute the dice coefficient, like loss = criterion(torch.sigmoid(output), target.float())

from magic-vnet.

Related Issues (3)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.