Git Product home page Git Product logo

resnet-tensorflow's Introduction

ResNet-Tensorflow

Simple Tensorflow implementation of pre-activation ResNet18, ResNet34, ResNet50, ResNet101, ResNet152

Summary

dataset

  • tiny_imagenet
  • cifar10, cifar100, mnist, fashion-mnist in keras (pip install keras)

Train

  • python main.py --phase train --dataset tiny --res_n 18 --lr 0.1

Test

  • python main.py --phase test --dataset tiny --res_n 18 --lr 0.1

Related works

Author

Junho Kim

resnet-tensorflow's People

Contributors

taki0112 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

resnet-tensorflow's Issues

Resnet 학습 부분에서 궁금한 부분이 있습니다.

안녕하세요. 조금 혼란 스러운 부분이 있어서 질문을 드립니다.
해당 라인에서 한 배치마다 Test error를 계산한 것을 볼 수 있었습니다. validation 처럼 사용하신 것 같은데요. 그래서 작성하신 test_loss, test_accuracy를 전 val_loss, val_accuracy로 이해했습니다. 그런데 데이터 전처리에서 validation 데이터를(실제로test 데이터로 이름 붙이신) Train 데이터 mean, std로 노말라이즈 한것을 볼 수 있었는데요.

제가 알기론 validation 데이터는 validation mean, std로 노말라이즈 하는 걸로 알고 있는데 제가 잘 못 알고 있는 건가요?

아니면 코드 그대로 배치마다 테스트 데이터를 돌려 보신건가요?. 궁금합니다.

ckpt = tf.train.get_checkpoint_state(checkpoint_dir) return None

I ran this project on ubuntu server 16.04.
I installed this below requirements for this project
tensorflow==1.14.0 keras==2.2.5 matplotlib==3.1.1 Pillow==6.1.0 scipy==1.1.0

I just cloned this project and then ran as the instruction in README.md file by following command:
python main.py --phase train --dataset tiny --res_n 18 --lr 0.1.
After ran for a while encountered this issue [*] Failed to find a checkpoint . The error came from this statement tf.train.get_checkpoint_state(checkpoint_dir).
So what is the correct argument

checkpiont_dir

for this function?

ValueError: Dimensions must be equal, but are 128 and 32 for 'network/resblock0_0/add' (op: 'Add') with input shapes: [256,32,32,1 28], [256,32,32,32].

When I set --res_n 50, the code will report a error.

def resblock(x_init, channels, is_training=True, use_bias=True, downsample=False, scope='resblock') :
with tf.variable_scope(scope) :

    x = batch_norm(x_init, is_training, scope='batch_norm_0')
    x = relu(x)

    if downsample :
        x = conv(x, channels, kernel=3, stride=2, use_bias=use_bias, scope='conv_0')
        x_init = conv(x_init, channels, kernel=1, stride=2, use_bias=use_bias, scope='conv_init')

    else :
        x = conv(x, channels, kernel=3, stride=1, use_bias=use_bias, scope='conv_0')

    x = batch_norm(x, is_training, scope='batch_norm_1')
    x = relu(x)
    x = conv(x, channels, kernel=3, stride=1, use_bias=use_bias, scope='conv_1')

    return x + x_init

Dimensions of x and x_init are not equal.

a mistake and a question

Hi,taki0112:
i had read you code for learning. Great code to read and learn.
there is a mistake in the README.md file. "--ren_n" should be "--res_n".

i look your commit,and could you explain why you chang
----x_init = avg_pooling(x_init)
----x_init = conv(x_init, channels, kernel=1, stride=1, use_bias=use_bias, scope='conv_init')
to
----x_init = conv(x_init, channels*4, kernel=1, stride=2, use_bias=use_bias, scope='conv_init')
??

regularization not working

I tried this code and it seems that the resulting train/test accuracies are not changed no matter how high the regularization parameter is (in ops.py). A little digging in tensorflow documentation reveals that in order for regularization to work, you had to add the following line in build_model() in class ResNet.

self.train_loss+=tf.losses.get_regularization_loss()

Don't understand shortcut

shortcut = conv(shortcut, channels*4, kernel=1, stride=2, use_bias=use_bias, scope='conv_init')

shortcut = conv(shortcut, channels * 4, kernel=1, stride=1, use_bias=use_bias, scope='conv_init')

Hi Mr.Kim,
I am wondering why you put the shortcut through a Conv2D layer for the bottleneck block, which did not happen for the normal block. I thought in the paper(https://arxiv.org/pdf/1512.03385.pdf) the author said passing the input directly to the end of the block and add it with the block output for both types of block.
Best,
Jiayu Wang

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.