Git Product home page Git Product logo

Comments (2)

Hiroid avatar Hiroid commented on June 14, 2024

Hi there, I found the same bug when running one class per task using iCaRL for example. The bug here seems to happen because of the function .squeeze(), line 170, which change the dists.shape from (batch_size, n_classes) to (batch_size). Just keep the dists a 2-D tensor is OK !

Try the follow code again ! :)

# _, pred_label = dists.min(1) # "IndexError: Dimension out of range" when n_classes is 1
_, pred_label = dists.min(1) if len(dists.shape) == 2 else dists.unsqueeze(1).min(1)

from online-continual-learning.

hsilva664 avatar hsilva664 commented on June 14, 2024

No, this is not the same bug. I will give an example.

Suppose self.old_classes = [1,2,3,26,24,22], meaning these were the classes seen before the evaluate() call. When you use ncm_trick (or any nearest mean based classifier, such as iCaRL), you have a mean tensor containing the feature means of each class in self.old_classes (mean in this case is of shape (6, n_features)). These means are used to calculate the distances of the test example features to each class mean, which is given in the tensor dists, with shape (batch_size, 6).

Then, what I tried to explain in the other message will happen: _, pred_label = dists.min(1) will return a pred_label array of (batch_size), shape, but with each entry in range(6), instead of in [1,2,3,26,24,22]. The pred_label has to be compared against batch_y, and this last one's entries are in [1,2,3,26,24,22]. Therefore, in line 175, this conversion of sets happen, so that the comparison can happen correctly.

What I'm pointing out is that this conversion does not change the variable pred_label, which is still in range(6). Furthermore, in the error_analysis part of the code (line 182 of the same file), which is only run if you pass error_analysis = True in the args, this same pred_label is used and compared against batch_y one more time, but now without conversion, despite the entries being in different "ranges". This comparison is used to get information such as "how many old classes were classified as new classes" (section 8.3 from their paper). In this case, there will be inconsistent error_analysis results

from online-continual-learning.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.