chenyaofo / pytorch-cifar-models Goto Github PK
View Code? Open in Web Editor NEWPretrained models on CIFAR10/100 in PyTorch
License: BSD 3-Clause "New" or "Revised" License
Pretrained models on CIFAR10/100 in PyTorch
License: BSD 3-Clause "New" or "Revised" License
url = (
"https://rutgers.box.com/shared/static/gkw08ecs797j2et1ksmbg1w5t3idf5r5.zip"
)
# Streaming, so we can iterate over the response.
r = requests.get(url, stream=True, verify=False)
state_dicts.zip file downloaded is unable to get extracted, is it because I used verify=False in request.get() command?
Thanks for the pre-trained models!
I'm testing the pre-trained models on CIFAR-10 using torchvision.datasets.CIFAR10()
and found that you are using different normalization params from pytorch.org:
mean: [0.4914, 0.4822, 0.4465]
, std: [0.2023, 0.1994, 0.2010]
.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
.I can reproduce the results using the same normalization params as you had. It took me some time to locate the issue so I assume clarifying the normalization params somewhere in the project README would help other people.
pip install throws the following error. requirements.txt
is missing.
FileNotFoundError: [Errno 2] No such file or directory: 'requirements.txt'
Hi, I tried to evaluate your models' accuracy in two ways. In your report, cifar100_MobileNetV2_x1_0
model accuracy is 74.20, but when I tried
start_on_colab
,import torch
from torchvision import datasets, transforms
# data will be downloaded in data_directory.
data_directory = './data'
batchsize=256
device = 'cuda'
normalize = transforms.Normalize(mean=[0.507, 0.4865, 0.4409],
std=[0.2673, 0.2564, 0.2761])
train_dataset = datasets.CIFAR100(root=data_directory,
train=True,
transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize]),
download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batchsize,
shuffle=True,
pin_memory=True,
num_workers=2)
test_dataset = datasets.CIFAR100(root=data_directory,
train=False,
transform=transforms.Compose([
transforms.ToTensor(),
normalize]),
download=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batchsize,
shuffle=False,
pin_memory=True,
num_workers=2)
def validate(model,loader):
global device
model.eval()
correct = 0.
total = 0.
for images, labels in loader:
images = images.to(device)
labels = labels.to(device)
with torch.no_grad():
pred = model(images)
pred = torch.max(pred.data, 1)[1]
total += labels.size(0)
correct += (pred == labels).sum().item()
val_acc = (correct / total)*100
model.train()
return val_acc
model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar100_mobilenetv2_x1_0", pretrained=True).to(device)
validate(model, test_loader)
and I got
Can you check your models?
Thanks for your pretrained-models on CIFAR10, I want to do some test with pretrained model, howerer, the inference acc(99.624%) seems too high (using torchvision.datasets.CFAR10
). I think your train/val split approach is different with torchvision, could you tell me how your data set is divided?
All models worked (both for Cifar10 and 100); except mobilenetv2_x0_75
trained with Cifar100
.
Downloading: "https://github.com/chenyaofo/pytorch-cifar-models/archive/master.zip" to C:\Users\xxx/.cache\torch\hub\master.zip
Downloading: "https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar100_mobilenetv2_x0_75-9ab3e178.pt" to C:\Users\xxx/.cache\torch\hub\checkpoints\cifar100_mobilenetv2_x0_75-9ab3e178.p
...
...
File "C:\ProgramData\Anaconda3\envs\PyTorchEight\lib\urllib\request.py", line 649, in http_error_default
raise HTTPError(req.full_url, code, msg, hdrs, fp)
HTTPError: Not Found
I thought it is the .pt filename, but looks correct to me.
Update:
When I try to download the .pt model directly (by on the link to the model), it does not download and I got the message Failed-No file.
Direct download works for all the other models.
Can you provide the train.py ?
There is an Error: http.client.RemoteDisconnected: Remote end closed connection without response
Code:
net = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar10_resnet20", pretrained=True)
Environment:
Pytorch 1.11
Hi, I ran
!python -m entry.run --conf conf/cifar10.conf -o output/cifar10/resnet20 -M model.name=cifar10_resnet20
on colab, and it gives me the error
Traceback (most recent call last):
File "/usr/lib/python3.7/runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "/usr/lib/python3.7/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/content/image-classification-codebase/entry/run.py", line 6, in <module>
main(get_args())
File "/content/image-classification-codebase/codebase/main.py", line 234, in main
main_worker(local_rank, ngpus_per_node, args, args.conf)
File "/content/image-classification-codebase/codebase/main.py", line 203, in main_worker
prepare_for_training(conf, args.output_dir, local_rank)
File "/content/image-classification-codebase/codebase/main.py", line 142, in prepare_for_training
basic_bs = optimizer_config.pop("basic_bs")
File "/usr/local/lib/python3.7/dist-packages/pyhocon/config_tree.py", line 274, in pop
value = self.get(key, UndefinedKey)
File "/usr/local/lib/python3.7/dist-packages/pyhocon/config_tree.py", line 236, in get
return self._get(ConfigTree.parse_key(key), 0, default)
File "/usr/local/lib/python3.7/dist-packages/pyhocon/config_tree.py", line 177, in _get
u"No configuration setting found for key {key}".format(key='.'.join(key_path[:key_index + 1])))
pyhocon.exceptions.ConfigMissingException: 'No configuration setting found for key basic_bs'
Does anyone know how to solve this?
Could let me know How I can cite your work
Hi, I was wondering how many epochs were the pre-trained models trained for? And what other hyperparameters (learning rate, optimizer, scheduler, etc.) were used? I'm particularly interested in ResNet-20.
Thanks!
I cannot find in the training logs how the data is being preprocessed for training and testing datasets. Does this mean that normalization only has been applied for both?
I couldn't find in the code/documentation the values used for mean/std of the normalization done if I want to test on new images, were they the same ones used in the image-classification-codebase code? ie,
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]
Side note, on the get_vit_val_transforms(mean, std, img_size)
and def get_vit_train_transforms(mean, std, img_size)
functions there seems to be a small bug, you end up not using the values of mean/std passed in to the function and instead using hard coded ones (unless this is intended behaviour, but the mean/std parameters threw me off)
I downloaded some of the networks and the accuracy is different from what you mentioned.
For example after downloading the RES-NET I checked the accuracy and it was only 0.8 (on train).
Could you please check if this is a problem happens only to me?
(I down loaded the model from README and then uploaded into python)
model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar_resnet20", pretrained=true)
should be
model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar_resnet20", pretrained="cifar10")
or
model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar_resnet20", pretrained="cifar100")
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.