Git Product home page Git Product logo

grain's Introduction

Grain - Feeding JAX Models

Grain is a library for reading data for training and evaluating JAX models. It's open source, fast and deterministic.

  • Installation: pip install grain
  • Docs
  • Grain is used by MaxText, a simple, performant and scalable JAX codebase for LLM.

grain's People

Contributors

aayooush avatar changm avatar charlesbeattie avatar claudiofantacci avatar conchylicultor avatar fabianp avatar gauravmishra avatar hamzamerzic avatar iindyk avatar jacek1727 avatar jimlinntu avatar jpuigcerver avatar marvin182 avatar nathanielmanistaatgoogle avatar peterjliu avatar ppwwyyxx avatar protoget avatar qwlouse avatar rchen152 avatar texasmichelle avatar untom avatar yangustc07 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

grain's Issues

AttributeError: 'NoneType' object has no attribute 'mmap'

This seems to happen at shutdown in any data pipeline that has NumPy arrays. Here is the full stacktrace:

INFO:absl:Process 0 exiting.
INFO:absl:Processing complete for process with worker_index 0
INFO:absl:Grain pool is exiting.
INFO:absl:Shutting down multiprocessing system.
INFO:absl:Shutting down multiprocessing system.
Exception ignored in: <function SharedMemoryArray.__del__ at 0x7e3b780a8a60>
Traceback (most recent call last):
  File "/home/black/micromamba/envs/trainpi/lib/python3.10/site-packages/grain/_src/python/shared_memory_array.py", line 139, in __del__
AttributeError: 'NoneType' object has no attribute 'mmap'
/home/black/micromamba/envs/trainpi/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 1 leaked shared_memory objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

Even if it's not an actual problem, it's a bit annoying because it overwhelms the logging output when you have many workers.

Here's the simplest possible repro:

import grain.python as grain
import numpy as np
import logging
logging.basicConfig(level=logging.INFO)

if __name__ == "__main__":
    class DataSource:
        def __len__(self):
            return 10

        def __getitem__(self, idx):
            return np.zeros(1)

    source = DataSource()
    sampler = grain.IndexSampler(
        num_records=len(source),
        num_epochs=1,
        shard_options=grain.NoSharding(),
        shuffle=False
    )
    loader = grain.DataLoader(
        data_source=source,
        sampler=sampler,
        worker_count=1,
    )

    for batch in loader:
        pass

Advice on using a JIT function inside a transform?

I want to put JAX jitted batched data augmentation inside my grain dataloader. I'm currently pretending this augmentation is a jitted batch inference of a Flax model. With worker_count=0, it smoothly processes about 390-400 batches per second. However, with worker_count=1 it becomes more sporadic and slower. I suppose having worker_count=0 is acceptable, and I can use this to feed a model for training. However, it might be useful to have a spare batch ready with worker_count=1 and worker_buffer_size=2, assuming my GPU has the memory for two of the jitted functions to be run in parallel. In this case it does, and I still see issues even when I make the Flax model much smaller. What is your advice?

from typing import SupportsIndex

import jax
import jax.numpy as jnp
import jax.random as random

import flax.linen as nn

from tqdm import tqdm
from absl import logging

import grain.python as grain


class Model(nn.Module):

    n_layers: int = 10
    features: int = 10

    @nn.compact
    def __call__(self, x):

        for _ in range(self.n_layers):
            x = nn.Dense(features=self.features)(x)
            x = nn.relu(x)

        return x


Model = nn.vmap(Model, variable_axes={'params': None}, split_rngs={'params': False})

B = 4
IN_FEATURES = 100
N_LAYERS = 20
FEATURES = 20

dummy_input = jnp.zeros(shape=(B, IN_FEATURES))

model = Model(n_layers=N_LAYERS, features=FEATURES)

params = model.init({'params': random.key(0), 'rng_stream': random.key(1)}, dummy_input)['params']

print(model.tabulate({'params': random.key(0), 'rng_stream': random.key(1)}, dummy_input))


@jax.jit
def jit_batch_inference(x):
    return model.apply({'params': params}, x)


class DataSimpleSource(grain.RandomAccessDataSource):

    def __init__(self, num_steps):

        self._num_steps = num_steps

    def __len__(self) -> int:
        return self._num_steps

    def __getitem__(self, record_key: SupportsIndex):
        record_key = int(record_key)
        return random.uniform(random.key(record_key), shape=(IN_FEATURES,))


class JITBatchTransform(grain.MapTransform):

    def map(self, batch: jnp.ndarray):
        assert batch.ndim == 2
        assert batch.shape == (B, IN_FEATURES)

        x = jit_batch_inference(batch)
        return x


if __name__ == '__main__':

    logging.set_verbosity(logging.INFO)

    num_steps = 1000000
    worker_count = 0  # todo:
    worker_buffer_size = 1  # todo:

    datasource = DataSimpleSource(num_steps=num_steps)

    index_sampler = grain.IndexSampler(
        num_records=len(datasource),
        num_epochs=1,
        shard_options=grain.NoSharding(),
        shuffle=False,
        seed=0,
    )

    pygrain_ops = [
        # grain.BatchOperation(batch_size=B, drop_remainder=True),  # deprecated alternative to grain.Batch
        grain.Batch(batch_size=B, drop_remainder=True),
        JITBatchTransform(),
    ]

    batched_dataloader = grain.DataLoader(
        data_source=datasource,
        sampler=index_sampler,
        operations=pygrain_ops,
        worker_count=worker_count,
        worker_buffer_size=worker_buffer_size,
        enable_profiling=False,  # todo:
    )

    for x in tqdm(batched_dataloader, total=num_steps, desc='Grain Dataset'):
        pass

Allow older versions of Orbax

Currently, if an older version of Orbax is used, then grain throws an error when loading:

[...]
     81 try:
     82   # Register the handler to be used with the new checkpointing API if Orbax is
     83   # present.
     84   import orbax.checkpoint as ocp  # pylint:disable=g-import-not-at-top # pytype:disable=import-error
---> 86   @ocp.args.register_with_handler(PyGrainCheckpointHandler, for_save=True)  # pytype:disable=wrong-arg-types
     87   @dataclasses.dataclass
     88   class PyGrainCheckpointSave(ocp.args.CheckpointArgs):
     89     item: Any
     91   @ocp.args.register_with_handler(PyGrainCheckpointHandler, for_restore=True)  # pytype:disable=wrong-arg-types
     92   @dataclasses.dataclass
     93   class PyGrainCheckpointRestore(ocp.args.CheckpointArgs):

AttributeError: module 'orbax.checkpoint' has no attribute 'args'

This is because this code relies on the latest version of Orbax.

No need to support the old version, but can you accept an AttributeError in the try block so that if the old version is installed, then we skip registering the handlers? This would help those of us who want to use grain for reasons other than checkpointing the data loader.

error

File "C:\Users\moriyantez\PycharmProjects\tf210\lib\site-packages\grain_src\python\lazy_dataset\transformations\shuffle.py", line 18, in
from grain._src.python.experimental.index_shuffle.python import index_shuffle_module as index_shuffle
ImportError: cannot import name 'index_shuffle_module' from 'grain._src.python.experimental.index_shuffle.python' (unknown location)

Minor bug in `grain.python.RangeDataSource.__getitem__()`

There may be a bug in grain.python.RangeDataSource.__getitem__().

Minimal code to reproduce:

import grain.python as pygrain 
x = pygrain.RangeDataSource(start=1,stop=10,step=2)
print(list(x))

Outputs:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
[<ipython-input-5-88dd2177f112>](https://localhost:8080/#) in <cell line: 3>()
      1 import grain.python as pygrain
      2 x = pygrain.RangeDataSource(start=1,stop=10,step=2)
----> 3 print(list(x))

[/usr/local/lib/python3.10/dist-packages/grain/_src/python/data_sources.py](https://localhost:8080/#) in __getitem__(self, record_key)
     92   def __getitem__(self, record_key: SupportsIndex) -> int:
     93     record_key = record_key.__index__()
---> 94     assert record_key >= 0 and record_key < self._len
     95     return self._start + record_key * self._step
     96 

Wheel marked as arch independent, while being arch dependent

I'm trying to use t5x on a GraceHopper computer that has an ARM based CPU.
T5x depends on grain-nightly and install it from pypi.
pip install grain-nightly work on ARM.
But the installed wheel fail at import as it try to load an .so that is build for x86.

Can the wheel be marked as dependent, so that it isn't found and not installed?

Here is a PR that fix the same issue in another project:
https://github.com/google/array_record/pull/79/files

I can't test it as I'm not able to build this project on x86 and on ARM.

Batching into shared memory is deprecated, but essential for performance

I was doing some profiling of my data pipeline and found that the Batch transformation was a severe bottleneck. Here are the critical lines in operations.py:

def stacking_function(*args):
      first_arg = np.asanyarray(args[0])
      shape, dtype = (len(args),) + first_arg.shape, first_arg.dtype
      if not self._use_shared_memory or dtype.hasobject:
        return np.stack(args)
      return np.stack(args, out=SharedMemoryArray(shape, dtype=dtype)).metadata

I found that self._use_shared_memory == True iff you used the deprecated grain.BatchOperation, rather than the "recommended" grain.Batch. And what do you know, switching to grain.BatchOperation gave me a 3x increase in throughput! This matches up with my intuition, because in the self._use_shared_memory == True branch, there is only one copy that goes directly into shared memory. But in the self._use_shared_memory == False branch, the np.stack will induce one copy into private memory, and then the later CopyNumPyArrayToSharedMemory transform performs an explicit second copy into shared memory. It's not too surprising that adding another copy of all of the pipeline's data could slow things down significantly.

Here comes the real problem -- I want to use grain through airio, which doesn't go through the standard DataLoader, but the much more complex lazy_dataset API. In lazy_dataset, batching is done through a different code path that does not have an option to enable this optimization. It always batches into private memory, and then the MultiprocessPrefetchLazyIterDataset does a second copy into shared memory.

I manually added a (slightly hacky) solution that enables batching directly into shared memory iff the batch operation is a parent of a MultiprocessPrefetchLazyIterDataset. Indeed, I saw a significant performance increase when using grain through airio. Is this something that could possibly be upstreamed into grain?

Slow data loading of large arrays from sharded dataset

Hi. First off I'd like to say that I'm unsure if I should post this issue here or in the array_record repo or in the tensorflow_datasets repo. But my goal here is to ultimately use grain in my project because I really like the idea of deterministic data loading and easily checkpointing the state, shuffle etc, and I'm obviously using JAX.

The problem is that I can't seem to load ArrayRecords fast with grain for my data. Using TFRecords with TFDS seems to be a lot faster, which isn't really what I'd expect. I suspect this might be an issue with my dataset consisting of large arrays.

Data

My dataset has around 50000 samples, where each sample is a numPy array of shape (100,500,99) and float32 dtype. Currently my dataset is in 50000 .npy files. I'm testing with a subset of 5000 from them.

Conversion to ArrayRecord

...

# arbitrarily chose 50 arrays per ArrayRecord cause I read online 1GB is ok for shard size
num_arrays_shard = 50
filenames = np.array(list(DATA_DIR.iterdir()))  # .npy filenames 
num_shards = len(filenames) // num_arrays_shard  # 100 shards for my subset of the dataset
group_size = 1

features = tfds.features.FeaturesDict({
    "arr": tfds.features.Tensor(shape=(100,500,99), dtype=np.float32)
})

def _write_arrayrecord_shard(shard: int):
  writer = array_record.ArrayRecordWriter(
    f"{GRAIN_DATA_DIR}/data.array_record-{shard:05d}-of-{num_shards - 1:05d}",
    f"group_size:{group_size}"
  )
  for fname in filenames[shard * num_arrays_shard : shard * num_arrays_shard + num_arrays_shard]:
    _arr = np.load(fname).astype(np.float32)
    tf_example = features.serialize_example({"arr": _arr})
    writer.write(tf_example)
  writer.close()

_ = process_map(_write_arrayrecord_shard, range(num_shards), max_workers=multiprocessing.cpu_count())

Loading with grain

import grain.python as grain

ds = grain.ArrayRecordDataSource([str(f) for f in (GRAIN_DATA_DIR).iterdir()])

@dataclasses.dataclass
class ParseFeatures(grain.MapTransform):
  def map(self, _features):
    return features.deserialize_example_np(_features)

sampler = grain.SequentialSampler(num_records=len(filenames), shard_options=grain.NoSharding())
loader = grain.DataLoader(
  data_source=ds,
  operations=[ParseFeatures(), grain.Batch(5)],
  sampler=sampler,
  worker_buffer_size=1000
)

The problem

I benchmark the resulting loader with tfds.benchmark(loader, batch_size=5) and I'm getting 3 examples per second, which seems really slow. Manually looping through the DataLoader and timing it is not any better, so I don't think this is a bug with the benchmark.

Reading each individual numPy file from the filesystem with numpy.load yields about 140 examples per second.

In an identical setup where I use tf.io.TFRecordWriter in my data conversion step, load it all as a TF Dataset and then benchmark it as follows:

ds = ds.batch(5, num_parallel_calls=5)
ds = ds.as_numpy_iterator()
tfds.benchmark(ds, num_iter=990, batch_size=5)

then I get roughly 130 samples per second, which isn't great but it's at least close to the naive solution of reading directly from the disk.

Without conversion to numPy / deserialisation, it's faster but not as fast as I'd expect. I'm getting around 53 examples per second without the ParseFeatures() operation. Also, I tried setting worker_count= in the DataLoader but I get an error "Processing Failed. Shutting down.". Though that is probably worth its own issue.

TLDR

I'm trying to load a few thousand big arrays (each float32, shape=(100,500,99)) from ArrayRecord files with Grain but it's slow. Slower than TFRecords and TFDataset and slower than just loading from disk directly.

Reproduction notebook here

Am I missing the point of Grain / is it just not a good fit for my use case? Or are some of my settings wrong (shard size / buffer size / serialisation strategy)?

I'm using grain_nightly==0.0.6 and array_record==0.5.0. I'm on a 1 TB NVMe SSD and have a Ryzen 9 7950X CPU with 64GB of DDR5 RAM on Linux.

Can not pip install the library

Hi,

I'm trying pip install grain but I get the following error:

ERROR: Could not find a version that satisfies the requirement grain (from versions: none)
ERROR: No matching distribution found for grain

Does someone have an idea why this is happening ?

Thank you

shard_options in IndexSampler

If shard_options is specified in IndexSampler, isn't the dataset being sharded twice?

DataLoader shards dataset if hasattr(self._sampler, "_shard_options") but sampler will shard it again with ShardLazyDataset() since that hasn't been disabled.

self._record_keys = lazy_dataset.ShardLazyDataset(

local_offset = self._shard_options.shard_index - self._global_num_workers # pytype: disable=attribute-error

whether to support streaming dataset?

Whether grain provides iter dataset, similar to torch.utils.data.IterableDataset, because when the total amount of original index files is large, such as 4T, they are difficult to load directly into memory.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.