Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Advice on using a JIT function inside a transform? #495

Open
DBraun opened this issue Jul 10, 2024 · 1 comment
Open

Advice on using a JIT function inside a transform? #495

DBraun opened this issue Jul 10, 2024 · 1 comment

Comments

@DBraun
Copy link

DBraun commented Jul 10, 2024

I want to put JAX jitted batched data augmentation inside my grain dataloader. I'm currently pretending this augmentation is a jitted batch inference of a Flax model. With worker_count=0, it smoothly processes about 390-400 batches per second. However, with worker_count=1 it becomes more sporadic and slower. I suppose having worker_count=0 is acceptable, and I can use this to feed a model for training. However, it might be useful to have a spare batch ready with worker_count=1 and worker_buffer_size=2, assuming my GPU has the memory for two of the jitted functions to be run in parallel. In this case it does, and I still see issues even when I make the Flax model much smaller. What is your advice?

from typing import SupportsIndex

import jax
import jax.numpy as jnp
import jax.random as random

import flax.linen as nn

from tqdm import tqdm
from absl import logging

import grain.python as grain


class Model(nn.Module):

    n_layers: int = 10
    features: int = 10

    @nn.compact
    def __call__(self, x):

        for _ in range(self.n_layers):
            x = nn.Dense(features=self.features)(x)
            x = nn.relu(x)

        return x


Model = nn.vmap(Model, variable_axes={'params': None}, split_rngs={'params': False})

B = 4
IN_FEATURES = 100
N_LAYERS = 20
FEATURES = 20

dummy_input = jnp.zeros(shape=(B, IN_FEATURES))

model = Model(n_layers=N_LAYERS, features=FEATURES)

params = model.init({'params': random.key(0), 'rng_stream': random.key(1)}, dummy_input)['params']

print(model.tabulate({'params': random.key(0), 'rng_stream': random.key(1)}, dummy_input))


@jax.jit
def jit_batch_inference(x):
    return model.apply({'params': params}, x)


class DataSimpleSource(grain.RandomAccessDataSource):

    def __init__(self, num_steps):

        self._num_steps = num_steps

    def __len__(self) -> int:
        return self._num_steps

    def __getitem__(self, record_key: SupportsIndex):
        record_key = int(record_key)
        return random.uniform(random.key(record_key), shape=(IN_FEATURES,))


class JITBatchTransform(grain.MapTransform):

    def map(self, batch: jnp.ndarray):
        assert batch.ndim == 2
        assert batch.shape == (B, IN_FEATURES)

        x = jit_batch_inference(batch)
        return x


if __name__ == '__main__':

    logging.set_verbosity(logging.INFO)

    num_steps = 1000000
    worker_count = 0  # todo:
    worker_buffer_size = 1  # todo:

    datasource = DataSimpleSource(num_steps=num_steps)

    index_sampler = grain.IndexSampler(
        num_records=len(datasource),
        num_epochs=1,
        shard_options=grain.NoSharding(),
        shuffle=False,
        seed=0,
    )

    pygrain_ops = [
        # grain.BatchOperation(batch_size=B, drop_remainder=True),  # deprecated alternative to grain.Batch
        grain.Batch(batch_size=B, drop_remainder=True),
        JITBatchTransform(),
    ]

    batched_dataloader = grain.DataLoader(
        data_source=datasource,
        sampler=index_sampler,
        operations=pygrain_ops,
        worker_count=worker_count,
        worker_buffer_size=worker_buffer_size,
        enable_profiling=False,  # todo:
    )

    for x in tqdm(batched_dataloader, total=num_steps, desc='Grain Dataset'):
        pass
@DBraun
Copy link
Author

DBraun commented Jul 20, 2024

Here's another version that uses flax.jax_utils.prefetch_to_device. Maybe this achieves what worker_buffer_size = 2 usually does, without actually having set worker_buffer_size = 2. However, I would still like a way that multiprocesses the DataSimpleSource. Multiple random arrays could be generated in parallel. This is just a stand-in for some other data loading process that I want to parallelize.

from typing import SupportsIndex

import jax
import jax.numpy as jnp
import jax.random as random

import flax.linen as nn
from flax.jax_utils import prefetch_to_device

from tqdm import tqdm
from absl import logging

import grain.python as grain


class Model(nn.Module):

    n_layers: int = 10
    features: int = 10

    @nn.compact
    def __call__(self, x):

        for _ in range(self.n_layers):
            x = nn.Dense(features=self.features)(x)
            x = nn.relu(x)

        return x


Model = nn.vmap(Model, variable_axes={'params': None}, split_rngs={'params': False})

B = 4
IN_FEATURES = 100
N_LAYERS = 20
FEATURES = 20

dummy_input = jnp.zeros(shape=(B, IN_FEATURES))

model = Model(n_layers=N_LAYERS, features=FEATURES)

params = model.init({'params': random.key(0), 'rng_stream': random.key(1)}, dummy_input)['params']

print(model.tabulate({'params': random.key(0), 'rng_stream': random.key(1)}, dummy_input))


@jax.jit
def jit_batch_inference(x):
    return model.apply({'params': params}, x)


class DataSimpleSource(grain.RandomAccessDataSource):

    def __init__(self, num_steps):

        self._num_steps = num_steps

    def __len__(self) -> int:
        return self._num_steps

    def __getitem__(self, record_key: SupportsIndex):
        record_key = int(record_key)
        return random.uniform(random.key(record_key), shape=(IN_FEATURES,))


class JITBatchTransform(grain.MapTransform):

    def map(self, batch: jnp.ndarray):
        assert batch.ndim == 2
        assert batch.shape == (B, IN_FEATURES)

        x = jit_batch_inference(batch)
        return x


if __name__ == '__main__':

    logging.set_verbosity(logging.INFO)

    num_steps = 1000000
    worker_count = 0  # todo:
    worker_buffer_size = 1  # todo:
    prefetch_size = 2  # todo:

    datasource = DataSimpleSource(num_steps=num_steps)

    index_sampler = grain.IndexSampler(
        num_records=len(datasource),
        num_epochs=1,
        shard_options=grain.NoSharding(),
        shuffle=False,
        seed=0,
    )

    pygrain_ops = [
        # grain.BatchOperation(batch_size=B, drop_remainder=True),  # deprecated alternative to grain.Batch
        grain.Batch(batch_size=B, drop_remainder=True),
        JITBatchTransform(),
    ]

    batched_dataloader = grain.DataLoader(
        data_source=datasource,
        sampler=index_sampler,
        operations=pygrain_ops,
        worker_count=worker_count,
        worker_buffer_size=worker_buffer_size,
        enable_profiling=False,  # todo:
    )

    def prepare_for_prefetch(xs):
        local_device_count = jax.local_device_count()

        def _prepare(x):
            return x.reshape((local_device_count, -1) + x.shape[1:])

        return jax.tree_util.tree_map(_prepare, xs)

    # # similar to flax.jax_utils.replicate
    batched_dataloader = map(prepare_for_prefetch, batched_dataloader)

    if prefetch_size > 1:
        # For prefetch to work, we must have already used prepare_for_prefetch
        batched_dataloader = prefetch_to_device(batched_dataloader, size=prefetch_size)

    for x in tqdm(batched_dataloader, total=num_steps, desc='Grain Dataset'):
        pass

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant