avivnavon / nash-mtl Goto Github PK
View Code? Open in Web Editor NEWOfficial implementation of "Multi-Task Learning as a Bargaining Game" [ICML 2022]
Official implementation of "Multi-Task Learning as a Bargaining Game" [ICML 2022]
Thank you for the great work.
When I tried the method, I found that the calculated weights are relatively large than other mtl method, like [47.24413258, 732.26542343] vs [0.4, 0.6]
Is it normal? Should I rescale the weights? Because I think the large weights of losses may influence the regularization.
Thank you!
I would like to express my sincere gratitude for the excellent research you have conducted. Your work has significantly helped me in defining the direction of my own research and initiating it. Specifically, I am working on solving problems related to multi-task learning in 3D object detection and BEV segmentation for autonomous driving.
I have a question regarding the repository you provided and have left an issue for discussion. Does your repository include code for torch.distributed to facilitate multi-GPU learning? I am inquiring because using torch.distributed allows averaging gradients across GPUs, which alters the computations. I am facing some challenges in this area and would greatly appreciate your guidance or insights.
Thank you for your support and looking forward to your response.
Junghokim
Hi, there. When I train the model, sometimes the code will call an error which is "ValueError: Parameter value must be real". It occurs at the call path from [self.get_weighted_loss] to [self.solve_optimization(GTG.cpu().detach().numpy())] to [self.G_param.value = gtg] to [self._value = self._validate_value(val)].
Could you help me to deal with the error? Thanks a lot!
Traceback (most recent call last):
File "xxx.py", line 435, in <module>
main()
File "xxx.py", line 427, in main
trainer.training(epoch)
File "xxx.py", line 158, in training
loss1, extra_outputs = self.weight_method.backward(
File "methods\weight_methods.py", line 810, in backward
return self.method.backward(losses, **kwargs)
File "methods\weight_methods.py", line 263, in backward
loss, extra_outputs = self.get_weighted_loss(
File "methods\weight_methods.py", line 237, in get_weighted_loss
alpha = self.solve_optimization(GTG.cpu().detach().numpy())
File "methods\weight_methods.py", line 134, in solve_optimization
self.G_param.value = gtg
File "C:\Users\xxx\.conda\envs\xxx\lib\site-packages\cvxpy\expressions\constants\parameter.py", line 87, in value
self._value = self._validate_value(val)
File "C:\Users\xxx\.conda\envs\xxx\lib\site-packages\cvxpy\expressions\leaf.py", line 442, in _validate_value
raise ValueError(
ValueError: Parameter value must be real.
Process finished with exit code 1
Thank you so much for the great work and for putting the code for all MTL algorithms together in a unified manner.
I find a minor error in your implementation of CAGrad. From line 563-567 in methods/weight_methods.py, I suppose what you want to do is to retain the computation graph for all tasks except the last one so that some memory can be saved. However, it seems that in the current implementation, the code will always go inside the first if statement, meaning that the computation graph is retained for all tasks. I find this issue because I run into out-of-memory errors when using your code. Hope this helps!
Hello, my network is a 2-task network based on ResNet-18. I have tried PCGrad, am currently trying nash-mtl. But I found that the multi-task training effect decreased with the addition of PCGrad, and the network with nash-mtl could not even converge. May I ask if there are any applicable conditions for nash-mtl? For example, what are the requirements for the loss of each task, what are the requirements for multi-task network, and other applicable conditions? In addition, how to adjust the hyperparameter of nash-mtl, 'update_weights_every', "optim_niter", 'max_norm'?
Can you implement GradNorm to compare with other algorithms? If it's convenient
Hi, experts
Thanks sharing the excellent work about MTL.
Thanks for your owsome works, it's really cooooool and helpful!
My question is about deploying the MTL methods on multiple NN modules. For example,
import torch.nn.functional as F
import torch
class First_model(nn.Module):
def __init__(self,
input_dim,
out_dim,
):
super(First_model, self).__init__()
self.nn = nn.Sequential(
torch.nn.Linear(self.input_dim, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, self.state_dim)
)
def forward(self, inputs):
return self.nn(inputs)
class Second_model(nn.Module):
def __init__(self,
input_dim,
out_dim,
):
super(Second_model, self).__init__()
self.nn = nn.Sequential(
torch.nn.Linear(self.input_dim, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, self.state_dim)
)
def forward(self, inputs):
return self.nn(inputs)
class Third_model(nn.Module):
def __init__(self,
input_dim,
out_dim,
):
super(Third_model, self).__init__()
self.nn = nn.Sequential(
torch.nn.Linear(self.input_dim, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, self.state_dim)
)
def forward(self, inputs):
return self.nn(inputs)
first_model = First_model(25, 1)
second_model = Second_model(25, 1)
third_model = Third_model(25, 1)
x = torch.randn(10,25)
y = torch.randn(10,1)
m = torch.randn(10,1)
n = torch.randn(10,1)
losses.append(F.mse(first_model(x), y))
losses.append(F.mse(second_model(x), m))
losses.append(F.mse(third_model(x), n))
losses.append(F.mse(third_model(x), y))
Now I want to use MTL methods to train these models, but I don't know if I was using it correctly. Here is my template code:
weight_method = WeightMethods(
method,
n_tasks=4,
device=device,
**weight_method_params[method],
)
loss, extra_outputs = weight_method.backward(
losses=losses,
shared_parameters=,
task_specific_parameters=,
last_shared_parameters=,
representation=features,
)
I'm not quite certain about what kind of variable structure should I feed in to 'weight_method.backward' in such case?
Hi @AvivNavon, thank you for your great work. Im trying to experiment with sth new, so could you please give me the log files for NYUv2 experiments so that I can benchmark my running?
First of all, thank you for the great repository :)
Could you also upload the RL training code for meta-world MT10 tasks?
Thank you!
I have run your experment of nyuv2 on my own computer. But I can not find the result of Mean Rank metric, which you reported in the paper. How can I compute the metric by myself?
Thank you very much
Hi, there. I have two warnings when training the nash-mtl model. The first warning is "the problem is not DPP", and the second is "Solution may be inaccurate" which I check out the problem in detail is "OPTIMAL_INACCURATE". Do you have these warnings? And have any suggestions for fix these warnings? Thanks a lot.
There is the detail terminal warning code.
\.conda\envs\xxx\lib\site-packages\cvxpy\reductions\solvers\solving_chain.py:213: UserWarning: You are solving a parameterized problem that is not DPP. Because the problem is not DPP, subsequent solves will not be faster than the first one. For more information, see the documentation on Discplined Parametrized Programming, at
https://www.cvxpy.org/tutorial/advanced/index.html#disciplined-parametrized-programming
warnings.warn(dpp_error_msg)
.conda\envs\xxx\lib\site-packages\cvxpy\problems\problem.py:1387: UserWarning: Solution may be inaccurate. Try another solver, adjusting the solver settings, or solve with verbose=True for more information.
- warnings.warn(
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.