Git Product home page Git Product logo

Comments (4)

hendrycks avatar hendrycks commented on August 23, 2024

We have code for OOD detection for multiclass classifiers in https://github.com/hendrycks/outlier-exposure
While these were used for a different paper, there should be a high amount of overlap.

from anomaly-seg.

gedoensmax avatar gedoensmax commented on August 23, 2024

Ok just for clarification:
If i train on a full train split ImageNet and test on gaussian noise for number of samples in the ImageNet test split i can reproduce Table 1 row 1. With the same approach - Places365 training and number of samples of it's test split as gaussian noise/rademacher/blob. For non artificiual data you are just using the full test set of whatever OOD dataset ? Just wondering because in the repository you mention there is neither a places experiment nor an ImageNet experiment.

from anomaly-seg.

hendrycks avatar hendrycks commented on August 23, 2024

I looked around for code and here is some that I found. It's probably similar to what we used for the paper, but I think the version we used for the paper is on a collaborator's compute. Regardless, this should hopefully help.

import numpy as np
import sys
import os
import pickle
import argparse
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision.transforms as trn
import torchvision.datasets as dset
import torch.nn.functional as F
from skimage.filters import gaussian as gblur
from PIL import Image as PILImage
import torchvision.models as models
import torch.utils.model_zoo as model_zoo
from tqdm import tqdm

# go through rigamaroo to do ...utils.display_results import show_performance
if __package__ is None:
    import sys
    from os import path

    sys.path.append(path.dirname(path.dirname(path.abspath(__file__))))
    from utils.display_results import show_performance, get_measures, print_measures, print_measures_with_std
    import utils.lsun_loader as lsun_loader

parser = argparse.ArgumentParser(description='Evaluates an ImageNet OOD Detector',
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# Setup
parser.add_argument('--test_bs', type=int, default=128)
parser.add_argument('--num_to_avg', type=int, default=1, help='Average measures across num_to_avg runs.')
parser.add_argument('--validate', '-v', action='store_true', help='Evaluate performance on validation distributions.')
parser.add_argument('--method', '-m', default="maxlogit", choices=["msp", "maxlogit", "kl", "ent"])
# Loading details
parser.add_argument('--load', '-l', type=str, default='', help='Checkpoint path to resume / test.')
parser.add_argument('--ngpu', type=int, default=1, help='0 = CPU.')
parser.add_argument('--prefetch', type=int, default=6, help='Pre-fetching threads.')
parser.add_argument('--dataset', '-d', type=str, default='imagenet', choices=['imagenet', 'imagenetrot', 'places'],
                    help='Choose in-distribution dataset: "imagenet" or "places".')

args = parser.parse_args()
print(args)
# torch.manual_seed(1)
# np.random.seed(1)

if args.dataset == "imagenet":
    data_folder = "/home/hendrycks/datasets/imagenet/"
else:
    data_folder = "/home/hendrycks/datasets/places365/places365_standard/"


# mean and standard deviation of channels of ImageNet images
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

transform = trn.Compose([trn.Resize(256), trn.CenterCrop(224),
                         trn.ToTensor(), trn.Normalize(mean, std)])


val_data = dset.ImageFolder(root=data_folder + 'val', transform=transform)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.test_bs, shuffle=False,
                                         num_workers=args.prefetch, pin_memory=True)

test_data = dset.ImageFolder(root=data_folder + 'test', transform=transform)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.test_bs, shuffle=False,
                                          num_workers=args.prefetch, pin_memory=True)

if args.dataset == "imagenet":
    num_classes = 1000
else:
    num_classes = 365

# Create model
if args.dataset == "imagenet":
    net = models.resnet50()
    net.load_state_dict(model_zoo.load_url('https://download.pytorch.org/models/resnet50-19c8e357.pth',
                                           model_dir='/media/hendrycks/My Passport1/backup/hendrycks/datasets/models'))
elif args.dataset == "imagenetrot":
    net = models.resnet50(num_classes=1000)
    net.rot_head = nn.Linear(2048,4)
    checkpoint = torch.load("./rot_ood.pth.tar")
    state_dict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].items()}
    net.load_state_dict(state_dict)
else:
    net = models.resnet50(num_classes=365)
    checkpoint = torch.load("./resnet50_places365.pth")
    state_dict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].items()}
    net.load_state_dict(state_dict)

# start_epoch = 0
#
# # Restore model
# if args.load != '':
#     for i in range(1000 - 1, -1, -1):
#         if 'baseline' in args.method:
#             subdir = 'baseline'
#         elif 'oe_tune' in args.method:
#             subdir = 'oe_tune'
#         else:
#             subdir = 'oe_scratch'
#
#         model_name = os.path.join(os.path.join(args.load, subdir), args.method + '_epoch_' + str(i) + '.pt')
#         if os.path.isfile(model_name):
#             print("loading model ",model_name)
#             net.load_state_dict(torch.load(model_name))
#             print('Model restored! Epoch:', i)
#             start_epoch = i + 1
#             break
#     if start_epoch == 0:
#         assert False, "could not resume"

net.eval()

if args.ngpu > 1:
    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

if args.ngpu > 0:
    net.cuda()

cudnn.benchmark = True  # fire on all cylinders


concat = lambda x: np.concatenate(x, axis=0)
to_np = lambda x: x.data.cpu().numpy()


def _iterate_data(loader, _matrix, _matrix_counts):
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(tqdm(loader)):
            data = data.cuda()

            output = net(data)
            smax = to_np(F.softmax(output, dim=1))

            for index, label in enumerate(target):
                label = np.argmax(smax[index])
                _matrix[label] += smax[index]
                _matrix_counts[label] += 1

        print("check if there are any zeroes",np.where(_matrix_counts == 0)[0])

    return _matrix, _matrix_counts, _matrix_grads


def create_typicality_matrix(loader):
    _matrix = np.zeros([num_classes, num_classes])
    _matrix_counts = np.zeros([num_classes])

    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(tqdm(loader)):
            data = data.cuda()

            output = net(data)
            smax = to_np(F.softmax(output, dim=1))

            for index, label in enumerate(target):
                label = np.argmax(smax[index])
                _matrix[label] += smax[index]
                _matrix_counts[label] += 1

        print("Zeroes check:", np.where(_matrix_counts == 0)[0])

    return (_matrix.T/_matrix_counts).T

matrix = []

if args.method == 'kl':
    if args.dataset == "imagenet":
        matrix_name = "imagenet_val_" + args.method
    else:
        matrix_name = "places_365_" + args.method

    if os.path.exists(matrix_name + ".npy"):
        matrix = np.load(matrix_name + ".npy")
    else:
        print('Estimating typical posterior distributions')
        matrix = create_typicality_matrix(val_loader)
        np.save(matrix_name, matrix)

    t_mat = torch.from_numpy(matrix).float().cuda()

# /////////////// Detection Prelims ///////////////


# ood_num_examples = len(val_data) // 5
# print("ood_num_examples = ", ood_num_examples)
ood_num_examples = 10000  # Setting it to 20K gave out of memory issues
expected_ap = ood_num_examples / (ood_num_examples + len(test_data))


def get_ood_scores(loader, in_dist=False):
    _score = []
    _right_score = []
    _wrong_score = []

    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(loader):

            if batch_idx >= ood_num_examples // args.test_bs and in_dist is False:
                break

            data = data.cuda()

            output = net(data)
            output_np = to_np(output)
            smax_torch = F.softmax(output, dim=1)
            smax = to_np(F.softmax(output, dim=1))

            if args.method == "msp":
                _score.append(-np.max(smax, axis=1))

            elif args.method == 'maxlogit':
                _score.append(-np.max(output_np, axis=1))

            elif args.method == 'ent':
                _score.append(np.sum(-smax * np.log(smax), axis=1))

            else:
                tmp = []
                for index in range(len(output)):
                    pred_class = torch.argmax(smax_torch[index])
                    row = to_np(torch.sum(-smax_torch[index] * torch.log(t_mat[pred_class]/smax_torch[index])))
                    tmp.append(row)
                _score.extend(tmp)

            if in_dist:
                preds = np.argmax(smax, axis=1)
                targets = target.numpy().squeeze()
                right_indices = preds == targets
                wrong_indices = np.invert(right_indices)

                if args.method == "msp":
                    _right_score.append(-np.max(smax[right_indices], axis=1))
                    _wrong_score.append(-np.max(smax[wrong_indices], axis=1))
                elif args.method == 'maxlogit':
                    _right_score.append(-np.max(output_np[right_indices], axis=1))
                    _wrong_score.append(-np.max(output_np[wrong_indices], axis=1))
                elif args.method == 'ent':
                    _right_score.append(np.sum(-smax[right_indices] * np.log(smax[right_indices]), axis=1))
                    _wrong_score.append(np.sum(-smax[wrong_indices] * np.log(smax[wrong_indices]), axis=1))
                else:
                    left_index = batch_idx*args.test_bs
                    right_index = (batch_idx+1)*args.test_bs

                    _right_score.extend(list(np.array(_score[left_index:right_index])[right_indices]))
                    _wrong_score.extend(list(np.array(_score[left_index:right_index])[wrong_indices]))

    if in_dist:
        if args.method == "kl":
            _score = np.asarray(_score).reshape(-1,1)
            _right_score = np.asarray(_right_score).reshape(-1,1)
            _wrong_score = np.asarray(_wrong_score).reshape(-1,1)
        # print("score shapes after", np.asarray(_right_score).shape, np.asarray(_wrong_score).shape)

        return concat(_score).copy(), concat(_right_score).copy(), concat(_wrong_score).copy()
    else:
        if args.method == "kl":
            _score = np.asarray(_score).reshape(-1,1)
        return concat(_score)[:ood_num_examples].copy()


print("Getting scores for in-distribution examples")

in_score, right_score, wrong_score = get_ood_scores(test_loader, in_dist=True)
# in_score, right_score, wrong_score = get_ood_scores(val_loader, in_dist=True)

# Error rate is meaingless if using test_loader
#num_right = len(right_score)
#num_wrong = len(wrong_score)
#print('Error Rate {:.2f}'.format(100 * num_wrong / (num_wrong + num_right)))

# print("lengths", right_score.shape, wrong_score.shape)

# /////////////// End Detection Prelims ///////////////

if "imagenet" in args.dataset:
    print('\nUsing ImageNet as typical data')
else:
    print('\nUsing Places365 as typical data')


# /////////////// Error Detection ///////////////

#print('\n\nError Detection')
#show_performance(wrong_score, right_score, method_name=args.method)

# /////////////// OOD Detection ///////////////
auroc_list, aupr_list, fpr_list = [], [], []


def get_and_print_results(ood_loader, num_to_avg=args.num_to_avg):

    aurocs, auprs, fprs = [], [], []
    for _ in range(num_to_avg):
        out_score = get_ood_scores(ood_loader)
        measures = get_measures(out_score, in_score)
        aurocs.append(measures[0]); auprs.append(measures[1]); fprs.append(measures[2])

    auroc = np.mean(aurocs); aupr = np.mean(auprs); fpr = np.mean(fprs)
    auroc_list.append(auroc); aupr_list.append(aupr); fpr_list.append(fpr)

    if num_to_avg >= 5:
        print_measures_with_std(aurocs, auprs, fprs, args.method)
    else:
        print_measures(auroc, aupr, fpr, args.method)



# /////////////// Gaussian Noise ///////////////

dummy_targets = torch.ones(ood_num_examples * args.num_to_avg)
ood_data = torch.clamp(torch.randn(ood_num_examples * args.num_to_avg, 3, 224, 224) * 0.5, -1, 1)
ood_data = torch.utils.data.TensorDataset(ood_data, dummy_targets)
ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.test_bs, shuffle=True,
                                         num_workers=args.prefetch, pin_memory=True)

print('\n\nGaussian Noise (sigma = 0.5) Detection')
get_and_print_results(ood_loader)

# /////////////// Rademacher Noise ///////////////

dummy_targets = torch.ones(ood_num_examples * args.num_to_avg)
ood_data = torch.sign(torch.rand(ood_num_examples * args.num_to_avg, 3, 224, 224) * 2 - 1)
ood_data = torch.utils.data.TensorDataset(ood_data, dummy_targets)
ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.test_bs, shuffle=True)

print('\n\nRademacher Noise Detection')
get_and_print_results(ood_loader)

# /////////////// Blob ///////////////

ood_data = np.float32(np.random.binomial(n=1, p=0.7, size=(ood_num_examples * args.num_to_avg, 224, 224, 3)))
for i in range(ood_num_examples * args.num_to_avg):
    ood_data[i] = np.float32(gblur(ood_data[i], sigma=2, multichannel=False))
    ood_data[i][ood_data[i] < 0.75] = 0.0

dummy_targets = torch.ones(ood_num_examples * args.num_to_avg)
ood_data = torch.from_numpy(ood_data.transpose((0, 3, 1, 2))) * 2 - 1
ood_data = torch.utils.data.TensorDataset(ood_data, dummy_targets)
ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.test_bs, shuffle=True,
                                         num_workers=args.prefetch, pin_memory=True)

print('\n\nBlob Detection')
get_and_print_results(ood_loader)

# /////////////// Textures ///////////////

ood_data = dset.ImageFolder(root="/home/hendrycks/datasets/dtd/images",
                            transform=transform)
ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.test_bs, shuffle=True,
                                         num_workers=args.prefetch, pin_memory=True)

print('\n\nTexture Detection')
get_and_print_results(ood_loader)

# /////////////// LSUN ///////////////
if args.dataset == 'imagenet':
    ood_data = lsun_loader.LSUN("/home/hendrycks/datasets/LSUN/data", classes='test',
                                transform=transform)
    ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.test_bs, shuffle=True,
                                             num_workers=args.prefetch, pin_memory=True)

    print('\n\nLSUN Detection')
    get_and_print_results(ood_loader)

# /////////////// Places365 ///////////////

if "imagenet" in  args.dataset:
    data_folder = "/home/hendrycks/datasets/places365/places365_standard/test"
elif args.dataset == "places":
    data_folder = "/home/hendrycks/datasets/places365/extra69/images"

ood_data = dset.ImageFolder(root=data_folder, transform=transform)
ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.test_bs, shuffle=True,
                                         num_workers=args.prefetch, pin_memory=True)

if "imagenet" in args.dataset:
    print('\n\nPlaces365 Detection')
else:
    print('\n\nPlaces365 Extra69 Detection')
get_and_print_results(ood_loader)


# /////////////// ImageNet-22K (held-out) ///////////////

#ood_data = dset.ImageFolder(
#    root="/home/hendrycks/datasets/imagenet22k/images", transform=transform)
#ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.test_bs, shuffle=True,
#                                         num_workers=args.prefetch, pin_memory=True)


#print('\n\nHeld-out ImageNet-22K Classes Detection')
#get_and_print_results(ood_loader)

# /////////////// Mean Results ///////////////

print('\n\nMean Test Results')
print_measures(np.mean(auroc_list), np.mean(aupr_list), np.mean(fpr_list), method_name=args.method)

exit()

# /////////////// OOD Detection of Validation Distributions ///////////////

if args.validate is False:
    exit()

auroc_list, aupr_list, fpr_list = [], [], []

# /////////////// Uniform Noise ///////////////

dummy_targets = torch.ones(ood_num_examples * args.num_to_avg)
ood_data = torch.from_numpy(
    np.random.uniform(size=(ood_num_examples * args.num_to_avg, 3, 224, 224),
                      low=-1.0, high=1.0).astype(np.float32))
ood_data = torch.utils.data.TensorDataset(ood_data, dummy_targets)
ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.test_bs, shuffle=True)

print('\n\nUniform[-1,1] Noise Detection')
get_and_print_results(ood_loader)


# /////////////// Arithmetic Mean of Images ///////////////


class AvgOfPair(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.shuffle_indices = np.arange(len(dataset))
        np.random.shuffle(self.shuffle_indices)

    def __getitem__(self, i):
        random_idx = np.random.choice(len(self.dataset))
        while random_idx == i:
            random_idx = np.random.choice(len(self.dataset))

        return self.dataset[i][0] / 2. + self.dataset[random_idx][0] / 2., 0

    def __len__(self):
        return len(self.dataset)


ood_data = dset.ImageFolder(
    root="/share/data/vision-greg/ImageNet/clsloc/256/images/val", transform=test_transform)
ood_loader = torch.utils.data.DataLoader(AvgOfPair(ood_data),
                                         batch_size=args.test_bs, shuffle=True,
                                         num_workers=args.prefetch, pin_memory=True)

print('\n\nArithmetic Mean of Random Image Pair Detection')
get_and_print_results(ood_loader)


# /////////////// Geometric Mean of Images ///////////////


class GeomMeanOfPair(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.shuffle_indices = np.arange(len(dataset))
        np.random.shuffle(self.shuffle_indices)

    def __getitem__(self, i):
        random_idx = np.random.choice(len(self.dataset))
        while random_idx == i:
            random_idx = np.random.choice(len(self.dataset))

        return trn.Normalize(mean, std)(torch.sqrt(self.dataset[i][0] * self.dataset[random_idx][0])), 0

    def __len__(self):
        return len(self.dataset)


ood_data = dset.ImageFolder(
    root="/share/data/vision-greg/ImageNet/clsloc/256/images/val", transform=trn.ToTensor())
ood_loader = torch.utils.data.DataLoader(
    GeomMeanOfPair(ood_data), batch_size=args.test_bs, shuffle=True,
    num_workers=args.prefetch, pin_memory=True)

print('\n\nGeometric Mean of Random Image Pair Detection')
get_and_print_results(ood_loader)

# /////////////// Jigsaw Images ///////////////

ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.test_bs, shuffle=True,
                                         num_workers=args.prefetch, pin_memory=True)

jigsaw = lambda x: torch.cat((
    torch.cat((torch.cat((x[:, 16:32, :32], x[:, :16, :32]), 1),
               x[:, 32:, :32]), 2),
    torch.cat((x[:, 32:, 32:],
               torch.cat((x[:, :32, 48:], x[:, :32, 32:48]), 2)), 2),
), 1)

ood_loader.dataset.transform = trn.Compose([trn.ToTensor(), jigsaw, trn.Normalize(mean, std)])

print('\n\nJigsawed Images Detection')
get_and_print_results(ood_loader)

# /////////////// Speckled Images ///////////////

speckle = lambda x: torch.clamp(x + x * torch.randn_like(x), 0, 1)
ood_loader.dataset.transform = trn.Compose([trn.ToTensor(), speckle, trn.Normalize(mean, std)])

print('\n\nSpeckle Noised Images Detection')
get_and_print_results(ood_loader)

# /////////////// Pixelated Images ///////////////

pixelate = lambda x: x.resize((int(224 * 0.2), int(224 * 0.2)), PILImage.BOX).resize((224, 224), PILImage.BOX)
ood_loader.dataset.transform = trn.Compose([pixelate, trn.ToTensor(), trn.Normalize(mean, std)])

print('\n\nPixelate Detection')
get_and_print_results(ood_loader)

# /////////////// RGB Ghosted/Shifted Images ///////////////

rgb_shift = lambda x: torch.cat((x[1:2].index_select(2, torch.LongTensor([i for i in range(224 - 1, -1, -1)])),
                                 x[2:, :, :], x[0:1, :, :]), 0)
ood_loader.dataset.transform = trn.Compose([trn.ToTensor(), rgb_shift, trn.Normalize(mean, std)])

print('\n\nRGB Ghosted/Shifted Image Detection')
get_and_print_results(ood_loader)

# /////////////// Inverted Images ///////////////

# not done on all channels to make image ood with higher probability
invert = lambda x: torch.cat((x[0:1, :, :], 1 - x[1:2, :, ], 1 - x[2:, :, :],), 0)
ood_loader.dataset.transform = trn.Compose([trn.ToTensor(), invert, trn.Normalize(mean, std)])

print('\n\nInverted Image Detection')
get_and_print_results(ood_loader)

# /////////////// Mean Results ///////////////

print('\n\nMean Validation Results')
print_measures(np.mean(auroc_list), np.mean(aupr_list), np.mean(fpr_list), method_name=args.method)

from anomaly-seg.

gedoensmax avatar gedoensmax commented on August 23, 2024

Thanks a lot i think i can get all parameteers from there.

from anomaly-seg.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.