You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
So I think if batch is shaped (B, H, W, C) like an image, then if B is a multiple of 8, then the data is nicely distributed among the 8 devices. Is Grain able to prepare a batch like this? I haven't been able to figure it out from looking at the ShardOptions.
The text was updated successfully, but these errors were encountered:
I'm using code similar to the 8-way batch data parallelism example here: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#way-batch-data-parallelism
So I think if
batch
is shaped (B, H, W, C) like an image, then if B is a multiple of 8, then the data is nicely distributed among the 8 devices. Is Grain able to prepare a batch like this? I haven't been able to figure it out from looking at theShardOptions
.The text was updated successfully, but these errors were encountered: