Git Product home page Git Product logo

Comments (3)

ChristianMarzahl avatar ChristianMarzahl commented on June 22, 2024

Dear @Monk5088,

If your dataset and with it the number of classes, the code should adapt to the new number of classes.

In the following example, the parameter n_classes defines the number of classes.

RetinaNet(encoder, n_classes=data.train_ds.c, n_anchors=3, sizes=[32], chs=8, final_bias=-4., n_conv=3)

from objectdetection.

Monk5088 avatar Monk5088 commented on June 22, 2024

Dear @ChristianMarzahl ,
I have tried changing the number of classes while initialising retinanet, but when i do learner.load(), it throws me weight mismatch error.
CODE:

batch_size = 64

do_flip = True
flip_vert = True 
max_rotate = 90 
max_zoom = 1.1 
max_lighting = 0.2
max_warp = 0.2
p_affine = 0.75 
p_lighting = 0.75 

tfms = get_transforms(do_flip=do_flip,
                      flip_vert=flip_vert,
                      max_rotate=max_rotate,
                      max_zoom=max_zoom,
                      max_lighting=max_lighting,
                      max_warp=max_warp,
                      p_affine=p_affine,
                      p_lighting=p_lighting)
train, valid = ObjectItemListSlide(train_images) ,ObjectItemListSlide(valid_images)
item_list = ItemLists(".", train, valid)
lls = item_list.label_from_func(lambda x: x.y, label_cls=SlideObjectCategoryList)
lls = lls.transform(tfms, tfm_y=True, size=patch_size)
data = lls.databunch(bs=batch_size, collate_fn=bb_pad_collate,num_workers=0).normalize() 

Here train dataset is as follows:
SlideLabelList (100 items)
x: ObjectItemListSlide
Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256)
y: SlideObjectCategoryList
ImageBBox (256, 256),ImageBBox (256, 256),ImageBBox (256, 256),ImageBBox (256, 256),ImageBBox (256, 256)
Path: .
And the train_images is a list of object_detection_fastai.helper.wsi_loader.SlideContainer objects that are created using the following function:

def create_wsi_container(annotations_df: pd.DataFrame):

    container = []

    for image_name in tqdm(annotations_df["file_name"].unique()):

        image_annos = annotations_df[annotations_df["file_name"] == image_name]

        bboxes = [box   for box   in image_annos["box"]]
        labels = [label for label in image_annos["cat"]]

        container.append(SlideContainer(image_folder/image_name, y=[bboxes, labels], level=res_level,width=patch_size, height=patch_size, sample_func=sample_function))

    return container

CODE FOR LEARNER:

backbone = "ResNet34" #["ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet150"]

backbone_model = models.resnet18
if backbone == "ResNet34":
    backbone_model = models.resnet34
if backbone == "ResNet50":
    backbone_model = models.resnet50
if backbone == "ResNet101":
    backbone_model = models.resnet101
if backbone == "ResNet150":
    backbone_model = models.resnet150

pre_trained_on_imagenet = False
encoder = create_body(backbone_model, pre_trained_on_imagenet, -2)


loss_function = "FocalLoss" 

if loss_function == "FocalLoss":
    crit = RetinaNetFocalLoss(anchors)


channels = 128 


final_bias = -4 


n_conv = 3 
model = RetinaNet(encoder, n_classes=3, 
                  n_anchors=len(scales) * len(ratios), 
                  sizes=[size[0] for size in sizes], 
                  chs=channels, # number of hidden layers for the classification head
                  final_bias=final_bias,
                  n_conv=n_conv # Number of hidden layers
                  )
voc = PascalVOCMetric(anchors, patch_size, [str(i) for i in data.train_ds.y.classes[1:]])
learn = Learner(data, model, loss_func=crit, 
                callback_fns=[BBMetrics,ShowGraph,CSVLogger,partial(GradientClipping, clip=2.0)],metrics=[voc])
learn.load("PATH/to/.pth",strict=False) 

ERROR:

/usr/local/lib/python3.7/dist-packages/fastai/basic_train.py in load(self, file, device, strict, with_opt, purge, remove_module)
    271             model_state = state['model']
    272             if remove_module: model_state = remove_module_load(model_state)
--> 273             get_model(self.model).load_state_dict(model_state, strict=strict)
    274             if ifnone(with_opt,True):
    275                 if not hasattr(self, 'opt'): self.create_opt(defaults.lr, self.wd)

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1496         if len(error_msgs) > 0:
   1497             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1498                                self.class.name, "\n\t".join(error_msgs)))
   1499         return _IncompatibleKeys(missing_keys, unexpected_keys)
   1500 

RuntimeError: Error(s) in loading state_dict for RetinaNet:
    size mismatch for classifier.3.weight: copying a param with shape torch.Size([2, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([3, 128, 3, 3]).
    size mismatch for classifier.3.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([3]).

from objectdetection.

Monk5088 avatar Monk5088 commented on June 22, 2024

I have trained it on 2 class dataset, but i want it to just predict on the 3 class dataset, and both dataset share the same first and last class in databunch ,i.e., the first databnch i trained my retinanet on has following classes:
['background', 'mitosis']
While the new dataset on which i need prediction contains the following classes:
['background', 'hard negative', 'mitosis']
So is it possible for my model to only predict the mitosis for new dataset.
Thanks

from objectdetection.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.