Hi. First off I'd like to say that I'm unsure if I should post this issue here or in the array_record repo or in the tensorflow_datasets repo. But my goal here is to ultimately use grain in my project because I really like the idea of deterministic data loading and easily checkpointing the state, shuffle etc, and I'm obviously using JAX.
The problem is that I can't seem to load ArrayRecords fast with grain for my data. Using TFRecords with TFDS seems to be a lot faster, which isn't really what I'd expect. I suspect this might be an issue with my dataset consisting of large arrays.
Data
My dataset has around 50000 samples, where each sample is a numPy array of shape (100,500,99)
and float32 dtype. Currently my dataset is in 50000 .npy files. I'm testing with a subset of 5000 from them.
Conversion to ArrayRecord
...
# arbitrarily chose 50 arrays per ArrayRecord cause I read online 1GB is ok for shard size
num_arrays_shard = 50
filenames = np.array(list(DATA_DIR.iterdir())) # .npy filenames
num_shards = len(filenames) // num_arrays_shard # 100 shards for my subset of the dataset
group_size = 1
features = tfds.features.FeaturesDict({
"arr": tfds.features.Tensor(shape=(100,500,99), dtype=np.float32)
})
def _write_arrayrecord_shard(shard: int):
writer = array_record.ArrayRecordWriter(
f"{GRAIN_DATA_DIR}/data.array_record-{shard:05d}-of-{num_shards - 1:05d}",
f"group_size:{group_size}"
)
for fname in filenames[shard * num_arrays_shard : shard * num_arrays_shard + num_arrays_shard]:
_arr = np.load(fname).astype(np.float32)
tf_example = features.serialize_example({"arr": _arr})
writer.write(tf_example)
writer.close()
_ = process_map(_write_arrayrecord_shard, range(num_shards), max_workers=multiprocessing.cpu_count())
Loading with grain
import grain.python as grain
ds = grain.ArrayRecordDataSource([str(f) for f in (GRAIN_DATA_DIR).iterdir()])
@dataclasses.dataclass
class ParseFeatures(grain.MapTransform):
def map(self, _features):
return features.deserialize_example_np(_features)
sampler = grain.SequentialSampler(num_records=len(filenames), shard_options=grain.NoSharding())
loader = grain.DataLoader(
data_source=ds,
operations=[ParseFeatures(), grain.Batch(5)],
sampler=sampler,
worker_buffer_size=1000
)
The problem
I benchmark the resulting loader with tfds.benchmark(loader, batch_size=5)
and I'm getting 3 examples per second, which seems really slow. Manually looping through the DataLoader and timing it is not any better, so I don't think this is a bug with the benchmark.
Reading each individual numPy file from the filesystem with numpy.load
yields about 140 examples per second.
In an identical setup where I use tf.io.TFRecordWriter
in my data conversion step, load it all as a TF Dataset and then benchmark it as follows:
ds = ds.batch(5, num_parallel_calls=5)
ds = ds.as_numpy_iterator()
tfds.benchmark(ds, num_iter=990, batch_size=5)
then I get roughly 130 samples per second, which isn't great but it's at least close to the naive solution of reading directly from the disk.
Without conversion to numPy / deserialisation, it's faster but not as fast as I'd expect. I'm getting around 53 examples per second without the ParseFeatures()
operation. Also, I tried setting worker_count=
in the DataLoader but I get an error "Processing Failed. Shutting down.". Though that is probably worth its own issue.
TLDR
I'm trying to load a few thousand big arrays (each float32, shape=(100,500,99)) from ArrayRecord files with Grain but it's slow. Slower than TFRecords and TFDataset and slower than just loading from disk directly.
Reproduction notebook here
Am I missing the point of Grain / is it just not a good fit for my use case? Or are some of my settings wrong (shard size / buffer size / serialisation strategy)?
I'm using grain_nightly==0.0.6
and array_record==0.5.0
. I'm on a 1 TB NVMe SSD and have a Ryzen 9 7950X CPU with 64GB of DDR5 RAM on Linux.