Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

In the Dataset tutorial, the conversion of MapDataset to IterDataset is not quite clear #640

Open
Tomas542 opened this issue Nov 29, 2024 · 0 comments

Comments

@Tomas542
Copy link

The infromation in tutorial next:
Any MapDataset can be turned into a IterDataset by calling to_iter_dataset. When possible this should happen late in the pipeline since it will restrict the transformations that can come after it (e.g. global shuffle must come before). This conversion by default skips None elements.
But calling map_dataset.to_iter_dataset() converts it to PrefetchIterDataset class. What the problem? This class is not iterable. Also we can't get state from it. So it's not easy to understand why you mentioned it that way, and why it wasn't used in the tutorial at all.
Example for reproduction:

# !pip install -q jax-ai-stack[grain]==2024.11.1
import chex
import grain.python as pygrain

class Source(pygrain.RandomAccessDataSource):
    def __init__(self, x:chex.Array, y:chex.Array) -> None:
        assert (len(x) == len(y)), "must be the same length"
        self.x = x
        self.y = y

    def __len__(self) -> int:
        return len(x)

    def __getitem__(self, idx: int) -> tuple[chex.Array]:
        return self.x[idx], self.y[idx]

x = range(10)
y = [0]*5 + [1]*5
data_source = Source(x, y)

dataset = (
    pygrain.MapDataset.source(data_source)
    .shuffle(seed=seed)
    .map(lambda x: x)
    .batch(batch_size=5)
)
# from tutorila, works fine
iter_dataset = iter(dataset)
print(iter_dataset.get_state())
print(next(iter_dataset))

# creates PrefetchIterDataset with to_iter_dataset()
iter_dataset2 = dataset.to_iter_dataset()
print(type(iter_dataset2))

# AttributeError: 'PrefetchIterDataset' object has no attribute 'get_state'
print(iter_dataset2.get_state())
# TypeError: 'PrefetchIterDataset' object is not an iterator
print(next(iter_dataset2))

# works fine
iter_dataset2 = iter(iter_dataset2)
print(iter_dataset2.get_state())
print(next(iter_dataset2))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant