Comments (6)
Hi, shijianjian,
Thanks for your attention, but I do not know what kind of problem you have met. Firstly, ACE Loss is extended from AC Loss, Xu Chen et al. and Jun Ma et al. have proved that the AC-based loss function is useful for a single class segmentation, same with them we just demonstrate the usefulness in binary segmentation task, that is because, for the multi-class problem, there are several parameters need to be tuned carefully, maybe it can also work well, but we have not tested it. Secondly, if you want to combine it with DICE Loss, please use "mean" to please "sum", this tip I have provided in READ.ME. Finally, you can read the original paper to learn more details, maybe it can help you.
Best;
Xiangde.
from aceloss.
@Luoxd1996 Thanks for the quick response. The major mistake I did there was to input logits than the sigmoid/softmax results to the loss function. I could use "mean" to integrate the loss now. Thanks your advise.
I am doing a organ segmentation task and I am currently using beta = 1
for it, which performs not as good as dice loss (dice only = 0.96, dice+ACE = 0.93). I have noticed in the paper that you mentioned that the beta is suggested to be in (2, 10) for this kind of task. Can you share your thoughts on how to select the best beta value?
Thank you!
FYI, the code:
class ACELossVM(nn.Module):
"""
Active Contour Loss
based on total variations and mean curvature
to use these methods just as constrains (combining with dice loss or ce loss)
with torch.mean() to replace torch.sum().
For instance, for curvilinear or tubular structures image segmentation tasks:
a) β (0 < β < 2) has better segmentation results
b) β (2 < β < 10) for non-tubular structures
"""
def __init__(self, u=1, a=1e-3, b=1, from_logits=True, reduction='sum') -> None:
super().__init__()
self.u = u
self.a = a
self.b = b
self.from_logits = from_logits
self.reduction = reduction
def first_derivative(self, input):
u = input
m = u.shape[2]
n = u.shape[3]
ci_0 = (u[:, :, 1, :] - u[:, :, 0, :]).unsqueeze(2)
ci_1 = u[:, :, 2:, :] - u[:, :, 0:m - 2, :]
ci_2 = (u[:, :, -1, :] - u[:, :, m - 2, :]).unsqueeze(2)
ci = torch.cat([ci_0, ci_1, ci_2], 2) / 2
cj_0 = (u[:, :, :, 1] - u[:, :, :, 0]).unsqueeze(3)
cj_1 = u[:, :, :, 2:] - u[:, :, :, 0:n - 2]
cj_2 = (u[:, :, :, -1] - u[:, :, :, n - 2]).unsqueeze(3)
cj = torch.cat([cj_0, cj_1, cj_2], 3) / 2
return ci, cj
def second_derivative(self, input, ci, cj):
u = input
# m = u.shape[2]
n = u.shape[3]
cii_0 = (u[:, :, 1, :] + u[:, :, 0, :] - 2 * u[:, :, 0, :]).unsqueeze(2)
cii_1 = u[:, :, 2:, :] + u[:, :, :-2, :] - 2 * u[:, :, 1:-1, :]
cii_2 = (u[:, :, -1, :] + u[:, :, -2, :] - 2 * u[:, :, -1, :]).unsqueeze(2)
cii = torch.cat([cii_0, cii_1, cii_2], 2)
cjj_0 = (u[:, :, :, 1] + u[:, :, :, 0] - 2 * u[:, :, :, 0]).unsqueeze(3)
cjj_1 = u[:, :, :, 2:] + u[:, :, :, :-2] - 2 * u[:, :, :, 1:-1]
cjj_2 = (u[:, :, :, -1] + u[:, :, :, -2] - 2 * u[:, :, :, -1]).unsqueeze(3)
cjj = torch.cat([cjj_0, cjj_1, cjj_2], 3)
cij_0 = ci[:, :, :, 1:n]
cij_1 = ci[:, :, :, -1].unsqueeze(3)
cij_a = torch.cat([cij_0, cij_1], 3)
cij_2 = ci[:, :, :, 0].unsqueeze(3)
cij_3 = ci[:, :, :, 0:n - 1]
cij_b = torch.cat([cij_2, cij_3], 3)
cij = cij_a - cij_b
return cii, cjj, cij
def region(self, y_pred, y_true, u=1):
label = y_true.float()
c_in = torch.ones_like(y_pred)
c_out = torch.zeros_like(y_pred)
if self.reduction == 'mean':
region_in = torch.abs(torch.mean(y_pred * ((label - c_in) ** 2)))
region_out = torch.abs(torch.mean((1 - y_pred) * ((label - c_out) ** 2)))
elif self.reduction == 'sum':
region_in = torch.abs(torch.sum(y_pred * ((label - c_in) ** 2)))
region_out = torch.abs(torch.sum((1 - y_pred) * ((label - c_out) ** 2)))
else:
raise ValueError
region = u * region_in + region_out
return region
def elastica(self, input, a=1, b=1):
ci, cj = self.first_derivative(input)
cii, cjj, cij = self.second_derivative(input, ci, cj)
beta = 1e-8
length = torch.sqrt(beta + ci ** 2 + cj ** 2)
curvature = (beta + ci ** 2) * cjj + (beta + cj ** 2) * cii - 2 * ci * cj * cij
curvature = torch.abs(curvature) / ((ci ** 2 + cj ** 2) ** 1.5 + beta)
if self.reduction == 'mean':
elastica = torch.mean((a + b * (curvature ** 2)) * torch.abs(length))
elif self.reduction == 'sum':
elastica = torch.sum((a + b * (curvature ** 2)) * torch.abs(length))
else:
raise ValueError
return elastica
def forward(self, y_pred, y_true):
if self.from_logits:
y_pred = torch.sigmoid(y_pred)
loss = self.region(y_pred, y_true, u=self.u) + self.elastica(y_pred, a=self.a, b=self.b)
return loss
from aceloss.
Hi, shijianjian,
In my understanding, there is not good guidance to select the best value of "u, a, b", it is also the main drawback of AC-based loss functions, too many parameters, you can search them in a pre-defined space. In addition, we recommend you to read Prof. Xuecheng Tai and Prof. Tony Chan's paper, named "Image segmentation using Euler’s elastica as the regularization". By the way, our work's main contributions are not how to select the best parameters, so we just evaluate the robustness in some different values.
Best,
Xiangde.
from aceloss.
Hi, shijianjian,
Also thanks for your re-written ACE code, we will update it and add more details about the usage later. Thank you again!
Sincerely,
Xiangde.
from aceloss.
I did not do a grid search but tried with few parameter combinations. With ACE loss only, I can hardly make it work for my project. Interestingly, the DSC will not be improved after several epochs (approx. 3 epochs):
- a=0.001, b=1 >>>> DSC 0.6271
- a=0.001, b=2 >>>> DSC 0.5432
- a=0.001, b=5 >>>> DSC 0.5763
- a=0.0001, b=1 >>>> DSC 0.6107
- a=0.0001, b=2 >>>> DSC 0.5763
Additionally, I tried to add some randomness whilst training but also no luck there. As simple as following:
loss = self.region(y_pred, y_true, u=self.u * np.random.randint(5)) + self.elastica(
y_pred, a=self.a * np.random.randint(100), b=self.b * np.random.randint(5))
To me, ACE loss is not easy to make it work and it seems to be very sensitive to hyperparameter settings. I think this can be a critical point to have more people using ACE loss. Hope it could be improved in your next work!
from aceloss.
Hi,shijinajian,
@shijianjian ,thanks for your suggestions. I agree with your opinion, the ACE Loss is sensitive to the hyperparameters, that is because of the ACE Model property, but we have reduced them to 3. In addition, I do not know what's your task and what's your target's structure, so I do not know what's your problems. You can try other loss functions for your task, good luck.
Best,
Xiangde.
from aceloss.
Related Issues (7)
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from aceloss.