Git Product home page Git Product logo

Comments (6)

Luoxd1996 avatar Luoxd1996 commented on July 2, 2024

Hi, shijianjian,
Thanks for your attention, but I do not know what kind of problem you have met. Firstly, ACE Loss is extended from AC Loss, Xu Chen et al. and Jun Ma et al. have proved that the AC-based loss function is useful for a single class segmentation, same with them we just demonstrate the usefulness in binary segmentation task, that is because, for the multi-class problem, there are several parameters need to be tuned carefully, maybe it can also work well, but we have not tested it. Secondly, if you want to combine it with DICE Loss, please use "mean" to please "sum", this tip I have provided in READ.ME. Finally, you can read the original paper to learn more details, maybe it can help you.
Best;
Xiangde.

from aceloss.

shijianjian avatar shijianjian commented on July 2, 2024

@Luoxd1996 Thanks for the quick response. The major mistake I did there was to input logits than the sigmoid/softmax results to the loss function. I could use "mean" to integrate the loss now. Thanks your advise.

I am doing a organ segmentation task and I am currently using beta = 1 for it, which performs not as good as dice loss (dice only = 0.96, dice+ACE = 0.93). I have noticed in the paper that you mentioned that the beta is suggested to be in (2, 10) for this kind of task. Can you share your thoughts on how to select the best beta value?

Thank you!

FYI, the code:

class ACELossVM(nn.Module):
    """
    Active Contour Loss
    based on total variations and mean curvature

    to use these methods just as constrains (combining with dice loss or ce loss)
    with torch.mean() to replace torch.sum().

    For instance, for curvilinear or tubular structures image segmentation tasks:
        a) β (0 < β < 2) has better segmentation results
        b) β (2 < β < 10) for non-tubular structures
    """
    def __init__(self, u=1, a=1e-3, b=1, from_logits=True, reduction='sum') -> None:
        super().__init__()
        self.u = u
        self.a = a
        self.b = b
        self.from_logits = from_logits
        self.reduction = reduction

    def first_derivative(self, input):
        u = input
        m = u.shape[2]
        n = u.shape[3]

        ci_0 = (u[:, :, 1, :] - u[:, :, 0, :]).unsqueeze(2)
        ci_1 = u[:, :, 2:, :] - u[:, :, 0:m - 2, :]
        ci_2 = (u[:, :, -1, :] - u[:, :, m - 2, :]).unsqueeze(2)
        ci = torch.cat([ci_0, ci_1, ci_2], 2) / 2

        cj_0 = (u[:, :, :, 1] - u[:, :, :, 0]).unsqueeze(3)
        cj_1 = u[:, :, :, 2:] - u[:, :, :, 0:n - 2]
        cj_2 = (u[:, :, :, -1] - u[:, :, :, n - 2]).unsqueeze(3)
        cj = torch.cat([cj_0, cj_1, cj_2], 3) / 2

        return ci, cj

    def second_derivative(self, input, ci, cj):
        u = input
        # m = u.shape[2]
        n = u.shape[3]

        cii_0 = (u[:, :, 1, :] + u[:, :, 0, :] - 2 * u[:, :, 0, :]).unsqueeze(2)
        cii_1 = u[:, :, 2:, :] + u[:, :, :-2, :] - 2 * u[:, :, 1:-1, :]
        cii_2 = (u[:, :, -1, :] + u[:, :, -2, :] - 2 * u[:, :, -1, :]).unsqueeze(2)
        cii = torch.cat([cii_0, cii_1, cii_2], 2)

        cjj_0 = (u[:, :, :, 1] + u[:, :, :, 0] - 2 * u[:, :, :, 0]).unsqueeze(3)
        cjj_1 = u[:, :, :, 2:] + u[:, :, :, :-2] - 2 * u[:, :, :, 1:-1]
        cjj_2 = (u[:, :, :, -1] + u[:, :, :, -2] - 2 * u[:, :, :, -1]).unsqueeze(3)

        cjj = torch.cat([cjj_0, cjj_1, cjj_2], 3)

        cij_0 = ci[:, :, :, 1:n]
        cij_1 = ci[:, :, :, -1].unsqueeze(3)

        cij_a = torch.cat([cij_0, cij_1], 3)
        cij_2 = ci[:, :, :, 0].unsqueeze(3)
        cij_3 = ci[:, :, :, 0:n - 1]
        cij_b = torch.cat([cij_2, cij_3], 3)
        cij = cij_a - cij_b

        return cii, cjj, cij

    def region(self, y_pred, y_true, u=1):
        label = y_true.float()
        c_in = torch.ones_like(y_pred)
        c_out = torch.zeros_like(y_pred)
        if self.reduction == 'mean':
            region_in = torch.abs(torch.mean(y_pred * ((label - c_in) ** 2)))
            region_out = torch.abs(torch.mean((1 - y_pred) * ((label - c_out) ** 2)))
        elif self.reduction == 'sum':
            region_in = torch.abs(torch.sum(y_pred * ((label - c_in) ** 2)))
            region_out = torch.abs(torch.sum((1 - y_pred) * ((label - c_out) ** 2)))
        else:
            raise ValueError
        region = u * region_in + region_out
        return region

    def elastica(self, input, a=1, b=1):
        ci, cj = self.first_derivative(input)
        cii, cjj, cij = self.second_derivative(input, ci, cj)
        beta = 1e-8
        length = torch.sqrt(beta + ci ** 2 + cj ** 2)
        curvature = (beta + ci ** 2) * cjj + (beta + cj ** 2) * cii - 2 * ci * cj * cij
        curvature = torch.abs(curvature) / ((ci ** 2 + cj ** 2) ** 1.5 + beta)
        if self.reduction == 'mean':
            elastica = torch.mean((a + b * (curvature ** 2)) * torch.abs(length))
        elif self.reduction == 'sum':
            elastica = torch.sum((a + b * (curvature ** 2)) * torch.abs(length))
        else:
            raise ValueError
        return elastica

    def forward(self, y_pred, y_true):
        if self.from_logits:
            y_pred = torch.sigmoid(y_pred)
        loss = self.region(y_pred, y_true, u=self.u) + self.elastica(y_pred, a=self.a, b=self.b)
        return loss

from aceloss.

Luoxd1996 avatar Luoxd1996 commented on July 2, 2024

Hi, shijianjian,
In my understanding, there is not good guidance to select the best value of "u, a, b", it is also the main drawback of AC-based loss functions, too many parameters, you can search them in a pre-defined space. In addition, we recommend you to read Prof. Xuecheng Tai and Prof. Tony Chan's paper, named "Image segmentation using Euler’s elastica as the regularization". By the way, our work's main contributions are not how to select the best parameters, so we just evaluate the robustness in some different values.
Best,
Xiangde.

from aceloss.

Luoxd1996 avatar Luoxd1996 commented on July 2, 2024

Hi, shijianjian,
Also thanks for your re-written ACE code, we will update it and add more details about the usage later. Thank you again!
Sincerely,
Xiangde.

from aceloss.

shijianjian avatar shijianjian commented on July 2, 2024

I did not do a grid search but tried with few parameter combinations. With ACE loss only, I can hardly make it work for my project. Interestingly, the DSC will not be improved after several epochs (approx. 3 epochs):

  1. a=0.001, b=1 >>>> DSC 0.6271
  2. a=0.001, b=2 >>>> DSC 0.5432
  3. a=0.001, b=5 >>>> DSC 0.5763
  4. a=0.0001, b=1 >>>> DSC 0.6107
  5. a=0.0001, b=2 >>>> DSC 0.5763

Additionally, I tried to add some randomness whilst training but also no luck there. As simple as following:

        loss = self.region(y_pred, y_true, u=self.u * np.random.randint(5)) + self.elastica(
            y_pred, a=self.a * np.random.randint(100), b=self.b * np.random.randint(5))

To me, ACE loss is not easy to make it work and it seems to be very sensitive to hyperparameter settings. I think this can be a critical point to have more people using ACE loss. Hope it could be improved in your next work!

from aceloss.

Luoxd1996 avatar Luoxd1996 commented on July 2, 2024

Hi,shijinajian,
@shijianjian ,thanks for your suggestions. I agree with your opinion, the ACE Loss is sensitive to the hyperparameters, that is because of the ACE Model property, but we have reduced them to 3. In addition, I do not know what's your task and what's your target's structure, so I do not know what's your problems. You can try other loss functions for your task, good luck.
Best,
Xiangde.

from aceloss.

Related Issues (7)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.