Git Product home page Git Product logo

rinalmo's Issues

How to understand the meaning of RNA representation

Hello, thank you for your reply, I have solved it, but now I have a problem:
I modified my test.py file as follows:

import torch
from rinalmo.pretrained import get_pretrained_model
DEVICE = "cuda:0"
model, alphabet = get_pretrained_model(model_name="rinalmo_giga_pretrained")
model.eval()
model = model.to(device=DEVICE)
seqs = ["CCCGGU"]
tokens = torch.tensor(alphabet.batch_tokenize(seqs), dtype=torch.int64, device=DEVICE)
with torch.no_grad(), torch.cuda.amp.autocast():
outputs = model(tokens)
for rep in outputs["representation"]:
print(rep.shape)

output:
torch.Size([8, 1280])

if seqs = ["ACUUUGGCCA"]
output:
torch.Size([12, 1280])

if seqs = ["ACUUUGGCCA","CCCGGU"]
output:
torch.Size([12, 1280])
torch.Size([12, 1280])

It seems that the output dimension is determined by the maximum sequence length of the input(Every sequence begins with a [CLS] token,and ends with an [EOS] token), and the excess dimensions are filled according to your rules.
Can I understand that each 1280 tensor represents a base?
But according to your paper:
an RNA sequence is tokenized and turned into a 1280 dimension vector using a learned input embedding model.
How do I understand the meaning of this output, and how do I fix the sequence dimensions to facilitate my downstream tasks, such as predicting interactions between RNAs?

Train database usage

Hi,

I noticed you are using a combination of database including rnacentral, rfam, ensembl and nt.

Can I please ask why did you chose these databases?

Specifically, rnacentral should be a superset of rfam and ensembl. While nt is not a part of rnacentral, it should have been very similar to the ENA database, which is also a subset of rnacentral.

Besides, what data deduplication pipelines is applied to remove the redundancy?

cluster

Hello,
I would like to ask how you use MMSeqs2 to cluster RNA, such as what values ​​are set to parameters such as identity and coverage.
Thanks!

Pre-trained weights for RiNALMo-135M?

Hello! Are there any plans to release the pre-trained weights for the smaller RiNALMo-135M model and/or any of the other configurations available (nano, micro)?

Same RNA, different representation

Hello, is this normal?
my test.py:
import torch
from rinalmo.pretrained import get_pretrained_model

DEVICE = "cuda:0"

model, alphabet = get_pretrained_model(model_name="rinalmo_giga_pretrained")
model = model.to(device=DEVICE)
seqs = ["CCCGGU","CCCGGU"]

tokens = torch.tensor(alphabet.batch_tokenize(seqs), dtype=torch.int64, device=DEVICE)
with torch.no_grad(), torch.cuda.amp.autocast():
outputs = model(tokens)

print(outputs["representation"])
but the output of two same sequences is different:
python test.py
tensor([[[ 0.0209, -0.3792, -0.9592, ..., -0.3661, -0.4986, -1.0630],
[-0.1543, -0.8713, -0.6534, ..., -0.7442, -0.4688, 0.0491],
[-0.1923, -1.0140, -1.3560, ..., -2.0971, -1.1946, -0.7145],
...,
[ 1.5532, -1.9415, -1.3395, ..., -1.3404, -1.1100, 0.9047],
[-0.1968, -0.5992, 0.3608, ..., -1.4525, -0.8330, 0.4122],
[-1.1677, 0.0836, -0.1704, ..., -0.8856, -0.8993, -0.1143]],

    [[ 0.1576, -0.0849, -1.1658,  ..., -0.1120, -0.8494, -0.4571],
     [ 0.2583, -0.0431, -0.1226,  ..., -1.9443, -0.7913,  0.4501],
     [ 0.0782, -0.8882, -0.7555,  ..., -0.7302, -1.6658,  0.0445],
     ...,
     [ 1.3045, -1.9552, -2.3737,  ..., -0.5877, -1.6685,  0.6632],
     [ 0.5900, -0.9660, -0.0392,  ..., -1.1003, -2.0937,  1.4232],
     [-0.7117, -0.8371, -0.3525,  ..., -1.1058, -1.0734, -0.6338]]],
   device='cuda:0')

the pre-trained weight download from https://zenodo.org/records/10725749/files/rinalmo_giga_pretrained.pt

error when running small inference code: "list_to_cuuint64_array"

/tmp/tmp5oe3edsd/main.c: In function ‘list_to_cuuint64_array’:
/tmp/tmp5oe3edsd/main.c:354:3: error: ‘for’ loop initial declarations are only allowed in C99 mode
for (Py_ssize_t i = 0; i < len; i++) {
^
/tmp/tmp5oe3edsd/main.c:354:3: note: use option -std=c99 or -std=gnu99 to compile your code
/tmp/tmp5oe3edsd/main.c: In function ‘list_to_cuuint32_array’:
/tmp/tmp5oe3edsd/main.c:365:3: error: ‘for’ loop initial declarations are only allowed in C99 mode
for (Py_ssize_t i = 0; i < len; i++) {
^
Traceback (most recent call last):
File "/projects/p32327/RNAFOLD/RiNALMo-main/try.py", line 13, in
outputs = model(tokens)
^^^^^^^^^^^^^
File "/home/vqc8153/miniconda3/envs/rna/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/vqc8153/miniconda3/envs/rna/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/projects/p32327/RNAFOLD/RiNALMo-main/rinalmo/model/model.py", line 26, in forward
representation, attn_weights = self.transformer(
^^^^^^^^^^^^^^^^^
File "/home/vqc8153/miniconda3/envs/rna/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/vqc8153/miniconda3/envs/rna/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/projects/p32327/RNAFOLD/RiNALMo-main/rinalmo/model/modules.py", line 58, in forward
x, attn = checkpoint.checkpoint(
^^^^^^^^^^^^^^^^^^^^^^
File "/home/vqc8153/miniconda3/envs/rna/lib/python3.11/site-packages/torch/_compile.py", line 24, in inner
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/vqc8153/miniconda3/envs/rna/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/vqc8153/miniconda3/envs/rna/lib/python3.11/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/vqc8153/miniconda3/envs/rna/lib/python3.11/site-packages/torch/utils/checkpoint.py", line 489, in checkpoint
ret = function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/vqc8153/miniconda3/envs/rna/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/vqc8153/miniconda3/envs/rna/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/projects/p32327/RNAFOLD/RiNALMo-main/rinalmo/model/modules.py", line 125, in forward
mh_out, attn = self.mh_attn(x, key_padding_mask=key_padding_mask, return_attn_probs=need_attn_weights)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/vqc8153/miniconda3/envs/rna/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/vqc8153/miniconda3/envs/rna/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/projects/p32327/RNAFOLD/RiNALMo-main/rinalmo/model/attention.py", line 193, in forward
qkv = self.rotary_emb(qkv, seqlen_offset=0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/vqc8153/miniconda3/envs/rna/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in wrapped_call_impl
return self.call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/vqc8153/miniconda3/envs/rna/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/vqc8153/miniconda3/envs/rna/lib/python3.11/site-packages/flash_attn/layers/rotary.py", line 438, in forward
return apply_rotary_emb_qkv
(
^^^^^^^^^^^^^^^^^^^^^^
File "/home/vqc8153/miniconda3/envs/rna/lib/python3.11/site-packages/flash_attn/layers/rotary.py", line 233, in apply_rotary_emb_qkv

return ApplyRotaryEmbQKV
.apply(qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/vqc8153/miniconda3/envs/rna/lib/python3.11/site-packages/torch/autograd/function.py", line 553, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/vqc8153/miniconda3/envs/rna/lib/python3.11/site-packages/flash_attn/layers/rotary.py", line 151, in forward
apply_rotary(
File "/home/vqc8153/miniconda3/envs/rna/lib/python3.11/site-packages/flash_attn/ops/triton/rotary.py", line 213, in apply_rotary
rotary_kernel[grid](
File "/home/vqc8153/miniconda3/envs/rna/lib/python3.11/site-packages/triton/runtime/jit.py", line 550, in run
bin.c_wrapper(
^^^^^^^^^^^^^
File "/home/vqc8153/miniconda3/envs/rna/lib/python3.11/site-packages/triton/compiler/compiler.py", line 692, in getattribute
self._init_handles()
File "/home/vqc8153/miniconda3/envs/rna/lib/python3.11/site-packages/triton/compiler/compiler.py", line 670, in _init_handles
bin_path = {driver.HIP: "hsaco_path", driver.CUDA: "cubin"}[driver.backend]
^^^^^^^^^^
File "/home/vqc8153/miniconda3/envs/rna/lib/python3.11/site-packages/triton/runtime/driver.py", line 157, in getattr
self._initialize_obj()
File "/home/vqc8153/miniconda3/envs/rna/lib/python3.11/site-packages/triton/runtime/driver.py", line 154, in _initialize_obj
self._obj = self._init_fn()
^^^^^^^^^^^^^^^
File "/home/vqc8153/miniconda3/envs/rna/lib/python3.11/site-packages/triton/runtime/driver.py", line 187, in initialize_driver
return CudaDriver()
^^^^^^^^^^^^
File "/home/vqc8153/miniconda3/envs/rna/lib/python3.11/site-packages/triton/runtime/driver.py", line 77, in init
self.utils = CudaUtils()
^^^^^^^^^^^
File "/home/vqc8153/miniconda3/envs/rna/lib/python3.11/site-packages/triton/runtime/driver.py", line 47, in init
so = _build("cuda_utils", src_path, tmpdir)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/vqc8153/miniconda3/envs/rna/lib/python3.11/site-packages/triton/common/build.py", line 106, in _build
ret = subprocess.check_call(cc_cmd)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/vqc8153/miniconda3/envs/rna/lib/python3.11/subprocess.py", line 413, in check_call
raise CalledProcessError(retcode, cmd)
subprocess.CalledProcessError: Command '['/usr/bin/gcc', '/tmp/tmp5oe3edsd/main.c', '-O3', '-I/home/vqc8153/miniconda3/envs/rna/lib/python3.11/site-packages/triton/common/../third_party/cuda/include', '-I/home/vqc8153/miniconda3/envs/rna/include/python3.11', '-I/tmp/tmp5oe3edsd', '-shared', '-fPIC', '-lcuda', '-o', '/tmp/tmp5oe3edsd/cuda_utils.cpython-311-x86_64-linux-gnu.so', '-L/.singularity.d/libs']' returned non-zero exit status 1.

mRNA representation

After reading the paper I am unsure if this model can be directly used for mRNA representation without fine-tuning.
Namely, the training data contains no mRNAs (if I understand correctly). However, the model was then FT on mRNA tasks.
1.) May I ask why mRNAs were not used in pre-training - what is the reason for that?
2.) Did you check masking performance on different sequence types including mRNAs? Similar as for structure prediction, although mRNAs were also missing there.

A small issue in layer_norm

Hey,

Thank you for your great work, it looks awesome.

When I tried to reproduce your work, I noticed a tiny issue:

In line 122-128 of module, you used the layer norm output of hidden states as residual path of attention, where in standard transformer implementation, the hidden states before layer norm is used as residual path.

There shouldn't be a big problem, but if anyone is unable to reproduce the claimed result, this may be the cause.

Inference

Hi,
I've had no issues installing and running your code. I am interested in using your model to do inference, for example predicting the MRL for various RNA sequences. I've tried modifying your code to do so without much success at the moment.

Any help you can provide to do such a task would be greatly appreciated. The idea is to pass a new csv dataset of RNA sequences and output the predictions of the finetuned model (MRL for example).

Thanks!

Cannot access pretrained weights

Thank you for opensource the project. When I am trying to access the weight through wget https://zenodo.org/records/10725749/files/rinalmo_giga_pretrained.pt, it returns error. And I cannot open the link as well. Could you please check the files?

No support for flash-attn transformer model

Hi there,
Have you ever pre-trained the model which was built from scratch and didn't use flash-attn, as recorded in your code? The case is my cuda version is 11.4 and doesn't support flash-attn, such that the model weights in "rinalmo_giga_pretrained.pt", which is built in flash-attn mode, are not compatible with the model built from scratch.
I'm trying to modify the plain model structure to be consistent with flash-attn model and therefore load flash-attn weights onto my plain model, but this is nontrivial for me and I can't assure it will finally work. Any suggestions?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.