Git Product home page Git Product logo

enzyme-jax's People

Contributors

ftynse avatar ingomueller-net avatar itf avatar ivanradanov avatar martinjm97 avatar mofeing avatar vchuravy avatar wsmoses avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

enzyme-jax's Issues

Missing differentiation rules for `einsum`, `unary_einsum`

I'm getting the following error in EnzymeAD/Reactant.jl when trying to differentiate stablehlo.einsum:

julia> f = Reactant.compile(grad, (a′, b′))
error: could not compute the adjoint for this operation %2 = "stablehlo.einsum"(%1, %0) <{einsum_config = "ij,jk->ik"}> : (tensor<2x2xf64>, tensor<2x2xf64>) -> tensor<2x2xf64>
Pipeline failed

I open the issue here because I believe here is where the EnzymeMLIR rules are declared for HLO dialects right?

HLO Canonicalizations Todo list

To mark which ones we see worth doing, are doing / need to do

cc @ivanradanov @ftynse

  • iota reshape (becomes single iota)
    %195 = stablehlo.iota dim = 0 : tensor<1024xi32>
    %196 = stablehlo.reshape %195 : (tensor<1024xi32>) -> tensor<1x1x1024xi32>
  • reshape of pad (becomes diff pad)
 %175 = stablehlo.pad %174, %148, low = [0, 0, 1024, 0, 0], high = [0, 0, 0, 0, 0], interior = [0, 0, 0, 0, 0] : (tensor<1x3x1024x1x1xf32>, tensor<f32>) -> tensor<1x3x2048x1x1xf32>
    %176 = stablehlo.reshape %175 : (tensor<1x3x2048x1x1xf32>) -> tensor<1x3x2048xf32>
    
  • mul of pad with 0 (becomes pad of mul) 44026d4
    %175 = stablehlo.pad %174, %constant_0, low = [0, 0, 1024], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<1x3x1024xf32>, tensor<f32>) -> tensor<1x3x2048xf32>
    %177 = stablehlo.multiply %176, %112 : tensor<1x3x2048xf32>
  • broadcast of pad (becomes pad of broadcast)
    %175 = stablehlo.pad %174, %constant_0, low = [0, 0, 1024], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<1x3x1024xf32>, tensor<f32>) -> tensor<1x3x2048xf32>
    %189 = stablehlo.broadcast_in_dim %177, dims = [0, 2, 4] : (tensor<1x3x2048xf32>) -> tensor<1x1x3x1024x2048xf32>

Not working for 3.11

ImportError: Python version mismatch: module was compiled for Python 3.10, but the interpreter version is incompatible: 3.11.3 (main, Apr 19 2023, 18:51:09) [Clang 14.0.6 ].

`test.py::EnzymePipeline` fails on `enzyme_call.optimize_module`

Seems like we are passing a Python object here:

enzyme_call.optimize_module(mod, pipeline)
return

instead of the MlirModule obj from MLIR-C that enzyme_call is expecting:
m.def("optimize_module",
[](MlirModule cmod, const std::string &pass_pipeline) {
run_pass_pipeline(unwrap(cmod), pass_pipeline);
});
m.def("run_pass_pipeline",

Log

$> python test/test.py                                                                                                                                                                                                                                                             ✔   enzyme-jax 
Running tests under Python 3.12.3: /Users/mofeing/.pyenv/versions/enzyme-jax/bin/python
[ RUN      ] EnzymeJax.test_custom_cpp_kernel
I0529 13:40:36.518865 8541272768 xla_bridge.py:884] Unable to initialize backend 'cuda': 
I0529 13:40:36.518960 8541272768 xla_bridge.py:884] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0529 13:40:36.519386 8541272768 xla_bridge.py:884] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/Users/mofeing/.pyenv/versions/3.12.3/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Users/mofeing/.pyenv/versions/3.12.3/lib/libtpu.so' (no such file), '/opt/homebrew/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/opt/homebrew/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file), '/usr/local/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache)
[[43. 43. 43.]
 [43. 43. 43.]]
[[85. 85. 85.]
 [85. 85. 85.]]
[Array([[56., 56., 56., 56.],
       [56., 56., 56., 56.],
       [56., 56., 56., 56.],
       [56., 56., 56., 56.]], dtype=float32)]
(Array([[43., 43., 43.],
       [43., 43., 43.]], dtype=float32), Array([[85., 85., 85.],
       [85., 85., 85.]], dtype=float32), [Array([[56., 56., 56., 56.],
       [56., 56., 56., 56.],
       [56., 56., 56., 56.],
       [56., 56., 56., 56.]], dtype=float32)])
(Array([[1., 1., 1.],
       [1., 1., 1.]], dtype=float32), Array([[1., 1., 1.],
       [1., 1., 1.]], dtype=float32), [Array([[56., 56., 56., 56.],
       [56., 56., 56., 56.],
       [56., 56., 56., 56.],
       [56., 56., 56., 56.]], dtype=float32)])
(Array([[43., 43., 43.],
       [43., 43., 43.]], dtype=float32), Array([[85., 85., 85.],
       [85., 85., 85.]], dtype=float32), [Array([[56., 56., 56., 56.],
       [56., 56., 56., 56.],
       [56., 56., 56., 56.],
       [56., 56., 56., 56.]], dtype=float32)])
[[128. 128. 128.]
 [128. 128. 128.]]
[       OK ] EnzymeJax.test_custom_cpp_kernel
[ RUN      ] EnzymeJax.test_enzyme_mlir_jit
[12. 23. 34.]
[ 50.1  70.2 110.3]
[12. 23. 34.]
(Array([500., 700., 110.], dtype=float32), Array([500., 700., 110.], dtype=float32))
[       OK ] EnzymeJax.test_enzyme_mlir_jit
[ RUN      ] EnzymePipeline.test_pipeline
[  FAILED  ] EnzymePipeline.test_pipeline
======================================================================
ERROR: test_pipeline (__main__.EnzymePipeline.test_pipeline)
EnzymePipeline.test_pipeline
----------------------------------------------------------------------
ValueError: PyCapsule_GetPointer called with incorrect name

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/mofeing/Developer/Enzyme-JAX/test/test.py", line 16, in test_pipeline
    optimize_module(module)
  File "/Users/mofeing/.pyenv/versions/enzyme-jax/lib/python3.12/site-packages/enzyme_ad/jax/primitives.py", line 463, in optimize_module
    enzyme_call.optimize_module(mod, pipeline)
TypeError: optimize_module(): incompatible function arguments. The following argument types are supported:
    1. (arg0: MlirModule, arg1: str) -> None

Invoked with: <jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x116b147b0>, '\n            inline{default-pipeline=canonicalize max-iterations=4},\n            canonicalize,cse,\n            canonicalize,enzyme-hlo-generate-td{\n            patterns=compare_op_canon<16>;\nbroadcast_in_dim_op_canon<16>;\nconvert_op_canon<16>;\ndynamic_broadcast_in_dim_op_not_actually_dynamic<16>;\nchained_dynamic_broadcast_in_dim_canonicalization<16>;\ndynamic_broadcast_in_dim_all_dims_non_expanding<16>;\nnoop_reduce_op_canon<16>;\nempty_reduce_op_canon<16>;\ndynamic_reshape_op_canon<16>;\nget_tuple_element_op_canon<16>;\nreal_op_canon<16>;\nimag_op_canon<16>;\nget_dimension_size_op_canon<16>;\ngather_op_canon<16>;\nreshape_op_canon<16>;\nmerge_consecutive_reshapes<16>;\ntranspose_is_reshape<16>;\nzero_extent_tensor_canon<16>;\nreorder_elementwise_and_shape_op<16>;\n\ncse_broadcast_in_dim<16>;\ncse_slice<16>;\ncse_transpose<16>;\ncse_convert<16>;\ncse_pad<16>;\ncse_dot_general<16>;\ncse_reshape<16>;\ncse_mul<16>;\ncse_div<16>;\ncse_add<16>;\ncse_subtract<16>;\ncse_min<16>;\ncse_max<16>;\ncse_neg<16>;\ncse_concatenate<16>;\n\nconcatenate_op_canon<16>(1024);\nselect_op_canon<16>(1024);\nadd_simplify<16>;\nsub_simplify<16>;\nand_simplify<16>;\nmax_simplify<16>;\nmin_simplify<16>;\nor_simplify<16>;\nnegate_simplify<16>;\nmul_simplify<16>;\ndiv_simplify<16>;\nrem_simplify<16>;\npow_simplify<16>;\nsqrt_simplify<16>;\ncos_simplify<16>;\nsin_simplify<16>;\nnoop_slice<16>;\nconst_prop_through_barrier<16>;\nslice_slice<16>;\nshift_right_logical_simplify<16>;\npad_simplify<16>;\nnegative_pad_to_slice<16>;\ntanh_simplify<16>;\nexp_simplify<16>;\nslice_simplify<16>;\nconvert_simplify<16>;\nreshape_simplify<16>;\ndynamic_slice_to_static<16>;\ndynamic_update_slice_elim<16>;\nconcat_to_broadcast<16>;\nreduce_to_reshape<16>;\nbroadcast_to_reshape<16>;\ngather_simplify<16>;\niota_simplify<16>(1024);\nbroadcast_in_dim_simplify<16>(1024);\nconvert_concat<1>;\ndynamic_update_to_concat<1>;\nslice_of_dynamic_update<1>;\nslice_elementwise<1>;\nslice_pad<1>;\ndot_reshape_dot<1>;\nconcat_const_prop<1>;\nconcat_fuse<1>;\npad_reshape_pad<1>;\npad_pad<1>;\nconcat_push_binop_add<1>;\nconcat_push_binop_mul<1>;\nscatter_to_dynamic_update_slice<1>;\nreduce_concat<1>;\nslice_concat<1>;\n\nbin_broadcast_splat_add<1>;\nbin_broadcast_splat_subtract<1>;\nbin_broadcast_splat_div<1>;\nbin_broadcast_splat_mul<1>;\nreshape_iota<16>;\nslice_reshape_slice<1>;\ndot_general_simplify<16>;\ntranspose_simplify<16>;\nreshape_empty_broadcast<1>;\nadd_pad_pad_to_concat<1>;\nbroadcast_reshape<1>;\n\nslice_reshape_concat<1>;\nslice_reshape_elementwise<1>;\nslice_reshape_transpose<1>;\nslice_reshape_dot_general<1>;\nconcat_pad<1>;\n\nreduce_pad<1>;\nbroadcast_pad<1>;\n\nzero_product_reshape_pad<1>;\nmul_zero_pad<1>;\ndiv_zero_pad<1>;\n\nbinop_const_reshape_pad<1>;\nbinop_const_pad_add<1>;\nbinop_const_pad_subtract<1>;\nbinop_const_pad_mul<1>;\nbinop_const_pad_div<1>;\n\nslice_reshape_pad<1>;\nbinop_binop_pad_pad_add<1>;\nbinop_binop_pad_pad_mul<1>;\nbinop_pad_pad_add<1>;\nbinop_pad_pad_subtract<1>;\nbinop_pad_pad_mul<1>;\nbinop_pad_pad_div<1>;\nbinop_pad_pad_min<1>;\nbinop_pad_pad_max<1>;\n\nunary_pad_push_convert<1>;\nunary_pad_push_tanh<1>;\nunary_pad_push_exp<1>;\n\ntranspose_pad<1>;\n\ntranspose_dot_reorder<1>;\ndot_transpose<1>;\nconvert_convert_float<1>;\nconcat_to_pad<1>;\nconcat_appending_reshape<1>;\nreshape_iota<1>;\n\nbroadcast_reduce<1>;\nslice_dot_general<1>;\n\ndot_reshape_pad<1>;\npad_dot_general<1>(0);\n\ndot_reshape_pad<1>;\npad_dot_general<1>(1);\n            },\n            transform-interpreter,\n            enzyme-hlo-remove-transform\n        '

----------------------------------------------------------------------
Ran 3 tests in 27.736s

FAILED (errors=1)

Add support for `jacrev`, `jacfwd`, `hessian`, `vmap`

Extending the tests in

print(grads)
with

> x = jax.jacrev(add_one)(jnp.array([1., 2., 3.]), jnp.array([1., 2., 3.]))
NotImplementedError: Batching rule for 'enzyme_rev' not implemented
> x = jax.jacfwd(add_one)(jnp.array([1., 2., 3.]), jnp.array([1., 2., 3.]))
NotImplementedError: Batching rule for 'enzyme_fwd' not implemented
> x = jax.hessian(add_one)(jnp.array([1., 2., 3.]), jnp.array([1., 2., 3.]))
NotImplementedError: Differentiation rule for 'enzyme_aug' not implemented
> x = jax.jit(jax.vmap(lambda x: add_one(x, jnp.array([1., 2., 3.]))))(jnp.array([jnp.array([1., 2., 3.])]*5))
NotImplementedError: Batching rule for 'enzyme_primal' not implemented

Tracking issue for missing HLO derivatives

NOTE: Strikethrough ops are deliberately not annotated.

  • StableHLO
    • AddOp
    • AfterAllOp
    • AllGatherOp
    • AllReduceOp
    • AllToAllOp
    • AndOp
    • Atan2Op #90
    • BatchNormGradOp
    • BatchNormInferenceOp
    • BatchNormTrainingOp
    • BitcastConvertOp
    • BroadcastInDimOp
    • CaseOp
    • CbrtOp #90
    • CeilOp #90
    • CholeskyOp
    • ClampOp
    • CollectiveBroadcastOp
    • CollectivePermuteOp
    • CompareOp
    • ComplexOp #90
    • CompositeOp
    • ConcatenateOp
    • ConstantOp
    • ConvertOp
    • ConvolutionOp
    • CosineOp
    • ClzOp
    • CustomCallOp
    • DivOp
    • DotGeneralOp
    • DynamicBroadcastInDimOp
    • DynamicConvOp
    • DynamicGatherOp
    • DynamicIotaOp
    • DynamicPadOp
    • DynamicReshapeOp
    • DynamicSliceOp
    • DynamicUpdateSliceOp
    • ExpOp
    • Expm1Op #90
    • FftOp #90
    • FloorOp #90
    • GatherOp
    • GetDimensionSizeOp
    • GetTupleElementOp
    • IfOp
    • ImagOp
    • InfeedOp
    • IotaOp #90
    • IsFiniteOp #90
    • LogOp
    • Logp1Op #90
    • LogisticOp #90
    • MapOp
    • MaxOp
    • MinOp #90
    • MulOp
    • NegateOp
    • NotOp
    • OptimizationBarrierOp
    • OrOp
    • OutfeedOp
    • PadOp
    • PartitionIdOp
    • PopcntOp
    • PowOp
    • RealOp
    • RecvOp
    • ReduceOp
    • ReducePrecisionOp
    • ReduceScatterOp
    • ReduceWindowOp
    • RemainderOp
    • ReplicaIdOp
    • ReshapeOp
    • ReverseOp #90
    • RngOp #90
    • RngBitGeneratorOp #90
    • RoundOp #90
    • RoundNearestEvenOp #90
    • RsqrtOp
    • ScatterOp
    • SelectOp
    • SelectAndScatterOp
    • SendOp
    • ShiftLeftOp
    • ShiftRightArithmeticOp
    • ShiftRightLogicalOp
    • SignOp #90
    • SineOp
    • SliceOp
    • SortOp
    • SqrtOp
    • SubtractOp
    • TanhOp
    • TransposeOp
    • TriangularSolveOp
    • TupleOp
    • UniformDequantizeOp
    • UniformQuantizeOp
    • WhileOp
    • XorOp
    • Deprecated operations in StableHLO
      • BroadcastOp
      • CreateTokenOp
      • CrossReplicaSumOp
      • DotOp
      • EinsumOp
      • TorchIndexSelectOp
      • UnaryEinsumOp
  • CHLO
    • Binary Element-wise Operations
      • BroadcastAddOp
      • BroadcastAtan2Op
      • BroadcastDivOp
      • BroadcastMaxOp
      • BroadcastMinOp
      • BroadcastMulOp
      • BroadcastNextAfterOp
      • BroadcastPolygammaOp
      • BroadcastPowOp
      • BroadcastRemOp
      • BroadcastShiftLeftOp
      • BroadcastShiftRightArithmeticOp
      • BroadcastShiftRightLogicalOp
      • BroadcastSubOp
      • BroadcastZetaOp
    • Binary Logical Element-wise Operations
      • BroadcastAndOp
      • BroadcastOrOp
      • BroadcastXorOp
    • Non-broadcasting Binary Operations
      • NextAfterOp
      • PolygammaOp #90
      • ZetaOp
    • ComplexOp
    • Unary Element-wise Operations
      • AcosOp #90
      • AcoshOp #90
      • AsinOp #90
      • AsinhOp #90
      • AtanOp #90
      • AtanhOp #90
      • BesselI1eOp
      • ConjOp #90
      • CoshOp #90
      • SinhOp #90
      • TanOp #90
      • ConstantOp (shared with StableHLO_ConstantOp)
      • ConstantLikeOp
      • DigammaOp #90
      • ErfOp
      • ErfInvOp
      • ErfcOp
      • IsInfOp #90
      • IsNegInfOp #90
      • IsPosInfOp #90
      • LgammaOp
    • BroadcastCompareOp
    • BroadcastSelectOp
    • TopKOp

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.