Git Product home page Git Product logo

recurrentgemma's Introduction

RecurrentGemma

RecurrentGemma is a family of open-weights Language Models by Google DeepMind, based on the novel Griffin architecture. This architecture achieves fast inference when generating long sequences by replacing global attention with a mixture of local attention and linear recurrences.

This repository contains the model implementation and examples for sampling and fine-tuning. We recommend most users adopt the Flax implementation, which is highly optimized. We also provide an un-optimized PyTorch implementation for reference.

Learn more about RecurrentGemma

Quick start

Installation

Using Poetry

RecurrentGemma uses Poetry for dependency management.

To install dependencies for the full project:

  • Checkout the code.
  • poetry install -E full to create a virtual environment with all dependencies.
  • poetry shell to activate the created virtual environment.

If you only need to install a subset of dependencies use one of the alternative library-specific commands below.

Using pip

If you want to use pip instead of Poetry, then create a virtual environment (run python -m venv recurrentgemma-demo and . recurrentgemma-demo/bin/activate) and:

  • Checkout the code.
  • pip install .[full]

Installing library-specific packages

JAX

To install dependencies only for the JAX pathway use: poetry install -E jax or (pip install .[jax]).

PyTorch

To install dependencies only for the PyTorch pathway use: poetry install -E torch (or pip install .[torch]).

Tests

To install dependencies required for running unit tests use: poetry install -E test (or pip install .[test])

Downloading the models

The model checkpoints are available through Kaggle at http://kaggle.com/models/google/recurrentgemma. Select either the Flax or PyTorch model variations, click the ⤓ button to download the model archive, then extract the contents to a local directory.

In both cases, the archive contains both the model weights and the tokenizer.

Running the unit tests

To run the tests, install the optional [test] dependencies (e.g. using pip install .[test]) from the root of the source tree, then:

pytest .

Examples

To run the example sampling script, pass the paths to the weights directory and tokenizer:

python examples/sampling_jax.py \
  --path_checkpoint=/path/to/archive/contents/2b/ \
  --path_tokenizer=/path/to/archive/contents/tokenizer.model

Colab notebook tutorials

To run these notebooks you will need to have a Kaggle account and first read and accept the Gemma license terms and conditions from the RecurrentGemma page. After this you can run the notebooks, which will automatically download the weights and tokenizer from there.

Currently different notebooks are supported under the following hardware:

Hardware T4 P100 V100 A100 TPUv2 TPUv3+
Sampling in Jax
Sampling in PyTorch
Finetuning in Jax

System Requirements

RecurrentGemma code can run on CPU, GPU or TPU. The code has been optimized for running on TPU using the Flax implementation, which contains a low level Pallas kernel to perform the linear scan in the recurrent layers.

Contributing

We are open to bug reports and issues. Please see CONTRIBUTING.md for details on PRs.

License

Copyright 2024 DeepMind Technologies Limited

This code is licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0.

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an AS IS BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

Disclaimer

This is not an official Google product.

recurrentgemma's People

Contributors

nush395 avatar samsmithgdm avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.