-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Advice on using a JIT function inside a transform? #495
Comments
Here's another version that uses from typing import SupportsIndex
import jax
import jax.numpy as jnp
import jax.random as random
import flax.linen as nn
from flax.jax_utils import prefetch_to_device
from tqdm import tqdm
from absl import logging
import grain.python as grain
class Model(nn.Module):
n_layers: int = 10
features: int = 10
@nn.compact
def __call__(self, x):
for _ in range(self.n_layers):
x = nn.Dense(features=self.features)(x)
x = nn.relu(x)
return x
Model = nn.vmap(Model, variable_axes={'params': None}, split_rngs={'params': False})
B = 4
IN_FEATURES = 100
N_LAYERS = 20
FEATURES = 20
dummy_input = jnp.zeros(shape=(B, IN_FEATURES))
model = Model(n_layers=N_LAYERS, features=FEATURES)
params = model.init({'params': random.key(0), 'rng_stream': random.key(1)}, dummy_input)['params']
print(model.tabulate({'params': random.key(0), 'rng_stream': random.key(1)}, dummy_input))
@jax.jit
def jit_batch_inference(x):
return model.apply({'params': params}, x)
class DataSimpleSource(grain.RandomAccessDataSource):
def __init__(self, num_steps):
self._num_steps = num_steps
def __len__(self) -> int:
return self._num_steps
def __getitem__(self, record_key: SupportsIndex):
record_key = int(record_key)
return random.uniform(random.key(record_key), shape=(IN_FEATURES,))
class JITBatchTransform(grain.MapTransform):
def map(self, batch: jnp.ndarray):
assert batch.ndim == 2
assert batch.shape == (B, IN_FEATURES)
x = jit_batch_inference(batch)
return x
if __name__ == '__main__':
logging.set_verbosity(logging.INFO)
num_steps = 1000000
worker_count = 0 # todo:
worker_buffer_size = 1 # todo:
prefetch_size = 2 # todo:
datasource = DataSimpleSource(num_steps=num_steps)
index_sampler = grain.IndexSampler(
num_records=len(datasource),
num_epochs=1,
shard_options=grain.NoSharding(),
shuffle=False,
seed=0,
)
pygrain_ops = [
# grain.BatchOperation(batch_size=B, drop_remainder=True), # deprecated alternative to grain.Batch
grain.Batch(batch_size=B, drop_remainder=True),
JITBatchTransform(),
]
batched_dataloader = grain.DataLoader(
data_source=datasource,
sampler=index_sampler,
operations=pygrain_ops,
worker_count=worker_count,
worker_buffer_size=worker_buffer_size,
enable_profiling=False, # todo:
)
def prepare_for_prefetch(xs):
local_device_count = jax.local_device_count()
def _prepare(x):
return x.reshape((local_device_count, -1) + x.shape[1:])
return jax.tree_util.tree_map(_prepare, xs)
# # similar to flax.jax_utils.replicate
batched_dataloader = map(prepare_for_prefetch, batched_dataloader)
if prefetch_size > 1:
# For prefetch to work, we must have already used prepare_for_prefetch
batched_dataloader = prefetch_to_device(batched_dataloader, size=prefetch_size)
for x in tqdm(batched_dataloader, total=num_steps, desc='Grain Dataset'):
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I want to put JAX jitted batched data augmentation inside my grain dataloader. I'm currently pretending this augmentation is a jitted batch inference of a Flax model. With
worker_count=0
, it smoothly processes about 390-400 batches per second. However, withworker_count=1
it becomes more sporadic and slower. I suppose havingworker_count=0
is acceptable, and I can use this to feed a model for training. However, it might be useful to have a spare batch ready withworker_count=1
andworker_buffer_size=2
, assuming my GPU has the memory for two of the jitted functions to be run in parallel. In this case it does, and I still see issues even when I make the Flax model much smaller. What is your advice?The text was updated successfully, but these errors were encountered: