Comments (3)
Dear @Monk5088,
If your dataset and with it the number of classes, the code should adapt to the new number of classes.
In the following example, the parameter n_classes defines the number of classes.
RetinaNet(encoder, n_classes=data.train_ds.c, n_anchors=3, sizes=[32], chs=8, final_bias=-4., n_conv=3)
from objectdetection.
Dear @ChristianMarzahl ,
I have tried changing the number of classes while initialising retinanet, but when i do learner.load(), it throws me weight mismatch error.
CODE:
batch_size = 64
do_flip = True
flip_vert = True
max_rotate = 90
max_zoom = 1.1
max_lighting = 0.2
max_warp = 0.2
p_affine = 0.75
p_lighting = 0.75
tfms = get_transforms(do_flip=do_flip,
flip_vert=flip_vert,
max_rotate=max_rotate,
max_zoom=max_zoom,
max_lighting=max_lighting,
max_warp=max_warp,
p_affine=p_affine,
p_lighting=p_lighting)
train, valid = ObjectItemListSlide(train_images) ,ObjectItemListSlide(valid_images)
item_list = ItemLists(".", train, valid)
lls = item_list.label_from_func(lambda x: x.y, label_cls=SlideObjectCategoryList)
lls = lls.transform(tfms, tfm_y=True, size=patch_size)
data = lls.databunch(bs=batch_size, collate_fn=bb_pad_collate,num_workers=0).normalize()
Here train dataset is as follows:
SlideLabelList (100 items)
x: ObjectItemListSlide
Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256)
y: SlideObjectCategoryList
ImageBBox (256, 256),ImageBBox (256, 256),ImageBBox (256, 256),ImageBBox (256, 256),ImageBBox (256, 256)
Path: .
And the train_images is a list of object_detection_fastai.helper.wsi_loader.SlideContainer objects that are created using the following function:
def create_wsi_container(annotations_df: pd.DataFrame):
container = []
for image_name in tqdm(annotations_df["file_name"].unique()):
image_annos = annotations_df[annotations_df["file_name"] == image_name]
bboxes = [box for box in image_annos["box"]]
labels = [label for label in image_annos["cat"]]
container.append(SlideContainer(image_folder/image_name, y=[bboxes, labels], level=res_level,width=patch_size, height=patch_size, sample_func=sample_function))
return container
CODE FOR LEARNER:
backbone = "ResNet34" #["ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet150"]
backbone_model = models.resnet18
if backbone == "ResNet34":
backbone_model = models.resnet34
if backbone == "ResNet50":
backbone_model = models.resnet50
if backbone == "ResNet101":
backbone_model = models.resnet101
if backbone == "ResNet150":
backbone_model = models.resnet150
pre_trained_on_imagenet = False
encoder = create_body(backbone_model, pre_trained_on_imagenet, -2)
loss_function = "FocalLoss"
if loss_function == "FocalLoss":
crit = RetinaNetFocalLoss(anchors)
channels = 128
final_bias = -4
n_conv = 3
model = RetinaNet(encoder, n_classes=3,
n_anchors=len(scales) * len(ratios),
sizes=[size[0] for size in sizes],
chs=channels, # number of hidden layers for the classification head
final_bias=final_bias,
n_conv=n_conv # Number of hidden layers
)
voc = PascalVOCMetric(anchors, patch_size, [str(i) for i in data.train_ds.y.classes[1:]])
learn = Learner(data, model, loss_func=crit,
callback_fns=[BBMetrics,ShowGraph,CSVLogger,partial(GradientClipping, clip=2.0)],metrics=[voc])
learn.load("PATH/to/.pth",strict=False)
ERROR:
/usr/local/lib/python3.7/dist-packages/fastai/basic_train.py in load(self, file, device, strict, with_opt, purge, remove_module)
271 model_state = state['model']
272 if remove_module: model_state = remove_module_load(model_state)
--> 273 get_model(self.model).load_state_dict(model_state, strict=strict)
274 if ifnone(with_opt,True):
275 if not hasattr(self, 'opt'): self.create_opt(defaults.lr, self.wd)
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
1496 if len(error_msgs) > 0:
1497 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1498 self.class.name, "\n\t".join(error_msgs)))
1499 return _IncompatibleKeys(missing_keys, unexpected_keys)
1500
RuntimeError: Error(s) in loading state_dict for RetinaNet:
size mismatch for classifier.3.weight: copying a param with shape torch.Size([2, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([3, 128, 3, 3]).
size mismatch for classifier.3.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([3]).
from objectdetection.
I have trained it on 2 class dataset, but i want it to just predict on the 3 class dataset, and both dataset share the same first and last class in databunch ,i.e., the first databnch i trained my retinanet on has following classes:
['background', 'mitosis']
While the new dataset on which i need prediction contains the following classes:
['background', 'hard negative', 'mitosis']
So is it possible for my model to only predict the mitosis for new dataset.
Thanks
from objectdetection.
Related Issues (20)
- No module name BoundingBox HOT 1
- size problem
- Inference Error HOT 8
- Fastai V2
- Anchors are totally wrong HOT 2
- In examples, incorrect anchors preview
- Exception: It's not possible to apply those transforms to your dataset HOT 5
- Validation scores
- Prediction Example HOT 2
- Index Error when using Transfer Learning HOT 1
- Migrating ObjectDetection to FastAIv2 HOT 2
- Error in NMS function
- any script to get prediction on single .scn file with coco .json annotations
- How can i get the predicted classes from my retinanet model detection over the images.
- How to use fastai v2 for this library HOT 1
- How to save the image patches that are created from databunch method?
- Requesting object detection for fastai v2 HOT 1
- mAP seems to be [email protected] not mAP value
- Any script to train the model ? HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from objectdetection.