Comments (9)
@PanPan0210 have you done it?
from magic-vnet.
because I have just one class and when I make VNet(out_channels=1) loss does not change, but when I make background as separate class in separate channel it appears to work but not so great
from magic-vnet.
Delete the softmax
function and retry.
Magic-VNet/magic_vnet/blocks.py
Line 208 in 4604b04
from magic-vnet.
thank you for reply! I use vnet with out_channels=1 and dice loss as follows
def dice_loss(input, target):
smooth = 1.
loss = 0.
# print(torch.unique(input), torch.unique(target))
iflat = input.view(-1)
tflat = target.view(-1)
intersection = (iflat * tflat).sum()
loss += (1 - ((2. * intersection + smooth) /
(iflat.sum() + tflat.sum() + smooth)))
return loss
but still i have issues, because loss is fluctuating. Can you help with that? maybe provide some demo code. Thank you
from magic-vnet.
You need applt sigmoid
function in your loss function. Or you can use my implemention of BinaryDiceLoss
from magic-vnet.
model = VNet(in_channels=1, out_channels=1)
model = model.cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
nTrain = len(train_loader.dataset)
nVal = len(val_loader.dataset)
for epoch in range(50):
model.train()
nProcessed = 0
for batch_idx, (data, target) in enumerate(train_loader):
if torch.cuda.is_available():
data, target = data.cuda(), target.cuda()
optimizer.zero_grad()
output = model(data.float())
loss = criterion(output, target.float())
loss.backward()
optimizer.step()
nProcessed += len(data)
print('Train Epoch {}: [{}/{} ({:.0f}%)]\tLoss: {:.8f}'.format(
epoch, nProcessed, nTrain, 100. * batch_idx / len(train_loader),
loss.item()))
model.eval()
nProcessed = 0
for batch_idx, (data, target) in enumerate(val_loader):
if torch.cuda.is_available():
data, target = data.cuda(), target.cuda()
optimizer.zero_grad()
# print(data.shape)
output = model(data.float())
loss = criterion(output, target.float())
# make_graph.save('/tmp/t.dot', loss.creator); assert(False)
nProcessed += len(data)
print('Val Epoch {}: [{}/{} ({:.0f}%)]\tLoss: {:.8f}'.format(
epoch, nProcessed, nVal, 100. * batch_idx / len(train_loader),
loss.item()))
torch.save(model, './best_model.pth')
print('Model saved!')
from magic-vnet.
can you take a look? is it right?
from magic-vnet.
because i have just empty masks
from magic-vnet.
The target should be a binary mask, and apply sigmoid
function on output before compute the dice coefficient, like loss = criterion(torch.sigmoid(output), target.float())
from magic-vnet.
Related Issues (3)
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from magic-vnet.