Git Product home page Git Product logo

candle's Introduction

candle

discord server Latest version Documentation License

Candle is a minimalist ML framework for Rust with a focus on performance (including GPU support) and ease of use. Try our online demos: whisper, LLaMA2, T5, yolo, Segment Anything.

Get started

Make sure that you have candle-core correctly installed as described in Installation.

Let's see how to run a simple matrix multiplication. Write the following to your myapp/src/main.rs file:

use candle_core::{Device, Tensor};

fn main() -> Result<(), Box<dyn std::error::Error>> {
    let device = Device::Cpu;

    let a = Tensor::randn(0f32, 1., (2, 3), &device)?;
    let b = Tensor::randn(0f32, 1., (3, 4), &device)?;

    let c = a.matmul(&b)?;
    println!("{c}");
    Ok(())
}

cargo run should display a tensor of shape Tensor[[2, 4], f32].

Having installed candle with Cuda support, simply define the device to be on GPU:

- let device = Device::Cpu;
+ let device = Device::new_cuda(0)?;

For more advanced examples, please have a look at the following section.

Check out our examples

These online demos run entirely in your browser:

We also provide a some command line based examples using state of the art models:

  • LLaMA v1, v2, and v3: general LLM, includes the SOLAR-10.7B variant.
  • Falcon: general LLM.
  • Codegeex4: Code completion,code interpreter,web search,fuction calling,repository-level
  • GLM4: Open Multilingual Multimodal Chat LMs by THUDM
  • Gemma v1 and v2: 2b and 7b+/9b general LLMs from Google Deepmind.
  • RecurrentGemma: 2b and 7b Griffin based models from Google that mix attention with a RNN like state.
  • Phi-1, Phi-1.5, Phi-2, and Phi-3: 1.3b, 2.7b, and 3.8b general LLMs with performance on par with 7b models.
  • StableLM-3B-4E1T: a 3b general LLM pre-trained on 1T tokens of English and code datasets. Also supports StableLM-2, a 1.6b LLM trained on 2T tokens, as well as the code variants.
  • Mamba: an inference only implementation of the Mamba state space model.
  • Mistral7b-v0.1: a 7b general LLM with better performance than all publicly available 13b models as of 2023-09-28.
  • Mixtral8x7b-v0.1: a sparse mixture of experts 8x7b general LLM with better performance than a Llama 2 70B model with much faster inference.
  • StarCoder and StarCoder2: LLM specialized to code generation.
  • Qwen1.5: Bilingual (English/Chinese) LLMs.
  • RWKV v5 and v6: An RNN with transformer level LLM performance.
  • Replit-code-v1.5: a 3.3b LLM specialized for code completion.
  • Yi-6B / Yi-34B: two bilingual (English/Chinese) general LLMs with 6b and 34b parameters.
  • Quantized LLaMA: quantized version of the LLaMA model using the same quantization techniques as llama.cpp.

  • Stable Diffusion: text to image generative model, support for the 1.5, 2.1, SDXL 1.0 and Turbo versions.

  • Wuerstchen: another text to image generative model.

  • SegFormer: transformer based semantic segmentation model.
  • Whisper: speech recognition model.
  • EnCodec: high-quality audio compression model using residual vector quantization.
  • MetaVoice: foundational model for text-to-speech.
  • Parler-TTS: large text-to-speech model.
  • T5, Bert, JinaBert : useful for sentence embeddings.
  • DINOv2: computer vision model trained using self-supervision (can be used for imagenet classification, depth evaluation, segmentation).
  • VGG, RepVGG: computer vision models.
  • BLIP: image to text model, can be used to generate captions for an image.
  • CLIP: multi-model vision and language model.
  • TrOCR: a transformer OCR model, with dedicated submodels for hand-writing and printed recognition.
  • Marian-MT: neural machine translation model, generates the translated text from the input text.
  • Moondream: tiny computer-vision model that can answer real-world questions about images.

Run them using commands like:

cargo run --example quantized --release

In order to use CUDA add --features cuda to the example command line. If you have cuDNN installed, use --features cudnn for even more speedups.

There are also some wasm examples for whisper and llama2.c. You can either build them with trunk or try them online: whisper, llama2, T5, Phi-1.5, and Phi-2, Segment Anything Model.

For LLaMA2, run the following command to retrieve the weight files and start a test server:

cd candle-wasm-examples/llama2-c
wget https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/model.bin
wget https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/tokenizer.json
trunk serve --release --port 8081

And then head over to http://localhost:8081/.

Useful External Resources

  • candle-tutorial: A very detailed tutorial showing how to convert a PyTorch model to Candle.
  • candle-lora: Efficient and ergonomic LoRA implementation for Candle. candle-lora has
    out-of-the-box LoRA support for many models from Candle, which can be found here.
  • optimisers: A collection of optimisers including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.
  • candle-vllm: Efficient platform for inference and serving local LLMs including an OpenAI compatible API server.
  • candle-ext: An extension library to Candle that provides PyTorch functions not currently available in Candle.
  • candle-coursera-ml: Implementation of ML algorithms from Coursera's Machine Learning Specialization course.
  • kalosm: A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
  • candle-sampling: Sampling techniques for Candle.
  • gpt-from-scratch-rs: A port of Andrej Karpathy's Let's build GPT tutorial on YouTube showcasing the Candle API on a toy problem.
  • candle-einops: A pure rust implementation of the python einops library.

If you have an addition to this list, please submit a pull request.

Features

  • Simple syntax, looks and feels like PyTorch.
  • Backends.
    • Optimized CPU backend with optional MKL support for x86 and Accelerate for macs.
    • CUDA backend for efficiently running on GPUs, multiple GPU distribution via NCCL.
    • WASM support, run your models in a browser.
  • Included models.
    • Language Models.
      • LLaMA v1, v2, and v3 with variants such as SOLAR-10.7B.
      • Falcon.
      • StarCoder, StarCoder2.
      • Phi 1, 1.5, 2, and 3.
      • Mamba, Minimal Mamba
      • Gemma v1 2b and 7b+, v2 2b and 9b.
      • Mistral 7b v0.1.
      • Mixtral 8x7b v0.1.
      • StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B.
      • Replit-code-v1.5-3B.
      • Bert.
      • Yi-6B and Yi-34B.
      • Qwen1.5, Qwen1.5 MoE.
      • RWKV v5 and v6.
    • Quantized LLMs.
      • Llama 7b, 13b, 70b, as well as the chat and code variants.
      • Mistral 7b, and 7b instruct.
      • Mixtral 8x7b.
      • Zephyr 7b a and b (Mistral-7b based).
      • OpenChat 3.5 (Mistral-7b based).
    • Text to text.
      • T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
      • Marian MT (Machine Translation).
    • Text to image.
      • Stable Diffusion v1.5, v2.1, XL v1.0.
      • Wurstchen v2.
    • Image to text.
      • BLIP.
      • TrOCR.
    • Audio.
      • Whisper, multi-lingual speech-to-text.
      • EnCodec, audio compression model.
      • MetaVoice-1B, text-to-speech model.
      • Parler-TTS, text-to-speech model.
    • Computer Vision Models.
      • DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT, ConvNeXTv2, MobileOne, EfficientVit (MSRA), MobileNetv4, Hiera.
      • yolo-v3, yolo-v8.
      • Segment-Anything Model (SAM).
      • SegFormer.
  • File formats: load models from safetensors, npz, ggml, or PyTorch files.
  • Serverless (on CPU), small and fast deployments.
  • Quantization support using the llama.cpp quantized types.

How to use

Cheatsheet:

Using PyTorch Using Candle
Creation torch.Tensor([[1, 2], [3, 4]]) Tensor::new(&[[1f32, 2.], [3., 4.]], &Device::Cpu)?
Creation torch.zeros((2, 2)) Tensor::zeros((2, 2), DType::F32, &Device::Cpu)?
Indexing tensor[:, :4] tensor.i((.., ..4))?
Operations tensor.view((2, 2)) tensor.reshape((2, 2))?
Operations a.matmul(b) a.matmul(&b)?
Arithmetic a + b &a + &b
Device tensor.to(device="cuda") tensor.to_device(&Device::new_cuda(0)?)?
Dtype tensor.to(dtype=torch.float16) tensor.to_dtype(&DType::F16)?
Saving torch.save({"A": A}, "model.bin") candle::safetensors::save(&HashMap::from([("A", A)]), "model.safetensors")?
Loading weights = torch.load("model.bin") candle::safetensors::load("model.safetensors", &device)

Structure

FAQ

Why should I use Candle?

Candle's core goal is to make serverless inference possible. Full machine learning frameworks like PyTorch are very large, which makes creating instances on a cluster slow. Candle allows deployment of lightweight binaries.

Secondly, Candle lets you remove Python from production workloads. Python overhead can seriously hurt performance, and the GIL is a notorious source of headaches.

Finally, Rust is cool! A lot of the HF ecosystem already has Rust crates, like safetensors and tokenizers.

Other ML frameworks

  • dfdx is a formidable crate, with shapes being included in types. This prevents a lot of headaches by getting the compiler to complain about shape mismatches right off the bat. However, we found that some features still require nightly, and writing code can be a bit daunting for non rust experts.

    We're leveraging and contributing to other core crates for the runtime so hopefully both crates can benefit from each other.

  • burn is a general crate that can leverage multiple backends so you can choose the best engine for your workload.

  • tch-rs Bindings to the torch library in Rust. Extremely versatile, but they bring in the entire torch library into the runtime. The main contributor of tch-rs is also involved in the development of candle.

Common Errors

Missing symbols when compiling with the mkl feature.

If you get some missing symbols when compiling binaries/tests using the mkl or accelerate features, e.g. for mkl you get:

  = note: /usr/bin/ld: (....o): in function `blas::sgemm':
          .../blas-0.22.0/src/lib.rs:1944: undefined reference to `sgemm_' collect2: error: ld returned 1 exit status

  = note: some `extern` functions couldn't be found; some native libraries may need to be installed or have their path specified
  = note: use the `-l` flag to specify native libraries to link
  = note: use the `cargo:rustc-link-lib` directive to specify the native libraries to link with Cargo

or for accelerate:

Undefined symbols for architecture arm64:
            "_dgemm_", referenced from:
                candle_core::accelerate::dgemm::h1b71a038552bcabe in libcandle_core...
            "_sgemm_", referenced from:
                candle_core::accelerate::sgemm::h2cf21c592cba3c47 in libcandle_core...
          ld: symbol(s) not found for architecture arm64

This is likely due to a missing linker flag that was needed to enable the mkl library. You can try adding the following for mkl at the top of your binary:

extern crate intel_mkl_src;

or for accelerate:

extern crate accelerate_src;

Cannot run the LLaMA examples: access to source requires login credentials

Error: request error: https://huggingface.co/meta-llama/Llama-2-7b-hf/resolve/main/tokenizer.json: status code 401

This is likely because you're not permissioned for the LLaMA-v2 model. To fix this, you have to register on the huggingface-hub, accept the LLaMA-v2 model conditions, and set up your authentication token. See issue #350 for more details.

Missing cute/cutlass headers when compiling flash-attn

  In file included from kernels/flash_fwd_launch_template.h:11:0,
                   from kernels/flash_fwd_hdim224_fp16_sm80.cu:5:
  kernels/flash_fwd_kernel.h:8:10: fatal error: cute/algorithm/copy.hpp: No such file or directory
   #include <cute/algorithm/copy.hpp>
            ^~~~~~~~~~~~~~~~~~~~~~~~~
  compilation terminated.
  Error: nvcc error while compiling:

cutlass is provided as a git submodule so you may want to run the following command to check it in properly.

git submodule update --init

Compiling with flash-attention fails

/usr/include/c++/11/bits/std_function.h:530:146: error: parameter packs not expanded with ‘...’:

This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the NVCC_CCBIN environment variable.

env NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...

Linking error on windows when running rustdoc or mdbook tests

Couldn't compile the test.
---- .\candle-book\src\inference\hub.md - Using_the_hub::Using_in_a_real_model_ (line 50) stdout ----
error: linking with `link.exe` failed: exit code: 1181
//very long chain of linking
 = note: LINK : fatal error LNK1181: cannot open input file 'windows.0.48.5.lib'

Make sure you link all native libraries that might be located outside a project target, e.g., to run mdbook tests, you should run:

mdbook test candle-book -L .\target\debug\deps\ `
-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.42.2\lib `
-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.48.5\lib

Extremely slow model load time with WSL

This may be caused by the models being loaded from /mnt/c, more details on stackoverflow.

Tracking down errors

You can set RUST_BACKTRACE=1 to be provided with backtraces when a candle error is generated.

CudaRC error

If you encounter an error like this one called Result::unwrap()on anErr value: LoadLibraryExW { source: Os { code: 126, kind: Uncategorized, message: "The specified module could not be found." } } on windows. To fix copy and rename these 3 files (make sure they are in path). The paths depend on your cuda version. c:\Windows\System32\nvcuda.dll -> cuda.dll c:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin\cublas64_12.dll -> cublas.dll c:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin\curand64_10.dll -> curand.dll

candle's People

Contributors

bayedieng avatar danielclough avatar dependabot[bot] avatar drbh avatar ericlbuehler avatar evgenyigumnov avatar fl33tw00d avatar gabotechs avatar grzuy avatar ivarflakstad avatar janimo avatar jbochi avatar jorgeantonio21 avatar kgrewal1 avatar laurentmazare avatar llukas22 avatar lucasavila00 avatar milkfather avatar narsil avatar nkypy avatar olivierdehaene avatar patrickvonplaten avatar radames avatar rocketknight1 avatar santiagomed avatar shua avatar ssslakter avatar toluclassics avatar tomsanbear avatar v-espitalier avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

candle's Issues

Compilation failed for flash-attn kernels

Hi, amazing job done here!

I am trying the example provided for llama-2 multi-device inference using the following command:

cargo run --example llama_multiprocess --release --features "cuda nccl flash-attn"

which yields the following error messages during the building stage of candle-flash-attn crate:

  kernels/flash_fwd_hdim160_fp16_sm80.cu(26): here

  kernels/flash_fwd_kernel.h(325): error: argument list for class template "cute::Tensor" is missing
            detected during:
              instantiation of "void flash::compute_attn<Kernel_traits,Is_dropout,Is_causal,Is_even_N,Is_even_K,Return_softmax,Params>(const Params &) [with Kernel_traits=Flash_fwd_kernel_traits<160, 64, 64, 4, false, false, cutlass::half_t, Flash_kernel_traits<160, 64, 64, 4, cutlass::half_t>>, Is_dropout=true, Is_causal=true, Is_even_N=true, Is_even_K=true, Return_softmax=true, Params=Flash_fwd_params]"
  kernels/flash_fwd_launch_template.h(15): here
              instantiation of "void flash_fwd_kernel<Kernel_traits,Is_dropout,Is_causal,Is_even_N,Is_even_K,Return_softmax>(Flash_fwd_params) [with Kernel_traits=Flash_fwd_kernel_traits<160, 64, 64, 4, false, false, cutlass::half_t, Flash_kernel_traits<160, 64, 64, 4, cutlass::half_t>>, Is_dropout=true, Is_causal=true, Is_even_N=true, Is_even_K=true, Return_softmax=true]"
  kernels/flash_fwd_launch_template.h(34): here
              instantiation of "void run_flash_fwd<Kernel_traits,Is_dropout,Is_causal>(Flash_fwd_params &, cudaStream_t) [with Kernel_traits=Flash_fwd_kernel_traits<160, 64, 64, 4, false, false, cutlass::half_t, Flash_kernel_traits<160, 64, 64, 4, cutlass::half_t>>, Is_dropout=true, Is_causal=true]"
  kernels/flash_fwd_launch_template.h(155): here
              instantiation of "void run_mha_fwd_hdim160<T>(Flash_fwd_params &, cudaStream_t) [with T=cutlass::half_t]"
  kernels/flash_fwd_hdim160_fp16_sm80.cu(26): here

  Error limit reached.
  100 errors detected in the compilation of "kernels/flash_fwd_hdim160_fp16_sm80.cu".
  Compilation terminated.
  Error: nvcc error while compiling:

  # stdout


  # stderr

warning: build failed, waiting for other jobs to finish...

The environment of nvidia related things are as follow:

bin /home/xuzhangda/.mamba/envs/llm/lib/python3.9/site-packages/bitsandbytes/libbitsandbytes_cuda118.so
CUDA SETUP: CUDA runtime path found: /home/xuzhangda/.mamba/envs/llm/lib/libcudart.so.11.0
CUDA SETUP: Highest compute capability among GPUs detected: 8.6
CUDA SETUP: Detected CUDA version 118
CUDA SETUP: Loading binary /home/xuzhangda/.mamba/envs/llm/lib/python3.9/site-packages/bitsandbytes/libbitsandbytes_cuda118.so...

I am not sure why the cute::Tensor is missing here, is it because the cutlass submodule not up-to-date? Current commit hash is on: c4f6b8c. Thanks!

Cannot run examples with --features cuda option

CARGO_PROFILE_RELEASE_BUILD_OVERRIDE_DEBUG=true
warning: some crates are on edition 2021 which defaults to resolver = "2", but virtual workspaces default to resolver = "1"
note: to keep the current resolver, specify workspace.resolver = "1" in the workspace root's manifest
note: to use the edition 2021 resolver, specify workspace.resolver = "2" in the workspace root's manifest
Compiling libc v0.2.147
Compiling autocfg v1.1.0
Compiling crossbeam-utils v0.8.16
Compiling proc-macro2 v1.0.66
Compiling unicode-ident v1.0.11
Compiling rayon-core v1.11.0
Compiling memchr v2.5.0
Compiling libm v0.2.7
Compiling cfg-if v1.0.0
Compiling pkg-config v0.3.27
Compiling paste v1.0.14
Compiling serde v1.0.183
Compiling serde_derive v1.0.183
Compiling scopeguard v1.2.0
Compiling syn v1.0.109
Compiling serde_json v1.0.104
Compiling seq-macro v0.3.5
Compiling vcpkg v0.2.15
Compiling crc32fast v1.3.2
Compiling ident_case v1.0.1
Compiling strsim v0.10.0
Compiling fnv v1.0.7
Compiling thiserror v1.0.44
Compiling either v1.9.0
Compiling glob v0.3.1
Compiling openssl v0.10.56
Compiling rustls v0.21.6
Compiling anyhow v1.0.72
Compiling cudarc v0.9.13
Compiling portable-atomic v1.4.2
Compiling native-tls v0.2.11
Compiling esaxx-rs v0.1.8
Compiling adler v1.0.2
Compiling rustix v0.38.7
Compiling gimli v0.27.3
Compiling macro_rules_attribute-proc_macro v0.1.3
Compiling rustc-demangle v0.1.23
Compiling miniz_oxide v0.7.1
Compiling heck v0.4.1
Compiling flate2 v1.0.26
Compiling memoffset v0.9.0
Compiling crossbeam-epoch v0.9.15
Compiling num-traits v0.2.16
Compiling zip v0.6.6
Compiling crossbeam-channel v0.5.8
Compiling aho-corasick v1.0.2
Compiling object v0.31.1
Compiling nom v7.1.3
Compiling aho-corasick v0.7.20
Compiling quote v1.0.32
Compiling macro_rules_attribute v0.1.3
Compiling syn v2.0.28
Compiling crossbeam-deque v0.8.3
Compiling num_cpus v1.16.0
Compiling getrandom v0.2.10
Compiling dirs-sys v0.4.1
Compiling console v0.15.7
Compiling memmap2 v0.7.1
Compiling regex-automata v0.3.6
Compiling cc v1.0.82
Compiling dirs v5.0.1
Compiling rand_core v0.6.4
Compiling num-complex v0.4.3
Compiling rand_chacha v0.3.1
Compiling indicatif v0.17.6
Compiling rand v0.8.5
Compiling addr2line v0.20.0
Compiling rayon v1.7.0
Compiling is-terminal v0.4.9
Compiling ring v0.16.20
Compiling openssl-sys v0.9.91
Compiling rand_distr v0.4.3
Compiling backtrace v0.3.68
Compiling onig_sys v69.8.1
Compiling anstream v0.3.2
Compiling clap_builder v4.3.21
Compiling half v2.3.1
Compiling spm_precompiled v0.1.4
Compiling regex v1.9.3
Compiling darling_core v0.14.4
Compiling fancy-regex v0.10.0
Compiling candle-kernels v0.1.0 (/mnt/source1/djbGR/ruststuffs/candle/candle-kernels)
Compiling candle-gemm-common v0.15.5
Compiling rayon-cond v0.1.0
Compiling candle-gemm-f32 v0.15.5
Compiling candle-gemm-f64 v0.15.5
Compiling candle-gemm-c64 v0.15.5
Compiling candle-gemm-c32 v0.15.5
Compiling safetensors v0.3.2
Compiling candle-examples v0.1.0 (/mnt/source1/djbGR/ruststuffs/candle/candle-examples)
Compiling tracing-chrome v0.7.1
Compiling candle-gemm-f16 v0.15.5
error: failed to run custom build command for candle-kernels v0.1.0 (/mnt/source1/djbGR/ruststuffs/candle/candle-kernels)

Caused by:
process didn't exit successfully: /mnt/source1/djbGR/ruststuffs/candle/target/release/build/candle-kernels-e21ab5b8e8daaf0a/build-script-build (exit status: 101)
--- stdout
cargo:rerun-if-changed=build.rs
cargo:rustc-env=CUDA_INCLUDE_DIR=/usr/local/cuda/include
cargo:rerun-if-changed=src/
cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP
cargo:rustc-env=CUDA_COMPUTE_CAP=sm_61

--- stderr
src/compatibility.cuh(19): error: function "__hmax_nan(__half, __half)" has already been defined
attribute((device)) inline attribute((always_inline)) __half __hmax_nan(__half a, __half b) {
^

src/compatibility.cuh(22): error: function "__hmin_nan(__half, __half)" has already been defined
attribute((device)) inline attribute((always_inline)) __half __hmin_nan(__half a, __half b) {
^

src/compatibility.cuh(19): error: function "__hmax_nan(__half, __half)" has already been defined
attribute((device)) inline attribute((always_inline)) __half __hmax_nan(__half a, __half b) {
^

src/compatibility.cuh(22): error: function "__hmin_nan(__half, __half)" has already been defined
attribute((device)) inline attribute((always_inline)) __half __hmin_nan(__half a, __half b) {
^

src/compatibility.cuh(19): error: function "__hmax_nan(__half, __half)" has already been defined
attribute((device)) inline attribute((always_inline)) __half __hmax_nan(__half a, __half b) {
^

src/compatibility.cuh(22): error: function "__hmin_nan(__half, __half)" has already been defined
attribute((device)) inline attribute((always_inline)) __half __hmin_nan(__half a, __half b) {
^

src/compatibility.cuh(19): error: function "__hmax_nan(__half, __half)" has already been defined
attribute((device)) inline attribute((always_inline)) __half __hmax_nan(__half a, __half b) {
^

src/compatibility.cuh(22): error: function "__hmin_nan(__half, __half)" has already been defined
attribute((device)) inline attribute((always_inline)) __half __hmin_nan(__half a, __half b) {
^

src/compatibility.cuh(19): error: function "__hmax_nan(__half, __half)" has already been defined
attribute((device)) inline attribute((always_inline)) __half __hmax_nan(__half a, __half b) {
^

src/compatibility.cuh(22): error: function "__hmin_nan(__half, __half)" has already been defined
attribute((device)) inline attribute((always_inline)) __half __hmin_nan(__half a, __half b) {
^

2 errors detected in the compilation of "src/indexing.cu".
src/compatibility.cuh(19): error: function "__hmax_nan(__half, __half)" has already been defined
attribute((device)) inline attribute((always_inline)) __half __hmax_nan(__half a, __half b) {
^

src/compatibility.cuh(22): error: function "__hmin_nan(__half, __half)" has already been defined
attribute((device)) inline attribute((always_inline)) __half __hmin_nan(__half a, __half b) {
^

2 errors detected in the compilation of "src/affine.cu".
src/compatibility.cuh(19): error: function "__hmax_nan(__half, __half)" has already been defined
attribute((device)) inline attribute((always_inline)) __half __hmax_nan(__half a, __half b) {
^

src/compatibility.cuh(22): error: function "__hmin_nan(__half, __half)" has already been defined
attribute((device)) inline attribute((always_inline)) __half __hmin_nan(__half a, __half b) {
^

2 errors detected in the compilation of "src/cast.cu".
2 errors detected in the compilation of "src/reduce.cu".
2 errors detected in the compilation of "src/conv.cu".
src/compatibility.cuh(19): error: function "__hmax_nan(__half, __half)" has already been defined
attribute((device)) inline attribute((always_inline)) __half __hmax_nan(__half a, __half b) {
^

src/compatibility.cuh(22): error: function "__hmin_nan(__half, __half)" has already been defined
attribute((device)) inline attribute((always_inline)) __half __hmin_nan(__half a, __half b) {
^

2 errors detected in the compilation of "src/ternary.cu".
2 errors detected in the compilation of "src/unary.cu".
2 errors detected in the compilation of "src/binary.cu".
thread 'main' panicked at 'nvcc error while compiling "src/affine.cu":

stdout

stderr

', candle-kernels/build.rs:207:13
stack backtrace:
0: 0x557f8498d0b1 - std::backtrace_rs::backtrace::libunwind::trace::hb01a67340c9cfb71
at /rustc/39f42ad9e8430a8abb06c262346e89593278c515/library/std/src/../../backtrace/src/backtrace/libunwind.rs:93:5
1: 0x557f8498d0b1 - std::backtrace_rs::backtrace::trace_unsynchronized::h896aca561948c930
at /rustc/39f42ad9e8430a8abb06c262346e89593278c515/library/std/src/../../backtrace/src/backtrace/mod.rs:66:5
2: 0x557f8498d0b1 - std::sys_common::backtrace::_print_fmt::h8627be5b68fbde29
at /rustc/39f42ad9e8430a8abb06c262346e89593278c515/library/std/src/sys_common/backtrace.rs:65:5
3: 0x557f8498d0b1 - <std::sys_common::backtrace::_print::DisplayBacktrace as core::fmt::Display>::fmt::h1b7758da45f4cd22
at /rustc/39f42ad9e8430a8abb06c262346e89593278c515/library/std/src/sys_common/backtrace.rs:44:22
4: 0x557f849b282c - core::fmt::rt::Argument::fmt::h0eb38586043a01ca
at /rustc/39f42ad9e8430a8abb06c262346e89593278c515/library/core/src/fmt/rt.rs:138:9
5: 0x557f849b282c - core::fmt::write::h68b52f8aa598961e
at /rustc/39f42ad9e8430a8abb06c262346e89593278c515/library/core/src/fmt/mod.rs:1094:21
6: 0x557f8498949e - std::io::Write::write_fmt::hc5568929b662da92
at /rustc/39f42ad9e8430a8abb06c262346e89593278c515/library/std/src/io/mod.rs:1714:15
7: 0x557f8498cec5 - std::sys_common::backtrace::_print::h65aecbff12ca83c8
at /rustc/39f42ad9e8430a8abb06c262346e89593278c515/library/std/src/sys_common/backtrace.rs:47:5
8: 0x557f8498cec5 - std::sys_common::backtrace::print::hf75ac9d60598d247
at /rustc/39f42ad9e8430a8abb06c262346e89593278c515/library/std/src/sys_common/backtrace.rs:34:9
9: 0x557f8498e483 - std::panicking::default_hook::{{closure}}::hc2cb8da3be7476b0
at /rustc/39f42ad9e8430a8abb06c262346e89593278c515/library/std/src/panicking.rs:269:22
10: 0x557f8498e19d - std::panicking::default_hook::hefa49c86da66275b
at /rustc/39f42ad9e8430a8abb06c262346e89593278c515/library/std/src/panicking.rs:288:9
11: 0x557f8498ea09 - std::panicking::rust_panic_with_hook::hd4c3b0056ba96951
at /rustc/39f42ad9e8430a8abb06c262346e89593278c515/library/std/src/panicking.rs:705:13
12: 0x557f8498e907 - std::panicking::begin_panic_handler::{{closure}}::he487675683e9a525
at /rustc/39f42ad9e8430a8abb06c262346e89593278c515/library/std/src/panicking.rs:597:13
13: 0x557f8498d516 - std::sys_common::backtrace::__rust_end_short_backtrace::hcff58b9b81620321
at /rustc/39f42ad9e8430a8abb06c262346e89593278c515/library/std/src/sys_common/backtrace.rs:151:18
14: 0x557f8498e652 - rust_begin_unwind
at /rustc/39f42ad9e8430a8abb06c262346e89593278c515/library/std/src/panicking.rs:593:5
15: 0x557f848b9333 - core::panicking::panic_fmt::h1b81548733a03bd5
at /rustc/39f42ad9e8430a8abb06c262346e89593278c515/library/core/src/panicking.rs:67:14
16: 0x557f848c3323 - build_script_build::cuda::build_ptx::ha488acce3cd701b3
at /mnt/source1/djbGR/ruststuffs/candle/candle-kernels/build.rs:207:13
17: 0x557f848c0878 - build_script_build::main::h2523e6c20b65fa04
at /mnt/source1/djbGR/ruststuffs/candle/candle-kernels/build.rs:6:33
18: 0x557f848d40cb - core::ops::function::FnOnce::call_once::h385ddf31127d3e12
at /rustc/39f42ad9e8430a8abb06c262346e89593278c515/library/core/src/ops/function.rs:250:5
19: 0x557f848ccbae - std::sys_common::backtrace::__rust_begin_short_backtrace::h1cfd550c72c3e194
at /rustc/39f42ad9e8430a8abb06c262346e89593278c515/library/std/src/sys_common/backtrace.rs:135:18
20: 0x557f848e0130 - std::rt::lang_start::{{closure}}::h70dc5fa7783a03f7
at /rustc/39f42ad9e8430a8abb06c262346e89593278c515/library/std/src/rt.rs:166:18
21: 0x557f8498541b - core::ops::function::impls::<impl core::ops::function::FnOnce for &F>::call_once::h9eccf02cf11756f6
at /rustc/39f42ad9e8430a8abb06c262346e89593278c515/library/core/src/ops/function.rs:284:13
22: 0x557f8498541b - std::panicking::try::do_call::hc95b838862bbb45a
at /rustc/39f42ad9e8430a8abb06c262346e89593278c515/library/std/src/panicking.rs:500:40
23: 0x557f8498541b - std::panicking::try::h82935254d12a76fc
at /rustc/39f42ad9e8430a8abb06c262346e89593278c515/library/std/src/panicking.rs:464:19
24: 0x557f8498541b - std::panic::catch_unwind::h7fd9d11cd70fc350
at /rustc/39f42ad9e8430a8abb06c262346e89593278c515/library/std/src/panic.rs:142:14
25: 0x557f8498541b - std::rt::lang_start_internal::{{closure}}::h0ddb191e68b650a4
at /rustc/39f42ad9e8430a8abb06c262346e89593278c515/library/std/src/rt.rs:148:48
26: 0x557f8498541b - std::panicking::try::do_call::h17d4693c7a6e120c
at /rustc/39f42ad9e8430a8abb06c262346e89593278c515/library/std/src/panicking.rs:500:40
27: 0x557f8498541b - std::panicking::try::h684fc020e1305912
at /rustc/39f42ad9e8430a8abb06c262346e89593278c515/library/std/src/panicking.rs:464:19
28: 0x557f8498541b - std::panic::catch_unwind::h757da538db515116
at /rustc/39f42ad9e8430a8abb06c262346e89593278c515/library/std/src/panic.rs:142:14
29: 0x557f8498541b - std::rt::lang_start_internal::ha6b1625a1e9a4f5b
at /rustc/39f42ad9e8430a8abb06c262346e89593278c515/library/std/src/rt.rs:148:20
30: 0x557f848e010a - std::rt::lang_start::h0d1360f20fc735dd
at /rustc/39f42ad9e8430a8abb06c262346e89593278c515/library/std/src/rt.rs:165:17
31: 0x557f848c43fe - main
32: 0x7fd8be429d90 - __libc_start_call_main
at ./csu/../sysdeps/nptl/libc_start_call_main.h:58:16
33: 0x7fd8be429e40 - __libc_start_main_impl
at ./csu/../csu/libc-start.c:392:3
34: 0x557f848b9a15 - _start
35: 0x0 -

Connection Failed: tls connection init failed (os error 10054)

No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav
Error: request error: https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_jfk.wav: Connection Failed: tls connection init failed: The remote host forcibly closed an existing connection(os error 10054)

Caused by:
0: https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_jfk.wav: Connection Failed: tls connection init failed: The remote host forcibly closed an existing connection (os error 10054)
1: The remote host forcibly closed an existing connection (os error 10054)
error: process didn't exit successfully: target\release\examples\whisper.exe (exit code: 1)

Anybody knows how to fix this?

Specify the device only once

This looks like a great project!

I have a question: Why is it necessary to specify the device for every Tensor? Wouldn't it be possible to set the device once and then all allocations are made to that device?
This would work like a global allocator in Rust: https://doc.rust-lang.org/std/alloc/index.html

The drawback is that you can't use multiple backends at the same time easily.

What are your thoughts?

Lamma2 start a test server got error

Description

Folow example to start a test server for llama2 got error.

# README.MD
For llama2, run the following command to retrieve the weight files and start a test server:

cd candle-wasm-examples/llama2-c
wget https://karpathy.ai/llama2c/model.bin
wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin
trunk serve --release --public-url /candle-llama2/ --port 8081
And then browse to http://localhost:8081/candle-llama2.

Expected behavior

Server start sucessfully.

Current behavior

2023-08-10T08:44:41.267930Z  INFO 📦 starting build
2023-08-10T08:44:41.268424Z  INFO spawning asset pipelines
2023-08-10T08:44:41.268905Z ERROR ❌ error
error from HTML pipeline

Caused by:
    0: error getting canonical path for "/hosted/workspace/1_user/...../candle/candle-wasm-examples/llama2-c/tokenizer.json"
    1: No such file or directory (os error 2)
2023-08-10T08:44:41.269446Z  INFO 📡 serving static assets at -> /candle-llama2/
2023-08-10T08:44:41.269498Z  INFO 📡 server listening at http://127.0.0.1:8081

Steps to reproduce

git clone https://github.com/huggingface/candle.git

cd candle

cd candle-wasm-examples/llama2-c
wget https://karpathy.ai/llama2c/model.bin
wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin
trunk serve --release --public-url /candle-llama2/ --port 8081

Diagnosis

  • tokenizer.json not found
# https://github.com/huggingface/candle/blob/main/candle-wasm-examples/llama2-c/index.html
# line no. 7
 <link data-trunk rel="copy-file" href="tokenizer.json" />

Randn only generates positive numbers

I can't seem to get randn to generate any negative numbers. Any ideas what is happening here?

Code:

use anyhow::Result;
use candle_core::{Device, Tensor};

fn main() -> Result<()> {
    let n = 200;
    let t = Tensor::randn(0f32, 1f32, n, &Device::Cpu)?;

    let count = Tensor::sum_all(&Tensor::gt(&t, &Tensor::zeros_like(&t)?)?)?;

    println!("{count} out of {n} elements are > 0");

    Ok(())
}

Output:

[200]
Tensor[[], u8] out of 200 elements are > 0

Is `compatibility.cuh` taken from dfdx?

I was reading through some of the issues and came across #353 which mentions compatibility.cuh. I remember writing a file with the same name for dfdx.

Was the file taken from dfdx and changed?
I am not worried about attribution. I don't think it is necessary as per the relevant licenses.
Instead, I am interested whether anything was changed that dfdx could also benefit from.

Maybe there should be a separate library for cuda kernels so that both libraries could benefit from improvements and bug fixes. Let me know what you think.

Bigram Model

Hello there,

Newbie here, I am trying to reproduce "let's build GPT" lecture from Andrej Karpathy in candle. At 31 minutes mark in this video, he implements a Bigram model using embeddings.

This is my rust implementation,

#[derive(Debug)]
pub struct BigramLanguageModel {
    token_embedding_table: Embedding,
}

impl BigramLanguageModel {
    // Constructor
    pub fn new(vocab_size: usize) -> Result<Self> {
        let vb = candle_nn::VarBuilder::from_varmap(&candle_nn::VarMap::new(), DType::F32, &Device::Cpu);
        let token_embedding_table = embedding(vocab_size, vocab_size, vb)?;
        Ok(BigramLanguageModel {
            token_embedding_table,
        })
    }

    // Forward pass
    pub fn forward(&self, idx: &Tensor, targets: &Tensor) -> (Tensor, Tensor) {

        let logits = self.token_embedding_table.forward(idx);
        let logits = logits.unwrap();
        
        let shape = logits.shape().dims();
        let logits = logits.reshape(&[shape[0]*shape[1], shape[2]]).unwrap();
        
        println!("shape: {:?}", logits.shape());
        println!("targets shape: {:?}", targets.shape().dims()[0]);
        if targets.shape().dims()[0] != 1 {
            let targets = targets.reshape(&[shape[0]*shape[1]]).unwrap();
            let loss = cross_entropy(&logits, &targets).unwrap();
            (logits, loss)
        }else{
            let loss = Tensor::zeros((1, 1), DType::F32, &Device::Cpu).unwrap();
            (logits, loss)
        }
        
    }

But during training, the loss does not reduce from -ln(1/65). Is my implementation incorrect ?

Also, do you have any tips you could give to a newcomer to make the adoption easy ?

Steps to use on browser/WASM

Hello thank you for sharing this crate !!

Would it be possible to get the steps/code to reproduce the llama2.c web example https://laurentmazare.github.io/candle-llama2/ (compiling to wasm seems ok, but I am quite struggling to generate the corresponding JS glue-code to make it all work).

Again, thank you for your heavy work, really appreciated. :)

Parameter Grouping

Optimizers like Lars and Lamb do per-layer weight updates. Is there functionality to group certain parameters together? Pytorch equivalent would be nn.ModuleDict.

error: failed to run custom build command for candle-kernels

cargo run --example stable-diffusion --features cuda --features image -- --prompt "a rusty robot holding a fire torch"
Compiling candle-kernels v0.1.0 (HOME/candle/candle-kernels)
error: failed to run custom build command for candle-kernels v0.1.0 (HOME/candle/candle-kernels)
note: To improve backtraces for build dependencies, set the CARGO_PROFILE_DEV_BUILD_OVERRIDE_DEBUG=true environment variable to enable debug information generation.

Caused by:
process didn't exit successfully: HOME/candle/target/debug/build/candle-kernels-c1d996e014c93c27/build-script-build (exit status: 101)
--- stdout
cargo:rerun-if-changed=build.rs
cargo:rustc-env=CUDA_INCLUDE_DIR=/usr/include
cargo:rerun-if-changed=src/
cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP
cargo:rustc-env=CUDA_COMPUTE_CAP=sm_75

--- stderr
src/compatibility.cuh(11): error: identifier "__hmax" is undefined

....

stderr

', candle-kernels/build.rs:207:13
stack backtrace:
0: rust_begin_unwind
at /rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs:593:5
1: core::panicking::panic_fmt
at /rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/panicking.rs:67:14
2: build_script_build::cuda::build_ptx
3: build_script_build::main
4: core::ops::function::FnOnce::call_once

whisper example: Error: No such file or directory (os error 2) when binary copied to another directory

The whisper example runs fine from ./target:

$ ./target/release/examples/whisper --input ~/git/whisper-burn/output.wav
Running on CPU, to run on GPU, build this example with `--features cuda`
loaded mel filters [80, 201]
loaded wav data: Header { audio_format: 1, channel_count: 1, sampling_rate: 16000, bytes_per_second: 32000, bytes_per_sample: 2, bits_per_sample: 16 }
pcm data loaded 131413
loaded mel: [1, 80, 3000]
audio features: [1, 1500, 384]
3000: Segment { start: 0.0, duration: 30.0, dr: DecodingResult { tokens: [50257, 50363, 770, 318, 281, 1672, 3809, 8296, 1223, 1244, 1682, 910, 611, 314, 8296, 1223, 319, 616, 2342, 11, 314, 892, 428, 318, 703, 340, 561, 1210, 503, 13, 50763, 50256], text: " This is an example voice recording something might actually say if I recording something on my watch, I think this is how it would turn out.", avg_logprob: -0.37165226448053545, no_speech_prob: 0.09571712464094162, temperature: 0.0, compression_ratio: NaN } }, in 3.080288417s

But it crashes with an unfound file when I run from another directory:

$ RUST_BACKTRACE=full ./whisper --input ~/git/whisper-burn/output.wav
Running on CPU, to run on GPU, build this example with `--features cuda`
Error: No such file or directory (os error 2)

Stack backtrace:
   0: backtrace::backtrace::libunwind::trace
             at /Users/n8henrie/.cargo/registry/src/index.crates.io-6f17d22bba15001f/backtrace-0.3.68/src/backtrace/libunwind.rs:93:5
      backtrace::backtrace::trace_unsynchronized
             at /Users/n8henrie/.cargo/registry/src/index.crates.io-6f17d22bba15001f/backtrace-0.3.68/src/backtrace/mod.rs:66:5
   1: backtrace::backtrace::trace
             at /Users/n8henrie/.cargo/registry/src/index.crates.io-6f17d22bba15001f/backtrace-0.3.68/src/backtrace/mod.rs:53:14
   2: anyhow::backtrace::capture::Backtrace::create
             at /Users/n8henrie/.cargo/registry/src/index.crates.io-6f17d22bba15001f/anyhow-1.0.72/src/backtrace.rs:216:13
   3: anyhow::backtrace::capture::Backtrace::capture
             at /Users/n8henrie/.cargo/registry/src/index.crates.io-6f17d22bba15001f/anyhow-1.0.72/src/backtrace.rs:204:17
   4: anyhow::error::<impl core::convert::From<E> for anyhow::Error>::from
             at /Users/n8henrie/.cargo/registry/src/index.crates.io-6f17d22bba15001f/anyhow-1.0.72/src/error.rs:547:25
   5: <core::result::Result<T,F> as core::ops::try_trait::FromResidual<core::result::Result<core::convert::Infallible,E>>>::from_residual
             at /rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/result.rs:1961:27
   6: whisper::main
             at /Users/n8henrie/git/candle/candle-examples/examples/whisper/main.rs:304:32
   7: core::ops::function::FnOnce::call_once
             at /rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs:250:5
   8: std::sys_common::backtrace::__rust_begin_short_backtrace
             at /rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/sys_common/backtrace.rs:135:18
   9: std::rt::lang_start::{{closure}}
             at /rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs:166:18
  10: core::ops::function::impls::<impl core::ops::function::FnOnce<A> for &F>::call_once
             at /rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs:284:13
      std::panicking::try::do_call
             at /rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs:500:40
      std::panicking::try
             at /rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs:464:19
      std::panic::catch_unwind
             at /rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs:142:14
      std::rt::lang_start_internal::{{closure}}
             at /rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs:148:48
      std::panicking::try::do_call
             at /rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs:500:40
      std::panicking::try
             at /rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs:464:19
      std::panic::catch_unwind
             at /rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs:142:14
      std::rt::lang_start_internal
             at /rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs:148:20
  11: std::rt::lang_start
             at /rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs:165:17
  12: _main

Is there a way to run this as a standalone binary? Other files I could include via include_bytes! or the like?

Version 0.2.0 Features

  • Having Q4 (GGML like) inference
  • Stable Diffusion example (not necessarily highly optimized, but functional on CUDA + CPU)
    • Make it reasonably fast (CUDA at least :) )
  • More unit tests
  • More benchmarks points (M1, CPU, MKL, CUDA, WASM) x (Whisper, LLama2) at least
  • Optimize CPU backend
  • Better errors for older GPUs (#353)
  • More complete book (Training loop, adding Cuda debugging/profiling section, adding CPU tracing information at least)
  • VarBuilder reconsideration (We know it's not a great abstraction, we just don't have a good replacement idea yet)

Add community requirements

I don't think this project is at the point for needing a code of conduct, but maybe setting up a contribution guide and a set of requests for compilation errors will alleviate a lot of back and forth's for this project in the future. Open to suggestions, and if this is not a concern please feel free to close the issue.

Error while linking with accelerate on mac

I enabled accelerate features in my cargo.toml file

[dependencies]
candle = { path = "../candle/candle-core", version = "0.1.0", package = "candle-core", features=["accelerate"]}
candle-datasets = { path = "../candle/candle-datasets", version = "0.1.0" }
candle-nn = { path = "../candle/candle-nn", version = "0.1.0" }
candle-transformers = { path = "../candle/candle-transformers", version = "0.1.0" }
safetensors = "*"
serde = "*"
serde_json = "*"
num-traits = "*"
half = "*"
rand =  "*"
rand_chacha =  "*"

I am getting the following error

Undefined symbols for architecture arm64:
            "_dgemm_", referenced from:
                candle_core::accelerate::dgemm::h1b71a038552bcabe in libcandle_core-8c2363c344682bad.rlib(candle_core-8c2363c344682bad.3cylqiepw2bvor3t.rcgu.o)
            "_sgemm_", referenced from:
                candle_core::accelerate::sgemm::h2cf21c592cba3c47 in libcandle_core-8c2363c344682bad.rlib(candle_core-8c2363c344682bad.3cylqiepw2bvor3t.rcgu.o)
          ld: symbol(s) not found for architecture arm64

Am I doing something wrong ?

Good Video Tutorials

It is possible that this will be the first framework for a lot of people who are entering into the field. Is it possible to create a video tutorial series such as Andrej's for the newcomers ? This will improve the adaptability by a huge margin.

Potentially incorrect outputs for q4_0 llama 7B

When running the ggml example for llama, using llama-2-7b.ggmlv3.q4_0.bin, I get the following output:

My favorite theorem is 100% of the time. nobody knows what it means. everybody knows it. nobody knows it. nobody knows it. It's a theorem.
I's a theorem.
I's a theorem
I's a theorem
I'm a theorem
I'm a theorem
I'm a theorem
I'm a theorem
I'm a theorem
I'm a theorem
I'm a theorem
I'm a theorem

This doesn't seem correct, so there might be some mismatches in the operations. I'm not using any temperature setting

Error while building on arm64 due to candle-gemm-f16

I'm using candle-core and candle-nn as dependencies and I can not build my project on an arm64 machine. (version 0.1.1)

I have just added the final log part.

All the errors seems to point to the file candle-gemm-f16/src/microkernel.rs.

...
note: instantiated into assembly here
   --> <inline asm>:1:2
    |
1   |     fmul v4.8h, v8.8h, v3.8h
    |     ^

error: instruction requires: fullfp16
   --> /code/vendor/candle-gemm-f16/src/microkernel.rs:364:18
    |
364 |                 "fmul {0:v}.8h, {1:v}.8h, {2:v}.8h",
    |                  ^
    |
note: instantiated into assembly here
   --> <inline asm>:1:2
    |
1   |     fmul v2.8h, v8.8h, v1.8h
    |     ^

error: instruction requires: fullfp16
   --> /code/vendor/candle-gemm-f16/src/microkernel.rs:364:18
    |
364 |                 "fmul {0:v}.8h, {1:v}.8h, {2:v}.8h",
    |                  ^
    |
note: instantiated into assembly here
   --> <inline asm>:1:2
    |
1   |     fmul v1.8h, v8.8h, v0.8h
    |     ^

...

error: could not compile `candle-gemm-f16` (lib) due to 1141 previous errors
warning: build failed, waiting for other jobs to finish...

Seems there is no way to disable f16 support, please let me know if I'm wrong. That will be nice as I'm not using it, but I don't know if there is another solution to compile the project for arm/arm64 devices.

Thank you for the help.

Cant run example - right access problem

Execute: cargo run --example llama

Have error:

Running on CPU, to run on GPU, build this example with --features cuda
loading the model weights from meta-llama/Llama-2-7b-hf
Error: request error: https://huggingface.co/meta-llama/Llama-2-7b-hf/resolve/main/tokenizer.json: status code 401

I have acc https://huggingface.co/ievnsk and create token and after create token file: C:\Users\igumn.cache\huggingface\token execute again. Have new error:

loading the model weights from meta-llama/Llama-2-7b-hf
Error: I/O error Клиент не обладает требуемыми правами. (os error 1314)

Caused by:
Клиент не обладает требуемыми правами. (os error 1314)

Stack backtrace:
0: backtrace::backtrace::dbghelp::trace
at C:\Users\igumn.cargo\registry\src\index.crates.io-6f17d22bba15001f\backtrace-0.3.68\src\backtrace\dbghelp.rs:98
1: backtrace::backtrace::trace_unsynchronizedanyhow::backtrace::capture::impl$4::create::closure_env$0
at C:\Users\igumn.cargo\registry\src\index.crates.io-6f17d22bba15001f\backtrace-0.3.68\src\backtrace\mod.rs:66
2: backtrace::backtrace::traceanyhow::backtrace::capture::impl$4::create::closure_env$0
at C:\Users\igumn.cargo\registry\src\index.crates.io-6f17d22bba15001f\backtrace-0.3.68\src\backtrace\mod.rs:53
3: anyhow::backtrace::capture::Backtrace::create
at C:\Users\igumn.cargo\registry\src\index.crates.io-6f17d22bba15001f\anyhow-1.0.72\src\backtrace.rs:216
4: anyhow::backtrace::capture::Backtrace::capture
at C:\Users\igumn.cargo\registry\src\index.crates.io-6f17d22bba15001f\anyhow-1.0.72\src\backtrace.rs:204
5: anyhow::error::impl$1::from<enum2$<hf_hub::api::sync::ApiError> >
at C:\Users\igumn.cargo\registry\src\index.crates.io-6f17d22bba15001f\anyhow-1.0.72\src\error.rs:547
6: core::result::impl$27::from_residual<tuple$<>,enum2$<hf_hub::api::sync::ApiError>,anyhow::Error>
at /rustc/eb26296b556cef10fb713a38f3d16b9886080f26\library\core\src\result.rs:1961
7: llama::main
at .\candle-examples\examples\llama\main.rs:168

Error building candle-kernels

Hi,
Exciting project!
I'm having some issues building the candle-kernels on windows 10 with cuda 11.7. when trying to run the examples. Any thoughts?

   Compiling candle-kernels v0.1.0 (D:\candle\candle-kernels)
error: failed to run custom build command for `candle-kernels v0.1.0 (D:\candle\candle-kernels)`
note: To improve backtraces for build dependencies, set the CARGO_PROFILE_DEV_BUILD_OVERRIDE_DEBUG=true environment variable to enable debug information generation.

Caused by:
  process didn't exit successfully: `D:\candle\target\debug\build\candle-kernels-68d6aa5feaf84d2d\build-script-build` (exit code: 101)
  --- stdout
  cargo:rerun-if-changed=build.rs
  cargo:rustc-env=CUDA_INCLUDE_DIR=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\include
  cargo:rerun-if-changed=src/
  cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP
  cargo:rustc-env=CUDA_COMPUTE_CAP=sm_86
  D:\candle\candle-kernels\src\unary.cu(34): error: identifier "M_2_SQRTPI" is undefined
            detected during instantiation of "T gelu_fwd(T) [with T=__nv_bfloat16]"
  (54): here

  D:\candle\candle-kernels\src\unary.cu(34): error: identifier "M_SQRT1_2" is undefined
            detected during instantiation of "T gelu_fwd(T) [with T=__nv_bfloat16]"
  (54): here

  D:\candle\candle-kernels\src\unary.cu(34): error: identifier "M_2_SQRTPI" is undefined
            detected during instantiation of "T gelu_fwd(T) [with T=__half]"
  (68): here

  D:\candle\candle-kernels\src\unary.cu(34): error: identifier "M_SQRT1_2" is undefined
            detected during instantiation of "T gelu_fwd(T) [with T=__half]"
  (68): here

  D:\candle\candle-kernels\src\unary.cu(34): error: identifier "M_2_SQRTPI" is undefined
            detected during instantiation of "T gelu_fwd(T) [with T=float]"
  (92): here

  D:\candle\candle-kernels\src\unary.cu(34): error: identifier "M_SQRT1_2" is undefined
            detected during instantiation of "T gelu_fwd(T) [with T=float]"
  (92): here

  D:\candle\candle-kernels\src\unary.cu(34): error: identifier "M_2_SQRTPI" is undefined
            detected during instantiation of "T gelu_fwd(T) [with T=double]"
  (93): here

  D:\candle\candle-kernels\src\unary.cu(34): error: identifier "M_SQRT1_2" is undefined
            detected during instantiation of "T gelu_fwd(T) [with T=double]"
  (93): here

  8 errors detected in the compilation of "src/unary.cu".
  unary.cu

  --- stderr
  thread 'main' panicked at 'nvcc error while compiling "src\\unary.cu":

  # stdout


  ', candle-kernels\build.rs:207:13
  stack backtrace:
     0: std::panicking::begin_panic_handler
               at /rustc/a2b1646c597329d0a25efa3889b66650f65de1de/library\std\src\panicking.rs:578
     1: core::panicking::panic_fmt
               at /rustc/a2b1646c597329d0a25efa3889b66650f65de1de/library\core\src\panicking.rs:67
     2: build_script_build::cuda::build_ptx
     3: <[T] as core::fmt::Debug>::fmt
     4: core::ops::function::FnOnce::call_once
  note: Some details are omitted, run with `RUST_BACKTRACE=full` for a verbose backtrace.
PS D:\candle> cargo run --example whisper -- --input samples_jfk.wav
   Compiling candle-kernels v0.1.0 (D:\candle\candle-kernels)
error: failed to run custom build command for `candle-kernels v0.1.0 (D:\candle\candle-kernels)`
note: To improve backtraces for build dependencies, set the CARGO_PROFILE_DEV_BUILD_OVERRIDE_DEBUG=true environment variable to enable debug information generation.

Caused by:
  process didn't exit successfully: `D:\candle\target\debug\build\candle-kernels-68d6aa5feaf84d2d\build-script-build` (exit code: 101)
  --- stdout
  cargo:rerun-if-changed=build.rs
  cargo:rustc-env=CUDA_INCLUDE_DIR=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\include
  cargo:rerun-if-changed=src/
  cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP
  cargo:rustc-env=CUDA_COMPUTE_CAP=sm_86
  D:\candle\candle-kernels\src\unary.cu(34): error: identifier "M_2_SQRTPI" is undefined
            detected during instantiation of "T gelu_fwd(T) [with T=__nv_bfloat16]"
  (54): here

  D:\candle\candle-kernels\src\unary.cu(34): error: identifier "M_SQRT1_2" is undefined
            detected during instantiation of "T gelu_fwd(T) [with T=__nv_bfloat16]"
  (54): here

  D:\candle\candle-kernels\src\unary.cu(34): error: identifier "M_2_SQRTPI" is undefined
            detected during instantiation of "T gelu_fwd(T) [with T=__half]"
  (68): here

  D:\candle\candle-kernels\src\unary.cu(34): error: identifier "M_SQRT1_2" is undefined
            detected during instantiation of "T gelu_fwd(T) [with T=__half]"
  (68): here

  D:\candle\candle-kernels\src\unary.cu(34): error: identifier "M_2_SQRTPI" is undefined
            detected during instantiation of "T gelu_fwd(T) [with T=float]"
  (92): here

  D:\candle\candle-kernels\src\unary.cu(34): error: identifier "M_SQRT1_2" is undefined
            detected during instantiation of "T gelu_fwd(T) [with T=float]"
  (92): here

  D:\candle\candle-kernels\src\unary.cu(34): error: identifier "M_2_SQRTPI" is undefined
            detected during instantiation of "T gelu_fwd(T) [with T=double]"
  (93): here

  D:\candle\candle-kernels\src\unary.cu(34): error: identifier "M_SQRT1_2" is undefined
            detected during instantiation of "T gelu_fwd(T) [with T=double]"
  (93): here

  8 errors detected in the compilation of "src/unary.cu".
  unary.cu

  --- stderr
  thread 'main' panicked at 'nvcc error while compiling "src\\unary.cu":

  # stdout


  stack backtrace:
     0: std::panicking::begin_panic_handler
               at /rustc/a2b1646c597329d0a25efa3889b66650f65de1de/library\std\src\panicking.rs:578
     1: core::panicking::panic_fmt
               at /rustc/a2b1646c597329d0a25efa3889b66650f65de1de/library\core\src\panicking.rs:67
     2: build_script_build::cuda::build_ptx
     3: <[T] as core::fmt::Debug>::fmt
     4: core::ops::function::FnOnce::call_once
  note: Some details are omitted, run with `RUST_BACKTRACE=full` for a verbose backtrace.
PS D:\candle> $env:CARGO_PROFILE_DEV_BUILD_OVERRIDE_DEBUG=true      
PS D:\candle> cargo run --example whisper -- --input samples_jfk.wav
   Compiling candle-kernels v0.1.0 (D:\candle\candle-kernels)
error: failed to run custom build command for `candle-kernels v0.1.0 (D:\candle\candle-kernels)`
note: To improve backtraces for build dependencies, set the CARGO_PROFILE_DEV_BUILD_OVERRIDE_DEBUG=true environment variable to enable debug information generation.

Caused by:
  process didn't exit successfully: `D:\candle\target\debug\build\candle-kernels-68d6aa5feaf84d2d\build-script-build` (exit code: 101)
  --- stdout
  cargo:rerun-if-changed=build.rs
  cargo:rustc-env=CUDA_INCLUDE_DIR=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\include
  cargo:rerun-if-changed=src/
  cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP
  cargo:rustc-env=CUDA_COMPUTE_CAP=sm_86
  D:\candle\candle-kernels\src\unary.cu(34): error: identifier "M_2_SQRTPI" is undefined
            detected during instantiation of "T gelu_fwd(T) [with T=__nv_bfloat16]"
  (54): here

  D:\candle\candle-kernels\src\unary.cu(34): error: identifier "M_SQRT1_2" is undefined
            detected during instantiation of "T gelu_fwd(T) [with T=__nv_bfloat16]"
  (54): here

  D:\candle\candle-kernels\src\unary.cu(34): error: identifier "M_2_SQRTPI" is undefined
            detected during instantiation of "T gelu_fwd(T) [with T=__half]"
  (68): here

  D:\candle\candle-kernels\src\unary.cu(34): error: identifier "M_SQRT1_2" is undefined
            detected during instantiation of "T gelu_fwd(T) [with T=__half]"
  (68): here

  D:\candle\candle-kernels\src\unary.cu(34): error: identifier "M_2_SQRTPI" is undefined
            detected during instantiation of "T gelu_fwd(T) [with T=float]"
  (92): here

  D:\candle\candle-kernels\src\unary.cu(34): error: identifier "M_SQRT1_2" is undefined
            detected during instantiation of "T gelu_fwd(T) [with T=float]"
  (92): here

  D:\candle\candle-kernels\src\unary.cu(34): error: identifier "M_2_SQRTPI" is undefined
            detected during instantiation of "T gelu_fwd(T) [with T=double]"
  (93): here

  D:\candle\candle-kernels\src\unary.cu(34): error: identifier "M_SQRT1_2" is undefined
            detected during instantiation of "T gelu_fwd(T) [with T=double]"
  (93): here

  8 errors detected in the compilation of "src/unary.cu".
  unary.cu

  --- stderr
  thread 'main' panicked at 'nvcc error while compiling "src\\unary.cu":

  # stdout


  # stderr
  ', candle-kernels\build.rs:207:13
  stack backtrace:
     0: std::panicking::begin_panic_handler
               at /rustc/a2b1646c597329d0a25efa3889b66650f65de1de/library\std\src\panicking.rs:578
     1: core::panicking::panic_fmt
               at /rustc/a2b1646c597329d0a25efa3889b66650f65de1de/library\core\src\panicking.rs:67
     2: build_script_build::cuda::build_ptx
     3: <[T] as core::fmt::Debug>::fmt
     4: core::ops::function::FnOnce::call_once
  note: Some details are omitted, run with `RUST_BACKTRACE=full` for a verbose backtrace.

Quick Tensor creation functions

Converting Tensor to vector is pretty straightforward. Is there any easy way for taking a vector and converting it into a Tensor ?

Generating the same sentence embedding for `all-MiniLM-L6-v2` using `candle`

// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
println!("pooled embeddings {:?}", embeddings.shape());

The BERT sentence embedding example is using a pooling strategy that generates a different sentence embedding compared to using either the HuggingFace API, or alternative ways of running the model locally.

I would be interested in getting the same result, and I suspect it's in the pooling strategy that should be used.

Any pointers would be helpful.

Thanks!

llama convert_checkpont error

all_dicts = {k: v.numpy() for k, v in all_dicts.items()}

Traceback (most recent call last):                                                                                                                                                  20:17:17
  File "/home/zhengwu/workspace/Github/candle/candle-examples/examples/llama/convert_checkpoint.py", line 199, in <module>
    main()
  File "/home/zhengwu/workspace/Github/candle/candle-examples/examples/llama/convert_checkpoint.py", line 191, in main
    write_model(
  File "/home/zhengwu/workspace/Github/candle/candle-examples/examples/llama/convert_checkpoint.py", line 173, in write_model
    all_dicts = {k: v.numpy() for k, v in all_dicts.items()}
  File "/home/zhengwu/workspace/Github/candle/candle-examples/examples/llama/convert_checkpoint.py", line 173, in <dictcomp>
    all_dicts = {k: v.numpy() for k, v in all_dicts.items()}
TypeError: Got unsupported ScalarType BFloat16

Numpy not support bfloat16.

if convert bfloat16 to float32 , can't support float16 anymore.

all_dicts = {k: v.numpy() if v.dtype != torch.bfloat16else v.to(torch.float32).numpy() for k, v in all_dicts.items()}

maybe there have some more elegant method .

Add ONNX support

Any plans to support ONNX? An ONNX converter would be very helpful, but implementation details may vary.

I know that burn tries to generate Rust code from ONNX and then include it as a module. Codegen provides some performance benefits.

On the other hand, it is possible to create a model from ONNX at runtime, similar to tract.

Converting to ONNX also requires some effort, as branches and loops can introduce errors.

Compiling candle-examples v0.1.0 error

warning: some crates are on edition 2021 which defaults to resolver = "2", but virtual workspaces default to resolver = "1"
note: to keep the current resolver, specify workspace.resolver = "1" in the workspace root's manifest
note: to use the edition 2021 resolver, specify workspace.resolver = "2" in the workspace root's manifest
Compiling candle-examples v0.1.0 (~/github.com/huggingface/candle/candle-examples)
error[E0308]: mismatched types
--> candle-examples/examples/whisper/main.rs:174:21
|
174 | .decode(tokens.clone(), true)
| ------ ^^^^^^^^^^^^^^ expected &[u32], found Vec<u32>
| |
| arguments to this method are incorrect
|
= note: expected reference &[u32]
found struct Vec<u32>
note: method defined here
--> ~/.cargo/registry/src/index.crates.io-6f17d22bba15001f/tokenizers-0.13.4/src/tokenizer/mod.rs:814:12
|
814 | pub fn decode(&self, ids: &[u32], skip_special_tokens: bool) -> Result {
| ^^^^^^
help: consider borrowing here
|
174 | .decode(&tokens.clone(), true)
| +

For more information about this error, try rustc --explain E0308.
error: could not compile candle-examples (example "whisper") due to previous error

Why is `candle_nn` not re-exported by `candle_core`?

It took me a while to figure out I should add candle_nn as a dependency separately to get access to types such as VarBuilder. It was additionally confusing because the examples use crates with different names, such as candle instead of candle_core: https://github.com/huggingface/candle/blob/main/candle-examples/examples/mnist-training/main.rs#L7

Is the intent to have candle_nn as its own crate or is this an oversight? Please share your insight.

Cannot run llama example : access to source requires login credentials

cargo run --example llama --release
warning: some crates are on edition 2021 which defaults to resolver = "2", but virtual workspaces default to resolver = "1"
note: to keep the current resolver, specify workspace.resolver = "1" in the workspace root's manifest
note: to use the edition 2021 resolver, specify workspace.resolver = "2" in the workspace root's manifest
Finished release [optimized] target(s) in 0.17s
Running target/release/examples/llama
Running on CPU, to run on GPU, build this example with --features cuda
loading the model weights from meta-llama/Llama-2-7b-hf
Error: request error: https://huggingface.co/meta-llama/Llama-2-7b-hf/resolve/main/tokenizer.json: status code 401

Caused by:
https://huggingface.co/meta-llama/Llama-2-7b-hf/resolve/main/tokenizer.json: status code 401

Using your own fine tuned weights

How do you provide a path to your own fine tuned weights for all the other models but its clear how you do it for llama2 but not the other models.

Equivalent to pytorch tril / masked_fill

How would you implement following pytorch code into candle ?

wei = torch.ones(T,T)
tril = torch.tril(torch.ones(T,T))
wei = wei.masked_fill(tril==0, float('-inf'))
wei = F.softmax(wei, dim=-1)
out = wei @ x

Serialize/Deserialize Varmap in memory

Is there any way to serialize a Varmap instance into a u8 vector instead of writing to a file? If not, can it be added?

I'm using candle 0.1.0 and I'm unable to accomplish that.

Thank you in advance.

Contributing

Cool project, I would love to contribute in my free time. What needs to be done at this point?

Tensor::avg_pool2d not working as expected

I'm an ML beginner and a Rust beginner, and I don't know if there's something wrong with my usage or understanding, but the avg_pool2d function doesn't seem to work as expected!

main.rs

use candle_core::{Device, Tensor};
fn main() {
    let device = Device::Cpu;
    
    let data: Vec<f32> = vec![1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,];
    let t = Tensor::from_vec(data, (1, 1, 4, 4), &device).unwrap();
    
    let pool = t.avg_pool2d((2, 2), (2, 2)).unwrap();
    
    println!("{}", t.to_string());
    println!("{}", pool.to_string());
}

output

scale:0.25
sum:2
sum:2
sum:2
sum:2
[[[[1., 1., 1., 1.],
   [0., 0., 1., 1.],
   [1., 1., 1., 1.],
   [1., 1., 1., 1.]]]]
Tensor[[1, 1, 4, 4], f32]

[[[[0.5000, 0.5000],
   [0.5000, 0.5000]]]]
Tensor[[1, 1, 2, 2], f32]

License

Hi, is the License file missing?

Support for broadcast matmul

If I understand correctly, for now matmul needs batch dimensions to be exactly the same. Taken from candle_core::tensor doc:
b1, b2, ..., bi, m, k x b1, b2, ..., bi, k, n -> b1, b2, ..., bi, m, n
This is batch matrix multiplication (like torch.bmm).

It would be great to support broadcast, as in torch.matmul: "For example, if input is a (j×1×n×n) tensor and other is a (k×n×n) tensor, out will be a (j×k×n×n) tensor."

What is the minimal requirements of Intel MKL version?

Hello, Thanks for the great work!

I've got an error while compiling with the -features mkl option.
For example cargo install --git https://github.com/huggingface/candle.git candle-examples --examples bert -F mkl

The error said

  = note: /usr/bin/ld: /workspaces/Kuberian/searcher/target/debug/deps/libcandle_core-0afc8671b4dae8af.rlib(candle_core-0afc8671b4dae8af.candle_core.b11884625c01537d-cgu.13.rcgu.o): in function `candle_core::mkl::hgemm':
          /usr/local/cargo/git/checkouts/candle-0c2b4fa9e5801351/60cd155/candle-core/src/mkl.rs:162: undefined reference to `hgemm_'
          collect2: error: ld returned 1 exit status
          
  = note: some `extern` functions couldn't be found; some native libraries may need to be installed or have their path specified
  = note: use the `-l` flag to specify native libraries to link
  = note: use the `cargo:rustc-link-lib` directive to specify the native libraries to link with Cargo (see https://doc.rust-lang.org/cargo/reference/build-scripts.html#cargorustc-link-libkindname)

I initially thought that I did not install intel mkl libs properly, but I found that

  1. intel-mkl-src automatically downloads the required library from ghcr
  2. intel mkl 2020.01, which automatically downloaded from here, simply does not implement hgemm while they do implement sgemm and dgemm
  3. the latest version of intel mkl does implement hgemm

So I tried the latest version of intel mkl, but it seems intel-mkl-src does not support it.

I'm wondering which intel-mkl version do you use for your development environment?

Trainable batch normalization

I am trying to translate some code I wrote with tch-rs into candle as an experiment to see what the library is like.
It looks like I stumbled into a road-block almost immediately. I have a convolutional neural network made up of many residual blocks. Each residual block internally uses batch normalization.

In tch-rs, I could use nn::batch_norm_2d. Is batch normalization is not implemented by candle yet?

AMD hardware support for training and Inference

Hi,

This library is cool. Rust for deep learning is nice and great work from huggingface. I am curious to understand if there are plans for AMD hardware support for training and Inference.

Thanks

WebGPU support

Is WebGPU support on the roadmap as an alternative GPU-accelerated backend? This would be especially useful for inference on the web or for non-CUDA environments.

Equivalent of torch.add function

I am trying to write a code similar to the following

import torch
a = torch.ones((4, 8, 32))
b = torch.ones((8,32))
print(a+b)

For following code

use candle::*;
fn main() -> Result<()> {
    let a = Tensor::randn(0f32, 1., (4,8, 32), &Device::Cpu)?;
    let b = Tensor::randn(0f32, 1., (8, 32), &Device::Cpu)?;

    let c = a.add(&b)?;
    println!("{c}");
    Ok(())
}

I am getting this error. Am I missing something ?

ShapeMismatchBinaryOp { lhs: [4, 8, 32], rhs: [8, 32], op: "add" }

Link website to github project

It took me a while to notice that this repository had a book 😓. Would it be possible to link the website to the project in GitHub on the sidebar?

Support for signed integer

It would be useful to have the possibility to have a signed integer DType.
In candle_core::dtype there is already support for many float types and unsigned int, but no signed int option.
I suggest we add i32.

Support for quantisation

rustformers/llm supports Q2 to Q8 quants with various varieties. Would it be possible to quantize the existing models and run them in this repo ?

request error: https://huggingface.co/meta-llama/Llama-2-7b-hf/resolve/main/tokenizer.json

$ cargo run --example llama --release
Finished release [optimized] target(s) in 0.09s
Running target/release/examples/llama
Running on CPU, to run on GPU, build this example with --features cuda
loading the model weights from meta-llama/Llama-2-7b-hf
Error: request error: https://huggingface.co/meta-llama/Llama-2-7b-hf/resolve/main/tokenizer.json: status code 401

Caused by:
https://huggingface.co/meta-llama/Llama-2-7b-hf/resolve/main/tokenizer.json: status code 401

Stack backtrace:
0: llama::main
1: std::sys_common::backtrace::__rust_begin_short_backtrace
2: std::rt::lang_start::{{closure}}
3: core::ops::function::impls::<impl core::ops::function::FnOnce for &F>::call_once
at /rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs:284:13
std::panicking::try::do_call
at /rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs:500:40
std::panicking::try
at /rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs:464:19
std::panic::catch_unwind
at /rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs:142:14
std::rt::lang_start_internal::{{closure}}
at /rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs:148:48
std::panicking::try::do_call
at /rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs:500:40
std::panicking::try
at /rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs:464:19
std::panic::catch_unwind
at /rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs:142:14
std::rt::lang_start_internal
at /rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs:148:20
4: main
5: __libc_start_call_main
at ./csu/../sysdeps/nptl/libc_start_call_main.h:58:16
6: __libc_start_main_impl
at ./csu/../csu/libc-start.c:392:3
7: _start

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.