Comments (5)
In SR, applying CutBlur is straightforward since we have both LR and HR images.
But in the classification problem, we only have a single image that can think as an HR image so that we need to generate an LR image from an HR one.
I haven't experimented on this but below is one possible scenario.
for inputs,labels in data_loader:
inputs_HQ = inputs.to(device, dtype=torch.float)
labels = labels.to(device, dtype=torch.float)
# or you can apply random noise, jittering, etc..
inputs_LQ = F.interpolate(inputs_HQ, scale_factor=1/4, mode="bilinear")
inputs = apply_cutblur(inputs_HQ, inputs_LQ) # just a pseudo-code
outputs = model(inputs)
I'm not sure that the above improves the classification performance but it might improve the robustness of the model.
from cutblur.
I have tried to use cutblur in my training loop but it is showing this error here ; Please help me in fixing it .
defining the training loop
from cutblur.augments import cutblur
def train_loop_fn(data_loader, model, optimizer, device, scheduler=None):
running_loss = 0.0
running_corrects = 0
model.train()
for inputs,labels in data_loader:
inputs_HQ = inputs.to(device, dtype=torch.float)
labels = labels.to(device, dtype=torch.float)
# or you can apply random noise, jittering, etc..
inputs_LQ = F.interpolate(inputs_HQ, scale_factor=1/4, mode="bilinear")
inputs = cutblur(inputs_HQ, inputs_LQ)
optimizer.zero_grad()
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = loss_fn(outputs, label)
loss.backward()
xm.optimizer_step(optimizer)
running_loss += loss.item()
running_corrects += torch.sum(preds == label.data)
train_loss = running_loss / float(len(train_data))
train_acc = running_corrects.double() / float(len(train_data))
scheduler.step(train_loss)
xm.master_print('training Loss: {:.4f} & training accuracy : {:.4f}'.format(train_loss , train_acc))
Here's the link to my notebook
from cutblur.
add inputs_LQ = F.interpolate(inputs_LQ, scale_factor=4, mode="bilinear")
right after the downsample line to match the resolution of LR and HR.
from cutblur.
@nmhkahn Which downsample line you are referring here , I am unable to understand. Please help !
from cutblur.
I was meant to as
from cutblur.augments import cutblur
def train_loop_fn(data_loader, model, optimizer, device, scheduler=None):
running_loss = 0.0
running_corrects = 0
model.train()
for inputs,labels in data_loader:
inputs_HQ = inputs.to(device, dtype=torch.float)
labels = labels.to(device, dtype=torch.float)
# or you can apply random noise, jittering, etc..
inputs_LQ = F.interpolate(inputs_HQ, scale_factor=1/4, mode="bilinear")
inputs_LQ = F.interpolate(inputs_LQ, scale_factor=4, mode="bilinear") # ADD HERE
inputs = cutblur(inputs_HQ, inputs_LQ)
from cutblur.
Related Issues (20)
- image size HOT 2
- The cutout function in augments.py HOT 1
- RCAN X2 PSNR only 36.xx HOT 7
- How to use cutblur on video super resolution? HOT 1
- Training Problem HOT 2
- about cutblur function
- Why do you use nearest method for matching the resolution of (LR, HR) due to CutBlur ? HOT 7
- Is the test result the average value of multiple models? HOT 2
- question about the input shape HOT 1
- about the size of input、output and HR in the demo
- no improvement, but rather a decline HOT 1
- What is the size of the image block you use at the x2, x3, and x4 scales?
- train on custom dataset
- Loss keeps oscillating during training
- Blend & RGB channel permutation seems cause PSNR metric drop
- Pretrained model HOT 3
- Division by zero error HOT 1
- No Output and It says 0it [00:00, ?it/s] HOT 2
- Patchsize 24 HOT 1
- Train using different dataset HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from cutblur.