Git Product home page Git Product logo

drrn-pytorch's People

Contributors

jt827859032 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

drrn-pytorch's Issues

Arguments for Training Command

Hello! May I know the value of arguments that are used in the training command:
main.py [-h] [--batchSize BATCHSIZE] [--nEpochs NEPOCHS] [--lr LR]
[--step STEP] [--cuda] [--resume RESUME]
[--start-epoch START_EPOCH] [--clip CLIP] [--threads THREADS]
[--momentum MOMENTUM] [--weight-decay WEIGHT_DECAY]
[--pretrained PRETRAINED]

to get the PSNRs in the following table.

Scale DRRN_B1U25 Paper DRRN_B1U25 PyTorch
x2 37.74 37.69
x3 34.03 34.02
x4 31.68 31.70

For example, the values of 'nEpochs', 'step', 'batchSize' and so on...
Many thanks in advance!

不能得到和论文一样的结果

感谢大佬开源,我按照你的代码,没有改任何代码,跑50 epoch,发现在第50 epoch得到结果如下:
Processing Set5_mat/head_GT_x2.mat
Processing Set5_mat/butterfly_GT_x2.mat
Processing Set5_mat/baby_GT_x2.mat
Processing Set5_mat/woman_GT_x2.mat
Scale= 2
Dataset= Set5
PSNR_predicted= 35.98152927120101
PSNR_bicubic= 33.69039381292539
It takes average 5.726621294021607s for processing
Processing Set5_mat/woman_GT_x3.mat
Processing Set5_mat/bird_GT_x3.mat
Processing Set5_mat/baby_GT_x3.mat
Processing Set5_mat/head_GT_x3.mat
Processing Set5_mat/butterfly_GT_x3.mat
Scale= 3
Dataset= Set5
PSNR_predicted= 32.590871083509164
PSNR_bicubic= 30.407692343235453
It takes average 5.687644147872925s for processing
Processing Set5_mat/bird_GT_x4.mat
Processing Set5_mat/woman_GT_x4.mat
Processing Set5_mat/butterfly_GT_x4.mat
Processing Set5_mat/head_GT_x4.mat
Processing Set5_mat/baby_GT_x4.mat
Scale= 4
Dataset= Set5
PSNR_predicted= 30.483599758580993
PSNR_bicubic= 28.41454827257395
It takes average 5.374003458023071s for processing

相反的是我在第20 epoch得到了更好的结果,emmm,如果大佬有时间可以麻烦解答下么:
Processing Set5_mat/head_GT_x2.mat
Processing Set5_mat/butterfly_GT_x2.mat
Processing Set5_mat/baby_GT_x2.mat
Processing Set5_mat/woman_GT_x2.mat
Scale= 2
Dataset= Set5
PSNR_predicted= 36.17091792433612
PSNR_bicubic= 33.69039381292539
It takes average 5.751824855804443s for processing
Processing Set5_mat/woman_GT_x3.mat
Processing Set5_mat/bird_GT_x3.mat
Processing Set5_mat/baby_GT_x3.mat
Processing Set5_mat/head_GT_x3.mat
Processing Set5_mat/butterfly_GT_x3.mat
Scale= 3
Dataset= Set5
PSNR_predicted= 32.7447910082183
PSNR_bicubic= 30.407692343235453
It takes average 5.653450155258179s for processing
Processing Set5_mat/bird_GT_x4.mat
Processing Set5_mat/woman_GT_x4.mat
Processing Set5_mat/butterfly_GT_x4.mat
Processing Set5_mat/head_GT_x4.mat
Processing Set5_mat/baby_GT_x4.mat
Scale= 4
Dataset= Set5
PSNR_predicted= 30.44373425939143
PSNR_bicubic= 28.41454827257395
It takes average 5.370740699768066s for processing

training problem

Thank you for your wonderful work!
I want to train a model with my own dataset, but there are something wrong in the process.
The error is described as below:

===> Loading datasets
===> Building model
/home/luomeilu/anaconda3/envs/py2/lib/python2.7/site-packages/torch/nn/_reduction.py:49: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.
warnings.warn(warning.format(ret))
===> Setting GPU
===> load model model/model_epoch_28.pth
/home/luomeilu/anaconda3/envs/py2/lib/python2.7/site-packages/torch/serialization.py:435: SourceChangeWarning: source code of class 'torch.nn.parallel.data_parallel.DataParallel' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
/home/luomeilu/anaconda3/envs/py2/lib/python2.7/site-packages/torch/serialization.py:435: SourceChangeWarning: source code of class 'torch.nn.modules.conv.Conv2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
/home/luomeilu/anaconda3/envs/py2/lib/python2.7/site-packages/torch/serialization.py:435: SourceChangeWarning: source code of class 'torch.nn.modules.activation.ReLU' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
Traceback (most recent call last):
File "main.py", line 128, in
main()
File "main.py", line 69, in main
model.load_state_dict(weights['model'].state_dict())
File "/home/luomeilu/anaconda3/envs/py2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 769, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for DRRN:
Missing key(s) in state_dict: "input.weight", "conv1.weight", "conv2.weight", "output.weight".
Unexpected key(s) in state_dict: "module.input.weight", "module.conv1.weight", "module.conv2.weight", "module.output.weight".

It seems that the parameters in model are miss? I can't understand the error, hoping you can give me some suggestions. Sincerely appreciate for your reply.

Not about this work

您好,
不好意思在这里问个题外话,请问您是不是也读过“A Deep-Reinforcement Learning Approach for Software-Defined Networking Routing Optimization”这篇文章吗?成果复现了吗?我训练的时候发现无论训练多少轮reward基本都不变,好像完全没训练一样,测试的时候结果也是。不知道是不是我出了什么问题,拜托指教!谢谢!

It does not work on windows 10

Hi, I have a problem with running this code because the multiprocessing works in a different way in windows. As I know it needs to add this sentence: if name == "main":
but I don't know where exactly I should add it in the code?!

Thanks a lot in advance.

h5py objects cannot be pickled

Hello, after I finished the H5 file on MATLAB, I ran main.py for training, and the TypeError: h5py objects cannot be pickled error appeared. What could be the reason?

Debugging problem

Can you teach me how to debug? ? This is my graduation design. After a long time of debugging, there are still many bugs. Thank you

如何用自己数据集测试(How to test with my own dataset.)

I want to be able to test with my own dataset. First, I need to convert the image to. Mat, such as a_ x2.mat ,a_ x4.mat
In the project, for example, How is the file baby_ GT_ x2.mat generated?

我希望能用自己的数据集测试,首先需要将图片转换为.mat,如a_x2.mat ,a_x4.mat
请问项目中例如baby_GT_x2.mat的文件是如何生成的?

how to evalutate on my own dataset?

There are files named as ***.mat what inculdes several images in different versions in demo.
But how can I evaluate on my own picture directly?

how does your implementation share the weight?

Hi, there,
I don't find the any code to evident that the parameter is shared. Maybe becanse I don't I understand how to use the "weight shared function" of pytorch? Can u help me?
thanks.

class DRRN(nn.Module):
	def __init__(self):
		super(DRRN, self).__init__()
		self.input = nn.Conv2d(in_channels=1, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False)
		self.conv1 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False)
		self.conv2 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False)
		self.output = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False)
		self.relu = nn.ReLU(inplace=True)

		# weights initialization
		for m in self.modules():
			if isinstance(m, nn.Conv2d):
				n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
				m.weight.data.normal_(0, sqrt(2. / n))

	def forward(self, x):
		residual = x
		inputs = self.input(self.relu(x))
		out = inputs
		for _ in range(25):
			out = self.conv2(self.relu(self.conv1(self.relu(out))))
			out = torch.add(out, inputs)

		out = self.output(self.relu(out))
		out = torch.add(out, residual)
return out

Cuda Runtime Error: out of memory

When i was testing the model in Set14, I found the GPU memory is overflow. How could you solve this problem? And my GPU is 1080Ti.

Why set_5.mat use double as ground truth

Hi I wonder how you generate the Set_5.mat. It shows that the ground truth are using double type and are not round to integer, if I try to round to integer, it gives a different psnr. However I found if I directly using rgb2ycrcb it gives an integer value in matlab.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.