Git Product home page Git Product logo

pytorch-cifar10's Introduction

Pytorch Implementation of Deep Learning Models/Algorithms

This repositiory is for implementing and training/testing popular model architectures on the CIFAR10 dataset.

Environment

  • CUDA Version: 10.2
torch==1.5.0
torchvision==0.6.0
numpy==1.19.2

Usage

Training

To train a model, run train.py. If you need to speicfy the model, just use some args.

# train alexnet model with using gpu. 50 epochs
$ python train.py --model alexnet --epoch 50 --gpu

optional&required arguments

--data_dir      default='./data/train',
                help="Directory containing the dataset"
--model         required=True, type=str,
                help="The model you want to train"
--lr            type=float, default=0.001,
                help="Learning rate"
--epoch         type=int, default=50,
                help="Total training epochs"
--batch_size    type=int, default=256,
                help="batch size"
--gpu           action='store_true', default='False',
                help="GPU available"

Evaluate

To evaluate the model, run evaluate.py. If you need to speicfy the model, just use some args.

# evaluate alexnet model
$ python evaluate.py --model alexnet --weights ./results/alexnet/best.pth --gpu

optional&required arguments

--data_dir      default='./data/test',
                help="Directory containing the dataset"
--model         required=True, type=str,
                help="The model you want to test"
--weight        required=True,
                help="The weights file you want to test"
--batch_size    default=256,
                help="batch size"
--gpu           action='store_true', default='False',
                help="GPU available"

Results

Network epoch lr top1@prec(test) ModelSize(MB)
AlexNet 50 0.001 74.2578% 266MB
ZFNet 50 0.01 80.4395% 445MB
VGG - - - -
ResNet - - - -
Inception - - - -
GoogLeNet - - - -
- - - - -
- - - - -
- - - - -
- - - - -

pytorch-cifar10's People

Contributors

yskim0 avatar

Stargazers

 avatar  avatar  avatar

Watchers

 avatar  avatar

pytorch-cifar10's Issues

AlexNet/ ZFNet 코드 리뷰

pytorch는 처음이라서 그냥 보는데 전체적으로 주석을 잘 달아주셔서 코드 가독성이 좋은 것 같아요!

data_loader.py

transform.Normalize()의 변수는 직접 계산하셔서 입력하신 건가요?(파이토치를 잘 몰라서,,)
transform에서 데이터를 resize할 때 alexnet은 입력값이 227이라고 알고 있는데 사이즈를 256? 224? 조정하신 건가요?? centercrop은 데이터를 자른 것인지!? (이렇게 전처리한 이유는 무엇인지 궁금하네요!!)

cifar10 데이터셋이 제 컴에서 잘 안돌아가서 일부만 활용하느라 데이터셋 load를 그냥 너무 단순하게 한 것 같아서 저도 shuffle이나 augmentation 등 다양한 옵션을 고민해봐야겠네요

alexnet, zfnet

classifier과 feature extraction을 따로 구현해서 이 tensorflow랑 달라서 신기하네요.
forward() 도 신기하구요! 근데 이 함수가 호출이 안 되어도 되는건지?(forward함수 호출하는걸 못찾겠어서.. 그냥 실행이 되는 건가요?!! 제가 이해를 잘못한 거 일 수도 있어요!ㅜㅜ )

train.py

logging.info() 저도 넣어야겠다는 생각이 드네요ㅎㅎ
tensorflow-gpu로 할 때는 gpu가 그냥 구동이되는 것 같은데
pytorch라서 다른 건지 아니면 저렇게 정할 수 있는 옵션이 저도 되는지 한번 시도해봐야겠어요! tensorflow는 그냥 모델 돌리기엔 편한데 pytorch로 보면 확실히 여러 단계들을 세분화해서 볼 수 있다는 생각이 드네요!

utils.py

각각 모델 별로 optimizer을 다르게 로드하게 해두셔서 좋다는 생각을 했습니다!!!
alexnet은 adam인데 제가 sgd로 잘못했다는 것도 발견했네요.
save_checkpoints( )랑 best 모델 저장하는 것도 저도 적용해야겠어요!!


고생하셨습니다 👐 여러모로 배울 부분과 제게 적용해야할 부분이 많네요ㅎㅎ
제가 아직 공부할 게 많아서..ㅎㅎ 혹시 질문한 부분 중에 답주실 수 있으신 부분은 시간 나실 때 언제든 답 주셔도 좋아요ㅎㅎ 앞으로도 화이팅!!

AlexNet/ ZFNet 코드 리뷰

alexnet.py

코드가 깔끔하게 작성된 것 같습니다. Alexnet에서 중요한 ReLU, pooling, dropout 등이 코드에 잘 포함되어 있네요. 논문에는 커널을 96->256->384->384->256으로 설정하였는데 코드에서는 64->192->384->256->256으로 설정하신 이유가 궁금해요! mnist와 cifar-10의 차이인지 아니면 성능을 더 좋게 하기 위해서 하신건지 알고싶어요.

기본적인 틀은 다 갖추어져 있으니 성능 평가 지표를 더 향상시킬 수 있게 추가적으로 하이퍼 파라미터를 바꿔보거나 conv layer를 더 추가하는 작업을 해보셔도 좋을 것 같습니다.

ZFNet.py

ZFNet 코드를 보니까 텐서플로보다 파이토치가 훨씬 더 심플하게 표현하기 좋다는 생각이 드네요! 저도 다음에는 파이토치에 도전해봐야 겠습니다ㅎㅎ 군더더기 없이 잘 작성된 것 같습니다. 한 가지 제안하고 싶은건 전체적으로 주석을 달아서 각각의 어떤 역할을 하고 있는지 설명을 덧붙이면 더 좋을 것 같아요.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.