Git Product home page Git Product logo

flaggems's Introduction

中文版

Introduction

FlagGems is a high-performance general operator library implemented in OpenAI Triton. It aims to provide a suite of kernel functions to accelerate LLM training and inference.

By registering with the ATen backend of PyTorch, FlagGems facilitates a seamless transition, allowing users to switch to the Triton function library without the need to modify their model code. Users can still utilize the ATen backend as usual while experiencing significant performance enhancement. The Triton language offers benefits in readability, user-friendliness and performance comparable to CUDA. This convenience allows developers to engage in the development of FlagGems with minimal learning investment.

Feature

Automatic Codegen

In FlagGems, we provide automatic code generation that developers can use to conveniently generate pointwise single operators and pointwise fused operators. Automatic code generation can handle various needs such as normal pointwise computations, non-tensor arguments, and specifying output data types.

Normal Pointwise Operator

Decorating the pointwise operator function with pointwise_dynamic can save the manual handling of tensor addressing, tensor read/write, parallel tiling, tensor broadcasting, dynamic dimensions, non-contiguous storage, etc. For example, in the following code, developers only need to describe the computational logic to generate flexible and efficient Triton code.

@pointwise_dynamic(promotion_methods=[(0, "COMPLEX_TO_FLOAT")])
@triton.jit
def abs_func(x):
    return tl.abs(x)

Non-Tensor Argument

By default, pointwise_dynamic treats all parameters as tensors, and by passing a list of boolean values to the parameter is_tensor, developers can specify which parameters are tensors and which are not. Additionally, developers can pass in dtypes to indicate the data types of non-tensor parameters, but this is not required. For example, in the following code, the alpha parameter is defined as a non-tensor floating point number, while the x and y parameters are defined as tensors.

@pointwise_dynamic(
    is_tensor=[True, True, False],
    dtypes=[None, None, float],
    promotion_methods=[(0,"DEFAULT")]
)
@triton.jit
def add_func(x, y, alpha):
    return x + y * alpha

Output Data Type

Furthermore, developers MUST provide promotion_methods to specify how type promotion should be handled for the operation to achieve the correct output type during computation.

@pointwise_dynamic(output_dtypes=[torch.bool])
@triton.jit
def ge(x, y):
    return x > y

In promotion_methods, an int is used to indicate the position of the parameter requiring type promotion, while a str denotes the method of type promotion. The str corresponds to the following enumerated types:

class ELEMENTWISE_TYPE_PROMOTION_KIND(Enum):
    DEFAULT = (0,)
    NO_OPMATH = (1,)
    INT_TO_FLOAT = (2,)
    ALWAYS_BOOL = (3,)
    COMPLEX_TO_FLOAT = (4,)
    BOOL_TO_LONG = (5,)

Examples:

  • DEFAULT :add
  • NO_OPMATH : where, nextafter, cat
  • INT_TO_FLOAT :sin
  • ALWAYS_BOOL :eq
  • COMPLEX_TO_FLOAT :abs
  • BOOL_TO_LONG :pow

Changelog

v1.0

  • support BLAS operators: addmm, bmm, mm
  • support pointwise operators: abs, add, div, dropout, exp, gelu, mul, pow, reciprocal, relu, rsqrt, silu, sub, triu
  • support reduction operators: cumsum, layernorm, mean, softmax

v2.0

  • support BLAS operator: mv, outer
  • support pointwise operators: bitwise_and, bitwise_not, bitwise_or, cos, clamp, eq, ge, gt, isinf, isnan, le, lt, ne, neg, or, sin, tanh, sigmoid
  • support reduction operators: all, any, amax, argmax, max, min, prod, sum, var_mean, vector_norm, cross_entropy_loss, group_norm, log_softmax, rms_norm
  • support fused operators: skip_rms_norm, skip_layer_norm, gelu_and_mul, silu_and_mul, apply_rotary_position_embedding

Quick Start

Requirements

  1. Triton >= 2.2.0, <3.0.0
  2. PyTorch >= 2.1.2
  3. Transformers >= 4.40.2

Installation

git clone https://github.com/FlagOpen/FlagGems.git
cd FlagGems
pip install .

Usage

Import

  1. Enable permanently

    import flag_gems
    flag_gems.enable()
  2. Enable temporarily

    import flag_gems
    with flag_gems.use_gems():
        pass
  3. Example

    import torch
    import flag_gems
    
    M, N, K = 1024, 1024, 1024
    A = torch.randn((M, K), dtype=torch.float16, device="cuda")
    B = torch.randn((K, N), dtype=torch.float16, device="cuda")
    with flag_gems.use_gems():
        C = torch.mm(A, B)

Execute

  1. Test Operator Accuracy

    • Run reference on cuda
      cd tests
      pytest test_xx_ops.py
    • Run reference on cpu
      cd tests
      pytest test_xx_ops.py --device cpu
  2. Test Model Accuracy

    cd examples
    pytest model_xx_test.py
  3. Test Operator Performance

    • Test CUDA performance
      cd benchmark
      pytest test_xx_perf.py -s
    • Test end-to-end performance
      cd benchmark
      pytest test_xx_perf.py -s --mode cpu
  4. Run tests with logging infomation

    pytest program.py --log-cli-level debug

    Not recommended in performance testing.

Supported Operators

Operators will be implemented according to OperatorList.md.

Supported Models

  • Bert-base-uncased
  • Llama-2-7b
  • Llava-1.5-7b

Supported Platforms

Platform float16 float32 bfloat16
Nvidia A100

Performance

The following chart shows the speedup of FlagGems compared with PyTorch ATen library in eager mode. The speedup is calculated by averaging the speedup on each shape, representing the overall performance of the operator.

Operator Speedup

Contributions

If you are interested in contributing to the FlagGems project, please refer to CONTRIBUTING.md. Any contributions would be highly appreciated.

Contact us

If you have any questions about our project, please submit an issue, or contact us through [email protected].

License

The FlagGems project is based on Apache 2.0.

flaggems's People

Contributors

strongspoon avatar jokmingwong avatar bowen12992 avatar iclementine avatar mard1no avatar tongxin avatar fatjhon avatar pingzhuu avatar zhzhcookie avatar gwokhiujin avatar phoenixdong avatar

Stargazers

 avatar  avatar Zhiyuan Li avatar 编程浪浪浪子 avatar Cunxiao Du avatar Wenrui Zhang avatar Ben avatar nii avatar HaiLongHuang avatar  avatar darthy avatar  avatar  avatar  avatar  avatar ToJoY avatar Xiaotian Han avatar Arcmoon avatar Guoliang He avatar BaofengZan avatar  avatar GinobiLi avatar samsara avatar  avatar liu avatar G.O.D avatar  avatar  avatar Dayrker avatar Dong Chen avatar Luchang Li avatar fredchen avatar Rinne avatar linxiaobo avatar GITSRC avatar  avatar xmfbit avatar Junyu Zhang avatar Shaoyu Yang avatar  avatar Siyuan Feng avatar Roxbili avatar s酱 avatar Eric Alcaide avatar tfruan avatar QQSong avatar CanftIn avatar linzhuo avatar wangone avatar Marcelo Albertini avatar Xi Chen avatar  avatar JIMMY ZHAO avatar  avatar XianyanLin avatar fade_away avatar  avatar zhaochaoxing avatar  avatar Lord >ε< Rebel avatar hyperkube avatar Tan Shaohui avatar  avatar Ben Yuan avatar Trần Đức Trung avatar Yixin avatar flame avatar Yuan avatar Haifeng Han avatar Tianchen Liu avatar  avatar  avatar Yuanhao Ji avatar sunkx109 avatar 孙矢初 avatar  avatar Zak-Sing avatar  avatar AriesWu avatar  avatar Kin Zhi avatar wdc avatar Joonhyung Lee/이준형 avatar ZZHHogan avatar Simon avatar Jac Zhao avatar Gabriel Wu avatar Jia avatar  avatar XiaTi avatar Doraemonzzz avatar rxzfn avatar Chi Tran avatar Jee Jee Li avatar  avatar felix-wang avatar LeeHX avatar Hailey Schoelkopf avatar TAN Xin avatar LiYu Lu avatar

Watchers

Lucian avatar wehu avatar Shuo Yuan avatar Song Yu avatar Amal Cao avatar  avatar Yulong Ao avatar cuichaowen avatar  avatar  avatar Bruce avatar  avatar G.O.D avatar

flaggems's Issues

Missing ast.unparse method in Python 3.8

Issue

In Python 3.8, the ast module does not include the unparse method.

While requires-python = ">=3.8" given

Env

Python 3.8.18
Torch: 2.3.1
Triton: 2.3.1
Pytest: 8.2.2

command

python -m pytest -svvv tests/test_unary_pointwise_ops.py::test_accuracy_abs

Result

image

image

使用FlagGems报错 AttributeError: module has no attribute '_wrapper' 和OSError: could not get source code

当我使用开启FlagGems来训练模型来对比能带来多大加速效果时,报了如下两个错,请问怎么解决:

File "/usr/local/lib/python3.10/dist-packages/swift/llm/utils/argument.py", line 234, in post_init
self.output_dir = add_version_to_work_dir(self.output_dir)
File "/usr/local/lib/python3.10/dist-packages/swift/utils/utils.py", line 59, in add_version_to_work_dir
sub_folder = broadcast_string(sub_folder)
File "/usr/local/lib/python3.10/dist-packages/swift/utils/torch_utils.py", line 176, in broadcast_string
first_zero = (tensor == 0).nonzero()[0].item()
File "/usr/local/lib/python3.10/dist-packages/flag_gems/ops/eq.py", line 28, in eq_scalar
O = eq_func_scalar(A, B)
File "/usr/local/lib/python3.10/dist-packages/flag_gems/utils/pointwise_dynamic.py", line 625, in call
overload = getattr(m, "_wrapper")
AttributeError: module '_gen_module_fa4c6a4330508d6b25e4d8648dd7f4f58dcfec5f20' has no attribute '_wrapper'

File "/root/.flaggems/pointwise_dynamic_fa4c6a4330508d6b25e4d8648dd7f4f58dcfec5f20_rank_1.py", line 40, in
def _jit_function(
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 561, in decorator
return JITFunction(
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 440, in init
self.starting_line_number = inspect.getsourcelines(fn)[1]
File "/usr/lib/python3.10/inspect.py", line 1121, in getsourcelines
lines, lnum = findsource(object)
File "/usr/lib/python3.10/inspect.py", line 958, in findsource
raise OSError('could not get source code')
OSError: could not get source code

failed to run reference on cpu without CUDA

Thanks for the help of issue #126 .
I got a question when I tried to run reference on cpu without CUDA. The reproduce steps are as follows,

Requirements

pip install triton==2.2 (Requires Triton >= 2.2.0, <3.0.0)
pip install torch==2.1.2 (Requires PyTorch >= 2.1.2)
pip install transformers==4.42.3 (Requires Transformers >= 4.40.2)

Codebase

commit 95f5afaf0219c2085d1717e8cd85dff5cc7e3cdd (HEAD -> master)
Author: Clement Chan <[email protected]>
Date:   Thu Jul 4 15:11:54 2024 +0800

    [codegen] generate gsl(grid-stride-loop) style pointwise kernel  (#91)
    
    * generate gsl(grid-stride-loop) style pointwise kernel to avoid grid_size exceeding the max grid size
    * add device guard around kernel launch
    * avoid assign to a constexpr since we are inlined into a loop
    * remove redundant code for rank-0 case

Installation:

git clone https://github.com/FlagOpen/FlagGems.git
cd FlagGems
pip install .

Run reference on cpu

cd tests
pytest test_unary_pointwise_ops.py::test_accuracy_abs[dtype0-shape0] --device cpu

Results

tests/test_unary_pointwise_ops.py F                                                                                                                                                                                                          [100%]

===================================================================================================================== FAILURES =====================================================================================================================
_________________________________________________________________________________________________________ test_accuracy_abs[dtype0-shape0] _________________________________________________________________________________________________________

shape = (1024, 1024), dtype = torch.float16

    @pytest.mark.parametrize("shape", POINTWISE_SHAPES)
    @pytest.mark.parametrize("dtype", FLOAT_DTYPES)
    def test_accuracy_abs(shape, dtype):
>       inp = torch.randn(shape, dtype=dtype, device="cuda")

tests/test_unary_pointwise_ops.py:19: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    def _lazy_init():
        global _initialized, _queued_calls
        if is_initialized() or hasattr(_tls, "is_initializing"):
            return
        with _initialization_lock:
            # We be double-checked locking, boys!  This is OK because
            # the above test was GIL protected anyway.  The inner test
            # is for when a thread blocked on some other thread which was
            # doing the initialization; when they get the lock, they will
            # find there is nothing left to do.
            if is_initialized():
                return
            # It is important to prevent other threads from entering _lazy_init
            # immediately, while we are still guaranteed to have the GIL, because some
            # of the C calls we make below will release the GIL
            if _is_in_bad_fork():
                raise RuntimeError(
                    "Cannot re-initialize CUDA in forked subprocess. To use CUDA with "
                    "multiprocessing, you must use the 'spawn' start method"
                )
            if not hasattr(torch._C, "_cuda_getDeviceCount"):
                raise AssertionError("Torch not compiled with CUDA enabled")
            if _cudart is None:
                raise AssertionError(
                    "libcudart functions unavailable. It looks like you have a broken build?"
                )
            # This function throws if there's a driver initialization error, no GPUs
            # are found or any other error occurs
            if "CUDA_MODULE_LOADING" not in os.environ:
                os.environ["CUDA_MODULE_LOADING"] = "LAZY"
>           torch._C._cuda_init()
E           RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx

/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py:298: RuntimeError
================================================================================================================= warnings summary =================================================================================================================
../../../../../../../../../usr/local/lib/python3.10/dist-packages/_pytest/config/__init__.py:1233
  /usr/local/lib/python3.10/dist-packages/_pytest/config/__init__.py:1233: PytestConfigWarning: Unknown config option: pythonpath
  
    self._warn_or_fail_if_strict(f"Unknown config option: {key}\n")

-- Docs: https://docs.pytest.org/en/stable/warnings.html
============================================================================================================= short test summary info ==============================================================================================================
FAILED tests/test_unary_pointwise_ops.py::test_accuracy_abs[dtype0-shape0] - RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx
=========================================================================================================== 1 failed, 1 warning in 8.21s ===========================================================================================================

It makes me confused that when I tried to run the reference on cpu, actually no CUDA device is needed.

error: failed to run reference on cpu

An issue occured when I tried to run reference on cpu following the instructions of README.md.

Here I show the reproduce steps as follows,
Installation:

git clone https://github.com/FlagOpen/FlagGems.git
cd FlagGems
pip install .

Codebase:

commit 95f5afaf0219c2085d1717e8cd85dff5cc7e3cdd (HEAD -> master)
Author: Clement Chan <[email protected]>
Date:   Thu Jul 4 15:11:54 2024 +0800

    [codegen] generate gsl(grid-stride-loop) style pointwise kernel  (#91)
    
    * generate gsl(grid-stride-loop) style pointwise kernel to avoid grid_size exceeding the max grid size
    * add device guard around kernel launch
    * avoid assign to a constexpr since we are inlined into a loop
    * remove redundant code for rank-0 case

Execute:
Run reference on cpu

cd tests
pytest test_unary_pointwise_ops.py::test_accuracy_abs[dtype0-shape0] --device cpu

Result:

collected 1 item                                                                                                                                                                                                                                   

test_unary_pointwise_ops.py F                                                                                                                                                                                                                [100%]

===================================================================================================================== FAILURES =====================================================================================================================
_________________________________________________________________________________________________________ test_accuracy_abs[dtype0-shape0] _________________________________________________________________________________________________________

shape = (1024, 1024), dtype = torch.float16

    @pytest.mark.parametrize("shape", POINTWISE_SHAPES)
    @pytest.mark.parametrize("dtype", FLOAT_DTYPES)
    def test_accuracy_abs(shape, dtype):
>       inp = torch.randn(shape, dtype=dtype, device="cuda")

test_unary_pointwise_ops.py:19: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    def _lazy_init():
        global _initialized, _queued_calls
        if is_initialized() or hasattr(_tls, "is_initializing"):
            return
        with _initialization_lock:
            # We be double-checked locking, boys!  This is OK because
            # the above test was GIL protected anyway.  The inner test
            # is for when a thread blocked on some other thread which was
            # doing the initialization; when they get the lock, they will
            # find there is nothing left to do.
            if is_initialized():
                return
            # It is important to prevent other threads from entering _lazy_init
            # immediately, while we are still guaranteed to have the GIL, because some
            # of the C calls we make below will release the GIL
            if _is_in_bad_fork():
                raise RuntimeError(
                    "Cannot re-initialize CUDA in forked subprocess. To use CUDA with "
                    "multiprocessing, you must use the 'spawn' start method"
                )
            if not hasattr(torch._C, "_cuda_getDeviceCount"):
>               raise AssertionError("Torch not compiled with CUDA enabled")
E               AssertionError: Torch not compiled with CUDA enabled

/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py:289: AssertionError
================================================================================================================= warnings summary =================================================================================================================
../../../../../../../../../../usr/local/lib/python3.10/dist-packages/_pytest/config/__init__.py:1233
  /usr/local/lib/python3.10/dist-packages/_pytest/config/__init__.py:1233: PytestConfigWarning: Unknown config option: pythonpath
  
    self._warn_or_fail_if_strict(f"Unknown config option: {key}\n")

-- Docs: https://docs.pytest.org/en/stable/warnings.html
============================================================================================================= short test summary info ==============================================================================================================
FAILED test_unary_pointwise_ops.py::test_accuracy_abs[dtype0-shape0] - AssertionError: Torch not compiled with CUDA enabled
=========================================================================================================== 1 failed, 1 warning in 1.36s ===========================================================================================================

NameError: name '_seed' is not defined

Reproduce steps:

git clone https://github.com/FlagOpen/FlagGems.git
cd FlagGems
pip install .
cd tests
pytest test_xx_ops.py
  1. error:
tests/test_binary_pointwise_ops.py:3: in <module>
    import flag_gems
/usr/local/lib/python3.10/dist-packages/flag_gems/__init__.py:4: in <module>
    from .ops import *  # noqa: F403
/usr/local/lib/python3.10/dist-packages/flag_gems/ops/__init__.py:21: in <module>
    from .dropout import native_dropout
/usr/local/lib/python3.10/dist-packages/flag_gems/ops/dropout.py:24: in <module>
    del _seed
E   NameError: name '_seed' is not defined

Solution:
I tried to workaround this problem by defining _seed and _offset, hacking FlagGems/src/flag_gems/ops/dropout.py from

try:
    tl_rand_dtype = tl.int64

    @triton.jit
    def _rand(seed, offset):
        offset = offset.to(tl_rand_dtype)

    _grid = (1,)
    _seed, _offset = philox_cuda_seed_offset(0)
    _rand[_grid](_seed, _offset)
except Exception:
    tl_rand_dtype = tl.int32

to

try:
    tl_rand_dtype = tl.int64

    @triton.jit
    def _rand(seed, offset):
        offset = offset.to(tl_rand_dtype)

    _seed = 0
    _offset = 0
    _grid = (1,)
    _seed, _offset = philox_cuda_seed_offset(0)
    _rand[_grid](_seed, _offset)
except Exception:
    tl_rand_dtype = tl.int32

, and it worked properly.

Is there any formal solution for it?

Infinite Recursion in triton.compile() due to flag_gems.use_gems()

Issue

There is an identified issue in the triton.compile() pipeline where the flag_gems.use_gems() is being activated all the time. This leads to an infinite recursion problem when certain functions are called to be compiled.

Specifically, if torch.ne.Scalar function is invoked during the triton.compile() pipeline, it will trigger another call to triton.compile() by lib.impl("ne.Scalar", ne_scalar, "CUDA") in FlagGems/src/flag_gems/__init__.py::enable(), causing an infinite loop and eventually a stack overflow.

where is rms backforwar

rms-norm only forward,

class RmsNorm(torch.autograd.Function):
@staticmethod
def forward(ctx, x, normalized_shape, weight, eps=1e-5):
logging.debug("GEMS LAYERNORM FORWARD")
dim = x.ndim - len(normalized_shape)
M = math.prod(x.shape[:dim])
N = math.prod(normalized_shape)

    BLOCK_SIZE = triton.next_power_of_2(N)
    x = x.contiguous()
    weight = weight.contiguous()
    y = torch.empty_like(x)

    rms_norm_kernel[M,](y, x, weight, N, 1, N, 1, N, eps, BLOCK_SIZE)
    return y

def rms_norm(x, normalized_shape, weight, eps=1e-5):
return RmsNorm.apply(x, normalized_shape, weight, eps)

some unitest failed

When I run all the unit tests and there are a lot of them failed, I want to know if this is a bug that hasn't been fixed yet?Frequently occurring errors like the following.
er
er2

Flag Gem性能数据

如题,想问下FlagGem的性能数据是否可以标识出来呢,像FlagAttention一样。
还请问算子在A10上效果相较于A100如何

[Performance] tile_size 8192 doesn't work well on A100 platform

Some pointwise operators with tile_size equal to 8192 perform worse than torch eager. For example, cosine, the average speedup on benchmark test cases is lower than 90%, while its performance was almost the same as torch eager before #91 increased tile_size to 8192.
Following are the latency data.

Operator cos Performance Test (torch.float16)
Size        Torch Latency (ms)   Gems Latency (ms)
--------------------------------------------------
64                     0.00752            0.009472
128                   0.007648             0.01008
192                    0.00768            0.009728
256                   0.009216            0.010752
320                   0.008416            0.010144
384                    0.00928            0.011136
448                   0.009024             0.01088
512                   0.009376             0.01056
576                   0.009344              0.0112
640                   0.009632            0.011904
704                    0.00992            0.011008
768                   0.010976              0.0112
832                   0.010368             0.01152
896                    0.01152            0.012608
960                   0.011744             0.01312
1024                  0.011808            0.014176
1088                  0.011008            0.013632
1152                  0.011264            0.013632
1216                  0.012256            0.014688
1280                  0.011584            0.013856
Operator cos Performance Test (torch.float32)
Size        Torch Latency (ms)   Gems Latency (ms)
--------------------------------------------------
64                    0.008608            0.010848
128                   0.008512            0.011264
192                   0.008256            0.011392
256                     0.0088            0.011232
320                   0.009376            0.010912
384                   0.009728            0.011264
448                   0.011072            0.012544
512                   0.011584            0.012864
576                   0.011168            0.012608
640                   0.011328            0.012608
704                    0.01264             0.01392
768                   0.011968             0.01312
832                   0.012992            0.013472
896                   0.012384             0.01616
960                   0.013632            0.015584
1024                   0.01296               0.016
1088                  0.013408            0.017024
1152                   0.01456            0.017152
1216                  0.015136            0.017248
1280                   0.01456            0.016608
Operator cos Performance Test (torch.bfloat16)
Size        Torch Latency (ms)   Gems Latency (ms)
--------------------------------------------------
64                    0.008416            0.009472
128                   0.008672            0.009632
192                   0.008256            0.010304
256                   0.008608             0.01088
320                   0.008544            0.011072
384                   0.009568            0.010112
448                   0.008992            0.010368
512                   0.009344            0.011456
576                    0.01024            0.010912
640                   0.010592            0.010976
704                   0.010016            0.011776
768                   0.011104            0.012224
832                   0.011328               0.012
896                   0.010624            0.014112
960                    0.01072            0.013376
1024                  0.011808            0.014656
1088                  0.011872            0.014464
1152                  0.012096            0.013952
1216                  0.012256            0.015104
1280                   0.01152            0.014944

Reduction Op Softmax, Cross Entropy Loss, and LogSoftmax Test Failed While Parsing JITFunc in Triton

Issue

Reduction Op Softmax, Cross Entropy Loss, and LogSoftmax Test Failed While Parsing JITFunc in Triton

Env

FlagGems Commit: 3c62c9c
Python 3.8.18
Torch: 2.3.1
Triton: 2.3.1
Pytest: 8.2.2

Command

python -m pytest -svvv tests/test_reduction_ops.py::test_accuracy_softmax'[0-dtype0-shape0]'
python -m pytest -svvv tests/test_reduction_ops.py::test_accuracy_cross_entropy_loss'[1-dtype0-shape0-mean-None-None]'
python -m pytest -svvv tests/test_reduction_ops.py::test_accuracy_log_softmax'[dtype0-shape0]'

Softmax Result

image

BT

image

image

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.