Comments (14)
Thanks for reviewing the code. You are correct I fixed the code. Would update the result after the re run.
from pointer_summarizer.
@atulkum
One more question, in your code, step_coverage_loss is the sum of the minimum of attn_dist and coverage in each element.
And coverage is coverage + attn_dist.
pointer_summarizer/training_ptr_gen/model.py
Line 124 in 5e51169
So, the step_converage_loss should always be the sum of attn_dist(1.0)?
from pointer_summarizer.
I think you are right. The order of updating the coverage is not correct. The coverage should be updated after coverage loss has been calculated. I think I might have made a mistake here while code refactoring. I will fix it.
if config.is_coverage:
coverage = coverage.view(-1, t_k)
coverage = coverage + attn_dist
from pointer_summarizer.
Thanks for reviewing the code. I have fixed the bug.
https://github.com/atulkum/pointer_summarizer/blob/master/training_ptr_gen/train.py#L91
https://github.com/atulkum/pointer_summarizer/blob/master/training_ptr_gen/train.py#L100
from pointer_summarizer.
@atulkum have you ever try to set is_coverage as True, it's extremely easy to cause loss became NaN, less learning rate is useless for this issue.
from pointer_summarizer.
I have turned on is_coverage=True after training for 500k iteration. Making is_coverage=True from the beginning makes the training unstable.
from pointer_summarizer.
@atulkum I think this operation may cause NaN
calculating the memory of attention in each decoder step may create many computation graph branch in torch backend, but in fact , we only need to calculate it once after get encoder_outputs.
from pointer_summarizer.
You are right about increasing branches computation graph but it won't cause NaN. If you are getting NaN then it might be somewhere else. I tested it on pytorch 0.4 and it does not give NaN.
I am changing this to be called only once (thanks for catching this) and test it again.
from pointer_summarizer.
@atulkum I set is_coverage as True after 500k step, but at the beginning of retraining, it always gives NaN. I'll continue to test, thank you again.
from pointer_summarizer.
After how many iteration (with is_coverage = True) you are getting NaN?
Did you initialize the model_file_path in the code?
https://github.com/atulkum/pointer_summarizer/blob/master/training_ptr_gen/train.py#L141
You can try to debug it on CPU.
My GPU is busy right now, but I just tested it on CPU for around 20 epoch I did not get NaN or any exploding gradient etc.
from pointer_summarizer.
Thanks for suggestion. I initialized the model_file_path, but after no more than 100 iter, it get NaN:(
from pointer_summarizer.
I have uploaded a model here. I retrain it with with is_coverage = True for 200k iteration but did not get NaN
For retraining you should do 3 things:
- in the config.py make is_coverage = True
- in the config.py make max_iterations = 600000
- make sure in train.py you are calling trainIters with full absolute path model_file_path
I am also curious why you get NaN, can you debug it and pinpoint the code which is causing NaN?
from pointer_summarizer.
Thanks for your code, but I have a question: when I run the train.py , one error : AttributeError: 'generator' object has no attribute 'next', i dont understand it , the system show it at the batcher.py 209.
from pointer_summarizer.
@jamestang0219 were you able to point out the nan problem?
I trained without coverage till 200k and beyond that, I am using the coverage but running into nan soon after.
from pointer_summarizer.
Related Issues (20)
- Python3 support? HOT 13
- During the training and verification process, when "step = 0", the "coverage" is initialized differently. During training, the coverage is an all-zero tensor, but this is not the case during prediction. HOT 1
- url correction HOT 1
- What is the version of tensorflow? HOT 1
- Test time custom decoding!!
- Training saturates early? HOT 3
- Vector encode input extend vocab HOT 3
- what's function of the eval.py when i check the train.py ,it does't call the eval.py , save the model directly? HOT 1
- question about eval HOT 1
- eval.py decode.py HOT 2
- when i train it with coverage ,the loss is nan when i get 250k iter? HOT 1
- how to use valid dataset to select a bestmodel to test? HOT 1
- How to train with Coverage? HOT 1
- 'Encoder' object has no attribute 'tx_proj' HOT 3
- What is the specific implementation of pointer network HOT 1
- Can the code here be trained with multiple GPUs HOT 1
- Discrepancy with implementation and the paper
- Retraining model cause optimizer duplicate parameter error HOT 3
- How to choose the best training model
- Duplicated computation with LSTM?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pointer_summarizer.