Git Product home page Git Product logo

semantic-segmentation-level2-cv-18's Introduction

0. Quick Start (Get our result)

  1. Set a virtual environment and requirements for our result (Recommended)

    $ conda create -n segmentation python=3.7.11
    $ conda activate segmentation
    $ pip install -r requirements
  2. Execute the following command and get our result.

    $ sh run.sh

1. Prepare

1.1. Get the trash dataset.

$ wget https://aistages-prod-server-public.s3.amazonaws.com/app/Competitions/000078/data/data.zip
$ mv data.zip rawdata.zip
$ unzip rawdata.zip
$ rm rawdata.zip

Dataset Copy Right and License: Naver Connect, CC-BY-2.0
For more detail about this dataset, see here.

1.2. Get Some Libraries.

  1. pytorch-toolbelt (for various loss)

    $ pip install pytorch_toolbelt

    For more detail, see here.

  2. MADGRAD (An optimizer)

    $ pip install madgrad

    For more detail, see here.

1.3. Completed Structure

./
├─rawdata/
|
├─config/
|    ├─fix_seed.py
|    ├─read_config.py
|    ├─wnb.py
|    └─default.yaml
├─data/
|    ├─dataset.py
|    ├─dataloader.py
|    └─augmentation.py
├─submission/
|    └─sample_submission.csv
├─util/
|    ├─eda.py
|    ├─ploting.py
|    ├─tta.py
|    └─utils.py
|
├─train.py
├─inference.py
├─pipeline.py
|
└─run.sh

2. Pipeline

  1. One Step Execution - From train to inference.

    $ python pipeline.py --cfg-yaml ./config/default.yaml
  2. Two Step Execution

    • Train Step

      $ python train.py --cfg-yaml ./config/default.yaml
    • Inference Step

      $ python inference.py --cfg-yaml ./config/default.yaml

3. Configurations

3.1. Configuration file usage

  1. You can see whole configurations in the yaml file (./config/default.yaml)

  2. Copy and paste the yaml file and edit it.

  3. Then, you can get the proper result.

3.2. About configurations

  1. Listing the supported frameworks

    FRAMEWORKS_AVAILABLE: ["torchvision", "segmentation_models_pytorch"]

    Only support 2 frameworks.

  2. Listing the supported models (including encoders and decoders)

    MODELS_AVAILABLE:
        torchvision: ["fcn_resnet50", ... , "lraspp_mobilenet_v3_large"]
                      
    DECODER_AVAILABLE: ["unet", ... , "pan"] # smp decoder
    
    ENCODER_AVAILABLE: ['resnet18', ... , 'timm-gernet_l'] # smp encoder

    smp : segmentation_models_pytorch

  3. Listing the supported criterion

    CRITERION_AVAILABLE:
        # available_framework: [avaliable criterions]
        torch.nn: ["CrossEntropy"]
        pytorch_toolbelt: ["BalancedBCEWithLogitsLoss", ... , "WingLoss"]
  4. Listing the supported K-Fold types.

    KFOLD_TYPE_AVAILABLE: ["KFold", "MultilabelStratifiedKFold"]

    For K-Fold reference, see here.

    For MLSK-Fold reference, see here.

  5. Model Selection

    • Torchvision

      SELECTED:
          # 1. IF you use torchvision model
          FRAMEWORK: "torchvision"
          MODEL: "lraspp_mobilenet_v3_large" # also used for submission save.
          MODEL_CFG:
              pretrained: True
    • smp

      SELECTED:
          # 2. IF you use smp model
          # smp.create_model(**cfg["SELECTED"]["MODEL_CFG"]) 형태로 사용하기 때문에
          # MODEL_CFG 아래는 소문자가 좋습니다. (PRETRAINED -> pretrained)
          FRAMEWORK: "segmentation_models_pytorch"
          MODEL_CFG:
              arch: "fpn"                          # DECODER
              encoder_name: "timm-efficientnet-b6" # ENCODER (https://smp.readthedocs.io/en/latest/encoders.html) 
              encoder_weights: "noisy-student"     # ENCODER 마다 가능한 DATASET 상이 ("imagenet", "advpros", "noisy-student" 등)
              in_channels: 3 # fixed
              classes: 11    # fixed
  6. Criterion Selection

    SELECTED:
      # ...
    
      CRITERION:
        FRAMEWORK: "pytorch_toolbelt"
        USE: "SoftCrossEntropyLoss"
        CFG:
  7. Experiment configurations

    • seed, epochs, batch size, learning rate, the number of workers, validation period config

      EXPERIMENTS:
          SEED: 21
          NUM_EPOCHS: 30
          BATCH_SIZE: 16
          LEARNING_RATE: 1e-4
          NUM_WORKERS: 4
          VAL_EVERY: 5
          
          # ...
    • K-Fold config

      EXPERIMENTS:
          # ...
          
          KFOLD:
              TURN_ON: True
              TYPE: "MultilabelStratifiedKFold"
              NUM_FOLD: 5
          
          # ...
    • Autocast

      EXPERIMENTS:
          # ...
          
          AUTOCAST_TURN_ON: True
          
          # ...
    • wandb config

      EXPERIMENTS:
          # ...
          
          WNB:
              TURN_ON: True
              INIT:
                  entity: "ai_tech_level2-cv-18"
                  project: "seunghun_T2042"
                  name: "fpn_timm-efficientnet-b6" # recommended to change if wnb is turn-on.
           
           # ...
    • Configure directories for best performance model saving and submission file saving

      EXPERIMENTS:
          # ...
          
          SAVED_DIR: 
              BEST_MODEL: "./saved"
              SUBMISSION: "./submission"
              
          # ...
    • Configure train transforms, which will be compounded by A.OneOf.

      EXPERIMENTS:
          # ...
          
          TRAIN_TRANS: # ToTensorV2 는 기본으로 들어가있고 Albumentation 의 augmentation 이용
              GridDistortion: 
                  p: 1.0
              RandomGridShuffle:
                  p: 1.0
              RandomResizedCrop:
                  height: 512
                  width: 512
                  p: 1.0
              HorizontalFlip:
                  p: 1.0
              VerticalFlip:
                  p: 1.0
              GridDropout:
                  p: 1.0
              ElasticTransform:
                  p: 1.0
                  
          # ...
    • TTA config

      EXPERIMENTS:
         # ...
          
          TTA:
              TURN_ON: True
              AVAILABLE_LIST: # only support 2 below TTAs.
              VERTICAL_FLIP_TURN_ON: True
              HORIZONTAL_FLIP_TURN_ON: True

      Only support vertical flip and horizontal flip.

      (Augmentations are equal to reverse of themselves.)

  8. Dataset configurations

    DATASET:
        PATH: "./rawdata" # Config dataset root
        ANNS_FILE_NAME: "train_all.json"
        TRAIN_FILE_NAME: "train_all.json"
        VAL_FILE_NAME: "val.json" # not used if you set "KFOLD TURN ON - True".
        TEST_FILE_NAME: "test.json"
        NUM_CLASSES: 11

4. Our Experiments

4.1. Model

  • Encoder : timm-efficientnet-b7
    • Weight : noisy-student
  • Decoder : FPN
    • In channel : 3
    • Classes : 11
  • Fold : KFold, MultilabelStratifiedKFold
    • Number of fold : 5
  • Learning rate : 2e-4
  • TTA : Horizontal flip

4.2. Loss

SoftCrossEntropyLoss in pytorch_toolbelt

SoftCE > CE > DiceCE > Dice

4.3. Optimizer

MADGRAD provides generalization performance of SGD and fast convergence speed such as Adam.

MADGRAD > Adam

4.4. Learning rate Scheduler

CosineAnnealingWarmRestarts in torch.optim.lr_scheduler

4.5. Scaler

Autocast and GradScaler were used to shorten training time.

4.6. Augmentations

By using light model, we perform quickly various augmentation experiments.

Selected model : LRASPP mobilenetv3 Large in torchvision - for more detail, see here

  1. Hyperparameter(Epochs) Tuning for the LRASPP mobilenet v3 large model.

    Epochs mIoU mIoU derivation
    6 epochs 0.510 0.0
    12 epochs 0.553 +0.042
    24 epochs 0.571 +0.061

    For fast experiments, we don't try 48 epochs.

  2. Single Augmentation Observation

    Augmentation (Fix 24 epochs) mIoU mIoU derivation
    None 0.571 0.0
    Blur 0.572 +0.001
    GridDistortion 0.583 +0.012
    RandomGridShuffle 0.585 +0.014
    GridDropout 0.587 +0.016
    ElasticTransform 0.598 +0.027
    RandomResizeCrop 0.619 +0.048
  3. A test about compound augmentation by using albumentations.core.composition.OneOf (see here)

    • Use 5 augmentations

      • GridDistortion, RandomGridShuffle, GridDropout, ElasticTransform, RandomResizeCrop
    • Result

      Epoch (Fix augmentation) mIoU mIoU derivation
      24 epochs 0.609 +0.038
      48 epochs 0.631 +0.060
      96 epochs 0.653 +0.082

      As epoch increase, mIoU also increase.

4.7. K-Fold Ensemble

4.8. TTA

We tried to use ttach library but, couldn't use it. So, we apply only flip TTA, which is satisfied that augmentation is equal to reverse augmentation.

You can add such augmentation function codes in ./util/tta.py and modify ./config/default.yaml and get_tta_list function in ./util/tta.py.

4.9. Pseudo labeling

We convert the resulting CSV file into COCO-dataset to apply pseudo labelling.

If you want to modify your path, you should change this part in code.

    # config
    cfg = {
        "csv_file_path" : "", # csv file you want to convert
        "test_file_path" : "", # test_json path
        "result_file_path" : "", # json file you want to save result
        "maxWidth" : 256, # test image width
        "maxHeight" : 256, # test image height
    }

You can use this module in ./util/pseudo.py

5. Result

5.1. Leader Board in Competition

mIoU
Public LB 0.781
Private LB 0.717

5.2. Images after model inference

image2image3

6. Participants

Name Github Role
김서기 (T2035) Link Research(HRNet, MMSeg library), Pseudo Labeling, TTA
김승훈 (T2042) Link Find Augmentations, Code Refactoring
배민한 (T2260) Link Research(smp library, loss), Model Enhancement, K-Fold
손지아 (T2113) Link Research(smp library, loss), Model Enhancement, MLSK-Fold
이상은 (T2157) Link Research(HRNet, optimizer, loss), Pseudo Labeling, Augmix
조익수 (T2213) Link Research(MMseg library)

semantic-segmentation-level2-cv-18's People

Contributors

minhan-bae avatar tmdgns1139 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.