Git Product home page Git Product logo

flash-rwkv's Introduction

FlashRWKV

Motivation

There are several reasons for creating the standalone FlashRWKV package:

  • During the support for the RWKV5 model in the transformersArthurZuckersuggested that RWKV5's custom CUDA kernel should be implemented independently. This way, there's no need to compile and install the CUDA kernel within the transformers library itself.
  • When implementing custom RWKV5 and RWKV6 models within the Hugging Face community, I found that .cu and .cpp files uploaded to the Hugging Face model repository could not be automatically downloaded when using trust_remote_code=True. To compile and use cuda kernels, it was necessary to manually copy the missing files, which was very cumbersome. And I found this problem was also present with Qwen's chat model and so on.
  • The rwkv package provided in the official ChatRWKV library can compile kernels, but it is quite difficult to use externally. For example, it is not feasible to use that package when implementing RWKV5 on Hugging Face, and the compilation of the rwkv package in the ChatRWKV library does not support backward propagation. In contrast, using the FlashRWKV package allows for a completely independent and convenient extension to any model implementation such as RWKV5/RWKV6, and it supports backward propagation.
  • Use the state-of-the-art RWKV CUDA kernel on Nvidia GPUs to accelerate fine-tuning training and inference, enhancing model efficiency.
  • Previously, there was no experience in using Torch's C++ extension module for library creation, so this project also serves as a skill-building exercise.

Installation and features

Simply use pip install flash-rwkv .

How to use FlashRWKV

RWKV5

from flash_rwkv import rwkv5_cuda_linear_attention

out, state = rwkv5_cuda_linear_attention(receptance, key, value, time_decay, time_first, state)

Here, receptance, key, value, time_decay, time_first, state are intermediate results generated by the RWKV Linear Attention module. The shape of these Tensors and their equivalent naive Python computation process can be seen in therwkv5 complete test file

RWKV6

from flash_rwkv import rwkv6_cuda_linear_attention

out, state = rwkv6_cuda_linear_attention(receptance, key, value, time_decay, time_first, state)

Here, receptance, key, value, time_decay, time_first, state are intermediate results generated by the RWKV Linear Attention module. The shape of these Tensors and their equivalent naive Python computation process can be seen in therwkv5 complete test file

Why Flash

The CUDA kernel used here is the optimal version we manually implemented in RWKV-CUDA. Compared to the simple Hugging Face implementation or naive CUDA kernels, it offers significant acceleration in both forward and backward operations. Detailed benchmarks will be posted here later, and we will also explore new optimization opportunities.

Changelog

  • Released v0.2.0 on 2024.4.6, supported the RWKV5 model, and providing the rwkv5_cuda_linear_attention API.
  • Released v0.2.1 on 2024.4.22, applied rwkv5_cuda_linear_attention to rwkv-5-world-1b5 and rwkv-5-world-3b .
  • Released v0.3.0 on 2024.5.3, supported the RWKV6 model, and providing the rwkv6_cuda_linear_attention API.
  • Appiled rwkv6_cuda_linear_attention to rwkv-6-world-3b and rwkv-6-world-7b .

Plan

  • Operator and end2end model benchmarking.
  • Integration of this library's operators with Hugging Face's RWKV models.
  • Support for RWKV6.
  • Continue optimize kernel.

flash-rwkv's People

Contributors

bbuf avatar

Stargazers

 avatar openGiGi avatar quinlan avatar DefTruth avatar SHI Xuan avatar LingBing avatar Eric Alcaide avatar Hanlin Zhang avatar 爱可可-爱生活 avatar Xin (Simon) Dong avatar Peter avatar Han Zhao avatar Sofian Mejjoute avatar OpenMOSE avatar Iron-Bound avatar  avatar  avatar Peyton avatar Jeff Carpenter avatar zhangtao avatar Lianghui Zhu avatar  avatar Bing Han avatar Bencheng avatar Yu Zhang avatar khazzz1c avatar  avatar

Watchers

 avatar  avatar

Forkers

brightxiaohan

flash-rwkv's Issues

关于 rwkv 的推理速度疑问

execution_times_comparison

你好,我使用 huggingface 的 transformers 库对 huggingfacehub 上面的 llama2-7b和 rwkv-7b进行了下载,然后针对不同的输出长度在 A100-40GB上进行了对比,得到的结果如上图所示。因为 rwkv 在论文中提到了 rwkv 的比较大的优势就是推理速度比较快,但是现在的测试结果看起来比 transformer 类的大模型似乎要慢非常多。请问您和您的团队有测试过这方面的数据吗,或者您和您的团队是否知道原因。我非常乐意继续探索这个图上的结果出现的原因,以及继续进行优化。
如果您能给予一些意见和指导我将不胜感激。

Triton support

If you're planning to make this API somehow standardized it would be great to integrate Songlin Yang's excellent new Triton RWKV-6 implementation from FLA
https://github.com/sustcsonglin/flash-linear-attention/blob/main/fla/ops/rwkv6/chunk.py

Another note is that the RWKV-6 kernel can be used for RWKV-5 with the proper expansion of the w tensor, so its possible to support just that one if desired for simplicity and maintenance.

Also, I don't know if your CUDA version supports true full backprop of the state gradients (none of Bo Peng's do fully yet). I believe hers does.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.