Git Product home page Git Product logo

Comments (5)

mahmoudnafifi avatar mahmoudnafifi commented on July 4, 2024

The generator should work with 512 images without changing its code. You need only to change the code in main_training.m script (check this switch statement)

from exposure_correction.

hermosayhl avatar hermosayhl commented on July 4, 2024

The generator should work with 512 images without changing its code. You need only to change the code in main_training.m script (check this switch statement)

Thanks for you reply!

First, I encountered this error

Preparing training data ...
Creating the generator model ...
Error using dlnetwork (line 81)
Invalid network.

Error in create_generator (line 85)
net  = dlnetwork(net);

Error in main_training (line 132)
        net = create_generator(patchSize, encoderDecoderDepth, chnls, convvfilter);

Caused by:
    Layer 'level_4-Encoder-Stage-4-MaxPool': Input size mismatch. Size of input to this layer is different from the expected
    input size.
    Inputs to this layer:
        from layer 'level_4-Encoder-Stage-4-L-ReLU-2' (1×1×192 output)
 

I found this "level_4-Encoder-Stage-4-L-ReLU-2" was defined at 29-30 lines in src/addSubNet.m, but I don't know how to modify the code.

I changed the code as follows:

%% training code
% Author: Mahmoud Afifi
% Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
% Please cite our paper:
% Mahmoud Afifi,  Konstantinos G. Derpanis, Bj鰎n Ommer, and Michael S
% Brown. Learning Multi-Scale Photo Exposure Correction, In CVPR 2021.
%%

clc
clear;
close all;

lR = 10^-4; % initial learning rate

chnls = 16; % number of channels of 1st layere of the encoder for the highest pyramid level

convvfilter = 3; % conv kernel size

encoderDecoderDepth = 3; % numbere of layers (i.e., levels) for the highest pyramid level 

trainingImgsNum = 0; %if 0, then load all training images

withDiscriminator = 1; % include discriminator loss term?

for ps = [64, 128, 256] % for each patch size, do
    
    % please, update training/validation directories accordingly
    
    In_Tr_datasetDir = fullfile('exposure_dataset','training',sprintf('INPUT_IMAGES_P_%d',ps)); % input training patches with size ps size
    
    GT_Tr_datasetDir = fullfile('exposure_dataset','training',sprintf('GT_IMAGES_P_%d',ps)); % ground truth training patches with size ps
    
    In_Vl_datasetDir = fullfile('exposure_dataset','validation',sprintf('INPUT_IMAGES_P_%d',ps)); % validation
    
    GT_Vl_datasetDir = fullfile('exposure_dataset','validation',sprintf('GT_IMAGES_P_%d',ps));
     
    patchSize = [ps, ps, 12]; % 3 color channels x 4 pyramid levels

    switch ps
        
        case 64
            
            dropRate = 20; % drop learning rate
            
            checkpoint_period = 10; % bkup every checkpoint_period
            
            epochs = 40; % number of epochs
            
            miniBatch = 32; % mini-batch size
            
            chkpoint = ''; % start training from scratch -- no chkpoint
            
             if withDiscriminator == 1
                 
                 chkpoint_d = '';
             
             end
            
            validationImgsNum = 2000; % number of validation patches
            
            vlFreq = 5612 *2; % every vlFreq iterations, do validation
            
        case 128
            
            dropRate = 10;
        
            checkpoint_period = 5;
            
            epochs = 30;
            
            miniBatch = 8;
            
            chkpoint = sprintf('model_%d.mat',ps/2);
            
            if withDiscriminator == 1
                chkpoint_d =  '';
            end
            
            validationImgsNum = 1000;
            
            vlFreq = 13230 *2;
        
        case 256
        
            dropRate = 5;
            
            checkpoint_period = 5;
            
            epochs = 20;
            
            miniBatch = 4;
            
            chkpoint = sprintf('model_%d.mat',ps/2);
            
            if withDiscriminator == 1
                chkpoint_d =  sprintf('D_model_%d.mat',ps/2);
            end
            
            validationImgsNum = 500;
            
            vlFreq = 17378 *2;
        
        otherwise
            
            error('wrong ps value');
    end
   checkpoint_dir = sprintf('%dx%d_reports_and_backup_%s',ps,ps,date);
    
    GPUDevice = 1;
    
  
    modelName = sprintf('model_%d.mat',ps);
    
    if withDiscriminator == 1
        D_modelName = sprintf('D_model_%d.mat',ps);
    end
    
    fprintf('Preparing training data ...\n');
    
    [Trdata,Vldata] = getTr_Vl_data(In_Tr_datasetDir, GT_Tr_datasetDir, ...
        In_Vl_datasetDir, GT_Vl_datasetDir, trainingImgsNum, ...
        validationImgsNum, patchSize(1:2),...
        miniBatch);
    
    options = get_trainingOptions(epochs,miniBatch,lR,...
        checkpoint_dir,Vldata,GPUDevice, checkpoint_period, ...
        vlFreq, dropRate);
    
    
    if strcmp(chkpoint,'')
        fprintf('Creating the generator model ...\n');
        net = create_generator(patchSize, encoderDecoderDepth, chnls, convvfilter);
    else
        fprintf('Loading the generator model ...\n');
        load(chkpoint);
        inLayer = imageInputLayer(patchSize,'Name','InputLayer',...
            'Normalization','none');
        net = layerGraph(net);
        net=replaceLayer(net,'InputLayer',inLayer);
        net = dlnetwork(net);
    end
    
    %define/load the discriminator
    if withDiscriminator == 1
        if strcmp(chkpoint_d,'')
            fprintf('Creating the discriminator model ...\n');
            [D] = createDiscriminator();
        else
            fprintf('Loading the discriminator model ...\n');
            load(chkpoint_d);
        end
    end
    
    
    fprintf('Starting training ...\n');
    
    if withDiscriminator == 1
        switch ps
            case 64
                [net, D] = train_network(Trdata,net,[], options);
            case 128
                [net, D] = train_network(Trdata,net,D, options,15);
            case 256
                [net, D] = train_network(Trdata,net,D, options,5);
        end
    else
        [net, ~] = train_network(Trdata,net,[], options);
    end
    
    disp('Done!');
    
    disp('Saving model!');
    
    save(modelName,'net','-v7.3');
    if withDiscriminator == 1
        save(D_modelName,'D','-v7.3');
    end
end

from exposure_correction.

mahmoudnafifi avatar mahmoudnafifi commented on July 4, 2024

To make sure I correctly understand your issue, do you want to train on the same 64,128,256 but the original images are 512, or you would like to train directly on 512?

from exposure_correction.

hermosayhl avatar hermosayhl commented on July 4, 2024

To make sure I correctly understand your issue, do you want to train on the same 64,128,256 but the original images are 512, or you would like to train directly on 512?

You are right, I want to train on the same 64,128,256 but the original images are 512. In the code, 64 is not supported.

from exposure_correction.

mahmoudnafifi avatar mahmoudnafifi commented on July 4, 2024

Oh right, 64 needs shallower net than the current architecture. Rule of thumb is to use image size (m) that is divisible by 2^n (where n is the number of encoder layers) and m/(2^n) should be >= 1. Keep in mind that input images are first processed by the Pyramid decomposition step, which reduces the dimensions by 2^(k - 1) at the last layer (where k is the number of the Pyramid layers).

from exposure_correction.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.