Comments (7)
I debug the class ActivationsAndGradients
, but I still can't solve the problem.
from pytorch-grad-cam.
@ross-Hr Hello! This question has been bothering me for a few days, and I tried a lot of things but I couldn't solve it.
I carefully ran through and compared the tutorial code https://jacobgil.github.io/pytorch-gradcam-book/Class%20Activation%20Maps%20for%20Semantic%20Segmentation.html with my own to try to solve this problem.
This is my error output and the activations is 0. I don't know what caused it:
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
And, below is the debug result of the official tutorial:
If you have any idea what might have caused this, please let me know.
from pytorch-grad-cam.
I found that my own model was registered register_forward_hook
in the process of visualization using grad-cam, but it was not executed in return self.model(x)
in class ActivationsAndGradients
, which means save_activation
has not been called.
class ActivationsAndGradients:
""" Class for extracting activations and
registering gradients from targetted intermediate layers """
def __init__(self, model, target_layers, reshape_transform):
self.model = model
self.gradients = []
self.activations = []
self.reshape_transform = reshape_transform
self.handles = []
for target_layer in target_layers:
self.handles.append(
target_layer.register_forward_hook(self.save_activation))
# Because of https://github.com/pytorch/pytorch/issues/61519,
# we don't use backward hook to record gradients.
self.handles.append(
target_layer.register_forward_hook(self.save_gradient))
def save_activation(self, module, input, output):
activation = output
if self.reshape_transform is not None:
activation = self.reshape_transform(activation)
self.activations.append(activation.cpu().detach())
def save_gradient(self, module, input, output):
if not hasattr(output, "requires_grad") or not output.requires_grad:
# You can only register hooks on tensor requires grad.
return
# Gradients are computed in reverse order
def _store_grad(grad):
if self.reshape_transform is not None:
grad = self.reshape_transform(grad)
self.gradients = [grad.cpu().detach()] + self.gradients
output.register_hook(_store_grad)
def __call__(self, x):
self.gradients = []
self.activations = []
return self.model(x)
def release(self):
for handle in self.handles:
handle.remove()
from pytorch-grad-cam.
Hi all, I will be looking into it.
Some context will help - what is the model you're using? is it a custom model? Object detection, or something else?
Are you using a model wrapper ?
Anything else you can share ?
from pytorch-grad-cam.
Thank you for your reply. Grad-cam caught my eye in some papers, and I wanted to implement reliable feature visualizations in my own models. So, I tried to use grad-cam for feature visualization on my own implementation of a semantic segmentation model based on Maskdino, a semantic segmentation variant of the DETR model.
In addition, my model has two backbones that correspond to RGB and Thermal. My model is implemented based on mmdetection, which is also a framework for further encapsulation based on pytorch. My version is pytorch=1.13.1
grad-cam=1.4.8
, and I tried the grad-cam semantic segmentation tutorial script , the output results is correct.
Here is my semantic segmentation feature visualizition script based on pytorch-grad-cam implementation:
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), ".")))
from PIL import Image
from pprint import pprint
import os, argparse, time, sys, torch
import numpy as np
from torchvision.transforms import Compose, Normalize, ToTensor
import torch
from torch.autograd import Variable
import mmcv
import torch.nn as nn
import copy
import warnings
from pathlib import Path
from typing import Optional, Sequence, Union
from mmengine.config import Config
from mmengine.model.utils import revert_sync_batchnorm
from mmengine.registry import init_default_scope
from mmengine.runner import load_checkpoint
from mmdet.registry import DATASETS
from mmdet.evaluation import get_classes
from mmdet.registry import MODELS
from pytorch_grad_cam import GradCAM, GradCAMPlusPlus, LayerCAM, XGradCAM, EigenCAM, EigenGradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
from mmdet.apis import init_detector, inference_detector
##################################################################################################################################################################
# Supported grad-cam type map
METHOD_MAP = {
'gradcam': GradCAM,
'gradcam++': GradCAMPlusPlus,
'xgradcam': XGradCAM,
'eigencam': EigenCAM,
'eigengradcam': EigenGradCAM,
'layercam': LayerCAM,
}
DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device(
'cpu')
image_file = './data/MF_dataset/images/val/'
rgb_file = './data/MF_dataset/seperated/'
image_name = "01007N"
IMAGE_FILE_PATH = os.path.join(image_file, image_name + (".png"))
RGB_FILE_PATH = os.path.join(rgb_file, image_name + ("_rgb.png"))
THER_FILE_PATH = os.path.join(rgb_file, image_name + ("_th.png"))
MEAN = [0.535, 0.520, 0.581]
STD = [0.149, 0.111, 0.104]
# MEAN = [123.675, 116.28, 103.53]
# STD = [58.395, 57.12, 57.375]
CONFIG = 'projects/TMM/configs/tmm_r152-MF.py'
CHECKPOINT = 'heatmap/checkpoints/MF/20231117_061123/best_mIoU_epoch_13.pth'
PREVIEW_MODEL = True
TARGET_LAYERS = ["model.model.backbone_r.layer3"]
METHOD = 'gradcam'
# for MFN dataset
SEM_CLASSES = [
'unlabeled', 'car', 'person', 'bike', 'curve', 'car_stop', 'guardrail',
'color_cone', 'bump'
]
TARGET_CATEGORY = 'person'
VIS_CAM_RESULTS = True
CAM_SAVE_PATH = "/heatmap/output"
LIKE_VIT = False
PRITN_MODEL_PRED_SEG = False
def parse_args():
parser = argparse.ArgumentParser(description='Visualize CAM')
parser.add_argument('--img', default=IMAGE_FILE_PATH, help='Image file')
parser.add_argument('--config', default=CONFIG, help='Config file')
parser.add_argument('--checkpoint',
default=CHECKPOINT,
help='Checkpoint file')
parser.add_argument(
'--target_layers',
default=TARGET_LAYERS,
nargs='+',
type=str,
help='The target layers to get CAM, if not set, the tool will '
'specify the norm layer in the last block. Backbones '
'implemented by users are recommended to manually specify'
' target layers in commmad statement.')
parser.add_argument('--preview_model',
default=PREVIEW_MODEL,
help='To preview all the model layers')
parser.add_argument('--method',
default=METHOD,
help='Type of method to use, supports '
f'{", ".join(list(METHOD_MAP.keys()))}.')
parser.add_argument('--sem_classes',
default=SEM_CLASSES,
nargs='+',
type=int,
help='all classes that model predict.')
parser.add_argument(
'--target_category',
default=TARGET_CATEGORY,
type=str,
help='The target category to get CAM, default to use result '
'get from given model.')
parser.add_argument('--aug_mean',
default=MEAN,
nargs='+',
type=float,
help='augmentation mean')
parser.add_argument('--aug_std',
default=STD,
nargs='+',
type=float,
help='augmentation std')
parser.add_argument(
'--cam_save_path',
default=CAM_SAVE_PATH,
type=str,
help='The path to save visualize cam image, default not to save.')
parser.add_argument('--vis_cam_results', default=VIS_CAM_RESULTS)
parser.add_argument('--device', default=DEVICE, help='Device to use cpu')
parser.add_argument('--like_vision_transformer',
default=LIKE_VIT,
help='Whether the target model is a ViT-like network.')
parser.add_argument('--print_model_pred_seg',
default=PRITN_MODEL_PRED_SEG,
help='')
args = parser.parse_args()
if args.method.lower() not in METHOD_MAP.keys():
raise ValueError(f'invalid CAM type {args.method},'
f' supports {", ".join(list(METHOD_MAP.keys()))}.')
return args
def norm_img(img, mean, std):
image = img.copy()
image = np.array(image)
image = image.transpose(2, 0, 1)
image = torch.tensor(image)
data_rgb = image[:3, :, :]
tmp = image[3:4, :, :]
data_t = torch.cat([tmp] * 3, dim=0)
preprocessing = Compose([
# ToTensor(),
Normalize(mean=mean, std=std)
])
rgb = preprocessing(data_rgb.float()).unsqueeze(0)
t = preprocessing(data_t.float()).unsqueeze(0)
input_tensor = torch.cat((rgb, t), dim=1)
return input_tensor
def make_input_tensor(image_file_path, mean, std, device):
if not os.path.exists(image_file_path):
raise (f"{image_file_path} is not exist!")
image = np.asarray(Image.open(image_file_path))
img = np.float32(image) / 255
img = torch.tensor(img)
input_tensor = norm_img(image, mean, std)
# input_tensor = preprocess_image(rgb_img, mean=mean, std=std)
if device == torch.device('cuda:0'):
input_tensor = input_tensor.to(device)
print(f"input_tensor has been to {device}")
return input_tensor, img
def make_model(config_path, checkpoint_path, device):
model = init_detector(config_path, checkpoint_path, device=device)
print('Network setup complete: The trained weights were successfully loaded')
return model
def init_detector(
config: Union[str, Path, Config],
checkpoint: Optional[str] = None,
palette: str = 'none',
device: str = 'cuda:0',
cfg_options: Optional[dict] = None,
) -> nn.Module:
"""Initialize a detector from config file.
Args:
config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path,
:obj:`Path`, or the config object.
checkpoint (str, optional): Checkpoint path. If left as None, the model
will not load any weights.
palette (str): Color palette used for visualization. If palette
is stored in checkpoint, use checkpoint's palette first, otherwise
use externally passed palette. Currently, supports 'coco', 'voc',
'citys' and 'random'. Defaults to none.
device (str): The device where the anchors will be put on.
Defaults to cuda:0.
cfg_options (dict, optional): Options to override some settings in
the used config.
Returns:
nn.Module: The constructed detector.
"""
if isinstance(config, (str, Path)):
config = Config.fromfile(config)
elif not isinstance(config, Config):
raise TypeError('config must be a filename or Config object, '
f'but got {type(config)}')
if cfg_options is not None:
config.merge_from_dict(cfg_options)
elif 'init_cfg' in config.model.backbone_r:
config.model.backbone_r.init_cfg = None
elif 'init_cfg' in config.model.backbone_i:
config.model.backbone_i.init_cfg = None
scope = config.get('default_scope', 'mmdet')
if scope is not None:
init_default_scope(config.get('default_scope', 'mmdet'))
model = MODELS.build(config.model)
model = revert_sync_batchnorm(model)
if checkpoint is None:
warnings.simplefilter('once')
warnings.warn('checkpoint is None, use COCO classes by default.')
model.dataset_meta = {'classes': get_classes('coco')}
else:
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
# Weights converted from elsewhere may not have meta fields.
checkpoint_meta = checkpoint.get('meta', {})
# save the dataset_meta in the model for convenience
if 'dataset_meta' in checkpoint_meta:
# mmdet 3.x, all keys should be lowercase
model.dataset_meta = {
k.lower(): v
for k, v in checkpoint_meta['dataset_meta'].items()
}
elif 'CLASSES' in checkpoint_meta:
# < mmdet 3.x
classes = checkpoint_meta['CLASSES']
model.dataset_meta = {'classes': classes}
else:
warnings.simplefilter('once')
warnings.warn(
'dataset_meta or class names are not saved in the '
'checkpoint\'s meta data, use COCO classes by default.')
model.dataset_meta = {'classes': get_classes('coco')}
# Priority: args.palette -> config -> checkpoint
if palette != 'none':
model.dataset_meta['palette'] = palette
else:
test_dataset_cfg = copy.deepcopy(config.test_dataloader.dataset)
# lazy init. We only need the metainfo.
test_dataset_cfg['lazy_init'] = True
metainfo = DATASETS.build(test_dataset_cfg).metainfo
cfg_palette = metainfo.get('palette', None)
if cfg_palette is not None:
model.dataset_meta['palette'] = cfg_palette
else:
if 'palette' not in model.dataset_meta:
warnings.warn(
'palette does not exist, random is used by default. '
'You can also set the palette to customize.')
model.dataset_meta['palette'] = 'random'
model.cfg = config # save the config in the model for convenience
model.to(device)
model.eval()
return model
from torch.nn import functional as F
class SegmentationModelOutputWrapper(torch.nn.Module):
def __init__(self, model):
super(SegmentationModelOutputWrapper, self).__init__()
self.model = model
def reshape_output(self, sem_seg, num_classes):
multi_channel_mask = torch.zeros(
(1, num_classes, sem_seg.shape[1], sem_seg.shape[2]),
dtype=sem_seg.dtype)
for i in range(num_classes):
multi_channel_mask[0, i, :, :] = (sem_seg == i)
return multi_channel_mask
def forward(self, x):
num_classes = self.model.num_classes
result = self.model(x.squeeze(0))
pred_sem_seg = self.reshape_output(result.sem_seg, num_classes)
# out = pred_sem_seg
out = F.interpolate(pred_sem_seg.float(),
size=x.shape[-2:],
mode='bilinear',
align_corners=False)
return out
class SemanticSegmentationTarget:
def __init__(self, category, mask):
self.category = category
self.mask = torch.from_numpy(mask)
if torch.cuda.is_available():
self.mask = self.mask.cuda()
def __call__(self, model_output):
model_output = model_output.cuda()
return (model_output[self.category, :, :] * self.mask).sum()
def reshape_transform_fc(in_tensor):
result = in_tensor.reshape(in_tensor.size(0),
int(np.sqrt(in_tensor.size(1))),
int(np.sqrt(in_tensor.size(1))),
in_tensor.size(2))
result = result.transpose(2, 3).transpose(1, 2)
return result
def main():
args = parse_args()
input_tensor, img = make_input_tensor(args.img,
args.aug_mean,
args.aug_std,
device=args.device)
rgb_img = mmcv.imread(RGB_FILE_PATH) # rgb
th_img = mmcv.imread(THER_FILE_PATH) # ther
cfg = args.config
checkpoint = args.checkpoint
model = make_model(cfg, checkpoint, device=args.device)
results = inference_detector(model, args.img)
if args.print_model_pred_seg:
pprint(results)
if args.preview_model:
pprint([name for name, _ in model.named_modules()])
model = SegmentationModelOutputWrapper(model)
output = model(input_tensor)
sem_classes = args.sem_classes
sem_class_to_idx = {cls: idx for (idx, cls) in enumerate(sem_classes)}
if len(sem_classes) == 1:
output = torch.sigmoid(output).cpu()
perd_mask = torch.where(output > 0.3, torch.ones_like(output),
torch.zeros_like(output))
perd_mask = perd_mask.detach().cpu().numpy()
else:
output = torch.nn.functional.softmax(output, dim=1).cpu()
perd_mask = output[0, :, :, :].argmax(axis=0).detach().cpu().numpy()
category = sem_class_to_idx[args.target_category]
mask_float = np.float32(perd_mask == category)
# # visual
# car_mask_uint8 = 255 * np.uint8(perd_mask == category)
#
# both_images = np.hstack((rgb_img, np.repeat(car_mask_uint8[:, :, None], 3, axis=-1)))
# image = Image.fromarray(both_images)
# image.save('output.png')
reshape_transform = reshape_transform_fc if args.like_vision_transformer else None
##########################################################################################################################################################################
target_layers = [model.model.backbone_r.layer4]
##########################################################################################################################################################################
targets = [SemanticSegmentationTarget(category, mask_float)]
GradCAM_Class = METHOD_MAP[args.method.lower()]
with GradCAM_Class(model=model,
target_layers=target_layers,
use_cuda=torch.cuda.is_available(),
reshape_transform=reshape_transform_fc
if args.like_vision_transformer else None) as cam:
grayscale_cam = cam(input_tensor=input_tensor,
targets=targets,
aug_smooth=True,
eigen_smooth=True)[0, :]
cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
vir_image = Image.fromarray(cam_image)
if args.vis_cam_results:
vir_image.show()
cam_save_path = f"{args.cam_save_path}/{os.path.basename(args.config).split('.')[0]}"
if not os.path.exists(cam_save_path):
os.makedirs(cam_save_path)
vir_image.save(
os.path.join(cam_save_path,
f"{os.path.basename(args.img).split('.')[0]}.png"))
if __name__ == '__main__':
main()
from pytorch-grad-cam.
Before I used grad-cam, I tried using register_forward_hook
to extract features from the middle layer of my model myself, and it worked.
hooks = [
model.backbone_r.layer3.register_forward_hook(
lambda self, input, output: conv_features_r_3.append(output)
),
model.backbone_t.layer3.register_forward_hook(
lambda self, input, output: conv_features_t_3.append(output)
)
]
result = inference_detector(model, img)
for hook in hooks:
hook.remove()
But when I tried to use grad-cam, I found that register_forward_hook
was not called in class ActivationsAndGradients
. Does pytorch-grad-cam have any restrictions on the model? Thank you for your attention.
from pytorch-grad-cam.
I'm using a custom model based on yolov7.
The code is correct:
cam = EigenCAM(model, target_layers)
When i change the following code in eigen_cam.py, it reports the bug:
super(EigenCAM, self).__init__(model, target_layers, reshape_transform, **uses_gradients=True**)
Where can I manually set up requires_grad=True
for a custom model?
from pytorch-grad-cam.
Related Issues (20)
- Support grad cam for cross attention on encoder-decoder models
- ScoreCAM device mismatch error
- GNN
- Can we use grad-cam for hypernetwork ?
- Support for CLIP HOT 5
- How many forward passes are calculated with ScoreCAM?
- If the target layer of model is encapsulated, and there are multiple outputs
- Pytorch YoloV5 model with GradCAM
- Does anyone have a working example with yolov8 for segmentation?
- Video CLassification
- How to separately visualize heatmaps for classification tasks and localization tasks in object detection
- Ambiguity in GradCam Visualization with Grayscale Images HOT 2
- Can I use grad-cam for video classification?
- [Feature Request] Support MPS Device HOT 3
- GradCAM for Dual Attention ViT HOT 1
- how to apply grad-cam to image fusion model ?
- gradcam for binary segmentaion network
- Installation showing missing requirements.txt
- Model improvement HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch-grad-cam.