Comments (4)
Typically, the checkpoint need to be saved when the lowest losses (ppl) are achieved. But during my experiments, I find the model' performance can be further improved by training more epoches. In my opinion, I think the generative task is a little bit different from the classification task, and you can try to train longer even if the lowest loss are obtained.
from multiturndialogzoo.
Thank you very much for your project. The code is beautiful, the logic is clear, and I've learned a lot from your project. Will you continue to open source GAN-based dialog models, RL-based conversation models, or some transformer-based dialog models? I'm looking forward to it.
Besides, I see in the data_loader.py, the order of loading data is fixed first. Although random is added to each batch thereafter, global randomness can not be achieved.
turns = [len(dialog) for dialog in src_dataset]
turnidx = np.argsort(turns)
# sort by the lengrh of the turns
src_dataset = [src_dataset[idx] for idx in turnidx]
tgt_dataset = [tgt_dataset[idx] for idx in turnidx]
...
shuffleidx = np.arange(0, len(sbatch))
np.random.shuffle(shuffleidx)
sbatch = [sbatch[idx] for idx in shuffleidx]
tbatch = [tbatch[idx] for idx in shuffleidx]
Does that make a difference? Does the model remember the order of the data leading to overfitting?
After I add the global random, the MReCoSa model Valid Loss curve is better, and the Test PPL can also be reduced
Some of the code changes are as follows:
turns = [len(dialog) for dialog in src_dataset]
fidx, bidx = 0, 0
fidx_bidx_list = []
while fidx < len(src_dataset):
bidx = fidx + batch_size
head = turns[fidx]
cidx = 10000
for p, i in enumerate(turns[fidx:bidx]):
if i != head:
cidx = p
break
cidx = fidx + cidx
bidx = min(bidx, cidx)
# print(fidx, bidx)
# batch, [batch, turns, lengths], [batch, lengths]
# shuffle
# sbatch= src_dataset[fidx:bidx]
if bidx - fidx <= plus:
fidx = bidx
continue
fidx_bidx_list.append([fidx, bidx])
fidx = bidx
shuffleidx = np.arange(0, len(fidx_bidx_list))
np.random.shuffle(shuffleidx)
fidx_bidx_list_ = [fidx_bidx_list[i] for i in shuffleidx]
for fidx, bidx in fidx_bidx_list_:
sbatch, tbatch = src_dataset[fidx:bidx], tgt_dataset[fidx:bidx]
shuffleidx = np.arange(0, len(sbatch))
np.random.shuffle(shuffleidx)
sbatch = [sbatch[idx] for idx in shuffleidx]
tbatch = [tbatch[idx] for idx in shuffleidx]
from multiturndialogzoo.
Thank you so much for your attention to this repo.
- As for the GAN-based model, it will take me some time to implement it, which may take about one month for me.
- As for the transformer-based model, you can check my other repo OpenDialog, which contains some transformer-based retrieval and generative dialog models.
- Thank you for your improvement in this repo, I think you are doing a great job, and I will consider listening to your suggestions.
from multiturndialogzoo.
I'm a little confused about some code in In DSHRED.py:
` def forward(self, inpt, hidden=None):
# inpt: [turn_len, batch, input_size]
# hidden
# ALSO RETURN THE STATIC ATTENTION
if not hidden:
hidden = torch.randn(2, inpt.shape[1], self.hidden_size)
if torch.cuda.is_available():
hidden = hidden.cuda()
# inpt = self.drop(inpt)
# outpput: [Seq, batch, 2 * hidden_size]
output, hidden = self.gru(inpt, hidden)
# output: [seq, batch, hidden_size]
output = output[:, :, :self.hidden_size] + output[:, :, self.hidden_size:]
# static attention
**static_attn = self.attn(output[0].unsqueeze(0), output)**
static_attn = static_attn.bmm(output.transpose(0, 1))
static_attn = static_attn.transpose(0, 1) # [1, batch, hidden]
# hidden: [1, batch, hidden_size]
# hidden = hidden.squeeze(0) # [batch, hidden_size]
hidden = torch.tanh(hidden)
return static_attn, output, hidden`
Why is the bold line of code self.attn(output[0].unsqueeze(0), output) and not self.attn(output[-1].unsqueeze(0), output)?
In DSHRED paper, the static attention mechanism calculates the importance of each utterance as e_i:
e_i = V tanh(Wh_i +Uh_s),
where h_i and h_s denote the representations of hidden state of the i-th and the last utterance in a conversation.
I think h_s should be output[-1] instead of output[0]. Is that right? Thanks.
from multiturndialogzoo.
Related Issues (16)
- baseline models HOT 29
- The performances when using ReCoSa HOT 2
- output = output[:, :, :self.hidden_size] + output[:, :, self.hidden_size:],why? HOT 1
- RuntimeError: The size of tensor a (31) must match the size of tensor b (30) at non-singleton dimension 0 HOT 1
- _
- _
- 关于ubuntu数据集,seq2seqNoAttention效果太好
- 使用中文数据集
- An inquiry about ReCoSa model HOT 3
- training epochs HOT 10
- The result of Seq2Seq HOT 4
- The dataset only have 3 file instead 6 file. HOT 4
- Which one is the best one? HOT 5
- data/data_process 文件夹里貌似没有Ubuntu corpus的预处理脚本? HOT 1
- Performance on DailyDialog dataset HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from multiturndialogzoo.