Git Product home page Git Product logo

Comments (2)

tghong avatar tghong commented on May 26, 2024 2

Hi, thank you for your interest.
As you understood, the model can be fine-tuned for EE and EL tasks at the same time.
It can be done by adding the loss (itc_loss + stc_loss) in model/bros_spade.py and the loss in model/bros_spade_rel.py.

bros/model/bros_spade.py

Lines 101 to 110 in 55c52d0

def _get_loss(self, head_outputs, batch):
itc_outputs = head_outputs["itc_outputs"]
stc_outputs = head_outputs["stc_outputs"]
itc_loss = self._get_itc_loss(itc_outputs, batch)
stc_loss = self._get_stc_loss(stc_outputs, batch)
loss = itc_loss + stc_loss
return loss

def _get_loss(self, head_outputs, batch):
el_outputs = head_outputs["el_outputs"]
bsz, max_seq_length = batch["attention_mask"].shape
device = batch["attention_mask"].device
self_token_mask = (
torch.eye(max_seq_length, max_seq_length + 1).to(device).bool()
)
box_first_token_mask = torch.cat(
[
(batch["are_box_first_tokens"] == False),
torch.zeros([bsz, 1], dtype=torch.bool).to(device),
],
axis=1,
)
el_outputs.masked_fill_(box_first_token_mask[:, None, :], -10000.0)
el_outputs.masked_fill_(self_token_mask[None, :, :], -10000.0)
mask = batch["are_box_first_tokens"].view(-1)
logits = el_outputs.view(-1, max_seq_length + 1)
logits = logits[mask]
labels = batch["el_labels"].view(-1)
labels = labels[mask]
loss = self.loss_func(logits, labels)
return loss

from bros.

unleft avatar unleft commented on May 26, 2024

Hello, I have had the same question, but I would like to ask how exactly I can implement the mentioned feature.

Should I create a completely new Model and Module class from scratch... Or is there another way to make it work?

from bros.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.