Git Product home page Git Product logo

kg_one2set's People

Contributors

jiacheng-ye avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

kg_one2set's Issues

About the result.

请问提交的sh文件中的各项参数是论文中结果的参数吗?
我最近在尝试复现论文的各项结果,但是提供的one click.sh文件并不能达到预期效果。

以及各个prediction.txt文件是不是就是最终的预测结果输出。

非常期待并感谢您的回答。

a bug about [dec_layers]

大佬,您好,我是小白
代码通过-dec_layers参数来制定模块的解码器的层数,但是,实际上没有作用。
因为Decoder代码文件的

self.input_fc = nn.Linear(self.embed.embedding_dim, d_model)
self.layer_stacks = nn.ModuleList([TransformerSeq2SeqDecoderLayer(d_model, n_head, dim_ff, dropout, layer_idx,  fix_kp_num_len, max_kp_num)  for layer_idx in range(6)])
self.embed_scale = math.sqrt(d_model)

里面直接给定了参数 ( for layer_idx in range(6)])),所以使得,至少出现以下bugs:

  1. state【初始化采用def init_state(self, encoder_output, encoder_mask)函数】严格按照num_layers,即-dec_layers,但是实际层数为6,当dec_layers小于6时,会提示state访问越界
    File "d:\code\【study-see】\【源码】kg_one2set-master\pykp\modules\multi_head_attn.py", line 57, in forward prev_k = state.decoder_prev_key[self.layer_idx] IndexError: list index out of range
  2. 当-dec_layers不等于6时,参数不起作用。

另外,对于decoder代码里面出现num_layers=opt.enc_layers,我觉得不妥,您可以试着提取一个基类,使得decoder,encoder继承于他。

【这个项目,包括您的论文,我收益非常大,很感谢您的工作,我还在继续拜读您的论文和代码,十分感谢】

RuntimeError: CUDA error: device-side assert triggered

Hi,

Thanks for the nice repo.

I am facing the following error while training the model with kp20k dataset. FYI, I am training with batch_size=2.

08/30/2021 23:41:03 [INFO] train_ml: Epoch 1; batch: 90000; total batch: 90000,avg training ppl: 5.333, loss: 1.674                                                              
08/30/2021 23:43:40 [INFO] train_ml: Epoch 1; batch: 91000; total batch: 91000,avg training ppl: 5.328, loss: 1.673                                                              
08/30/2021 23:46:18 [INFO] train_ml: Epoch 1; batch: 92000; total batch: 92000,avg training ppl: 5.322, loss: 1.672                                                              
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:662: indexSelectLargeIndex: block: [148,0,0], thread: [96,0,0] Assertion `srcIndex < srcSelectDimSize` failed.                     
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:662: indexSelectLargeIndex: block: [148,0,0], thread: [97,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
...
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:662: indexSelectLargeIndex: block: [130,0,0], thread: [30,0,0] Assertion `srcIndex < srcSelectDimSize` failed.                     
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:662: indexSelectLargeIndex: block: [130,0,0], thread: [31,0,0] Assertion `srcIndex < srcSelectDimSize` failed.                     
Traceback (most recent call last):                                                                                                                                                
  File "train.py", line 103, in <module>                                                                                                                                          
    main(opt)                                                                                                                                                                     
  File "train.py", line 85, in main                                                                                                                                               
    train_ml.train_model(model, optimizer, train_data_loader, valid_data_loader, opt)                                                                                             
  File "/home/ubuntu/kg_one2set/train_ml.py", line 44, in train_model
    batch_loss_stat = train_one_batch(batch, model, optimizer, opt)
  File "/home/ubuntu/kg_one2set/train_ml.py", line 146, in train_one_batch
    control_embed = model.decoder.forward_seg(state)
  File "/home/ubuntu/kg_one2set/pykp/decoder/transformer.py", line 153, in forward_seg
    control_idx = torch.arange(0, self.max_kp_num).long().to(device).reshape(1, -1).repeat(batch_size, 1)
RuntimeError: CUDA error: device-side assert triggered

Any suggestions would be appreciated.

询问set部分

SetGenerator文件,好像是没有使用beam的,只是实现了_, tokens = decoder_dist.max(-1),请问,有什么理由么?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.