Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow postponing dataset integrity checks in NextGenHDFDataset #1323

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 84 additions & 13 deletions returnn/datasets/hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,9 +471,10 @@ class StreamParser(object):
Stream parser.
"""

def __init__(self, seq_names, stream):
def __init__(self, seq_names, stream, use_lazy_data_integrity_checks=False):
self.seq_names = seq_names
self.stream = stream
self.use_lazy_data_integrity_checks = use_lazy_data_integrity_checks

self.num_features = None
self.feature_type = None # 1 for sparse, 2 for dense
Expand Down Expand Up @@ -518,8 +519,10 @@ def __init__(self, *args, **kwargs):
if self.dtype is None:
self.dtype = str(seq_data.dtype)

assert seq_data.shape[1] == self.num_features
assert str(seq_data.dtype) == self.dtype
if self.use_lazy_data_integrity_checks:
break

self.check_data_integrity(seq_data, s)

self.feature_type = 2

Expand All @@ -528,7 +531,12 @@ def get_data(self, seq_name):
:param str seq_name:
:rtype: numpy.ndarray
"""
return self.stream["data"][seq_name][...]
data = self.stream["data"][seq_name][...]

if self.use_lazy_data_integrity_checks:
self.check_data_integrity(data, seq_name)

return data

def get_seq_length(self, seq_name):
"""
Expand All @@ -537,6 +545,18 @@ def get_seq_length(self, seq_name):
"""
return self.stream["data"][seq_name].shape[0]

def check_data_integrity(self, data, seq_name):
"""
:param numpy.ndarray data:
:param str seq_name:
"""

assert len(data.shape) == 2, f"shape length mismatch in {seq_name}: {data.shape} (should be 2-dimensional)"
assert (
self.num_features == data.shape[1]
), f"feature dim mismatch in {seq_name}: {data.shape[1]} (should be {self.num_features})"
assert self.dtype == str(data.dtype), f"dtype mismatch {seq_name}: {str(data.dtype)} (should be {self.dtype})"


class SparseStreamParser(StreamParser):
"""
Expand All @@ -552,7 +572,11 @@ def __init__(self, *args, **kwargs):

if self.dtype is None:
self.dtype = str(seq_data.dtype)
assert str(seq_data.dtype) == self.dtype

if self.use_lazy_data_integrity_checks:
break

self.check_data_integrity(seq_data, s)

self.num_features = self.stream["feature_names"].shape[0]
self.feature_type = 1
Expand All @@ -562,7 +586,12 @@ def get_data(self, seq_name):
:param str seq_name:
:rtype: numpy.ndarray
"""
return self.stream["data"][seq_name][:]
data = self.stream["data"][seq_name][:]

if self.use_lazy_data_integrity_checks:
self.check_data_integrity(data, seq_name)

return data

def get_seq_length(self, seq_name):
"""
Expand All @@ -571,6 +600,17 @@ def get_seq_length(self, seq_name):
"""
return self.stream["data"][seq_name].shape[0]

def check_data_integrity(self, data, seq_name):
"""
:param numpy.ndarray data:
:param str seq_name:
"""

assert len(data.shape) == 1, f"shape length mismatch in {seq_name}: {data.shape} (should be 2-dimensional)"
assert self.dtype == str(
data.dtype
), f"dtype mismatch in {seq_name}: {str(data.dtype)} (should be {self.dtype})"


class SegmentAlignmentStreamParser(StreamParser):
"""
Expand All @@ -585,10 +625,11 @@ def __init__(self, *args, **kwargs):

if self.dtype is None:
self.dtype = str(seq_data.dtype)
assert str(seq_data.dtype) == self.dtype

assert len(seq_data.shape) == 2
assert seq_data.shape[1] == 2
if self.use_lazy_data_integrity_checks:
break

self.check_data_integrity(seq_data, s)

self.num_features = self.stream["feature_names"].shape[0]
self.feature_type = 1
Expand All @@ -602,6 +643,9 @@ def get_data(self, seq_name):
length = self.get_seq_length(seq_name) // 2
segments = self.stream["data"][seq_name][:]

if self.use_lazy_data_integrity_checks:
self.check_data_integrity(segments, seq_name)

alignment = numpy.zeros((length, 2), dtype=self.dtype)
num_segments = segments.shape[0]
seg_end = 0
Expand All @@ -621,6 +665,18 @@ def get_seq_length(self, seq_name):
"""
return 2 * sum(self.stream["data"][seq_name][:, 1])

def check_data_integrity(self, data, seq_name):
"""
:param numpy.ndarray data:
:param str seq_name:
"""

assert len(data.shape) == 2, f"shape length mismatch in {seq_name}: {data.shape} (should be 2-dimensional)"
assert data.shape[1] == 2, f"feature dim mismatch in {seq_name}: {data.shape[1]} (should be 2-dimensional)"
assert self.dtype == str(
data.dtype
), f"dtype mismatch in {seq_name}: {str(data.dtype)} (should be {self.dtype})"


class NextGenHDFDataset(CachedDataset2):
"""
Expand All @@ -633,7 +689,7 @@ class NextGenHDFDataset(CachedDataset2):
"segment_alignment": SegmentAlignmentStreamParser,
}

def __init__(self, input_stream_name, files=None, **kwargs):
def __init__(self, input_stream_name, files=None, use_lazy_data_integrity_checks=False, **kwargs):
"""
:param str input_stream_name:
:param None|list[str] files:
Expand All @@ -649,6 +705,7 @@ def __init__(self, input_stream_name, files=None, **kwargs):
self.file_indices = []
self.seq_order = []
self.all_parsers = collections.defaultdict(list)
self.use_lazy_data_integrity_checks = use_lazy_data_integrity_checks

if files:
for fn in files:
Expand Down Expand Up @@ -684,7 +741,9 @@ def add_file(self, path):
)

parsers = {
name: NextGenHDFDataset.parsers[stream.attrs["parser"]](norm_seqs, stream)
name: NextGenHDFDataset.parsers[stream.attrs["parser"]](
norm_seqs, stream, use_lazy_data_integrity_checks=self.use_lazy_data_integrity_checks
)
for name, stream in cur_file["streams"].items()
}
for k, v in parsers.items():
Expand Down Expand Up @@ -807,7 +866,15 @@ class SiameseHDFDataset(CachedDataset2):
"segment_alignment": SegmentAlignmentStreamParser,
}

def __init__(self, input_stream_name, seq_label_stream="words", class_distribution=None, files=None, **kwargs):
def __init__(
self,
input_stream_name,
seq_label_stream="words",
class_distribution=None,
files=None,
use_lazy_data_integrity_checks=False,
**kwargs,
):
"""
:param str input_stream_name: name of a feature stream
:param str seq_label_stream: name of a stream with labels
Expand All @@ -833,6 +900,8 @@ def __init__(self, input_stream_name, seq_label_stream="words", class_distributi
self.target_to_seqs = {} # (int) class_index -> (string) sequence_names
self.curr_epoch_triplets = []
self.targets_stream = seq_label_stream
self.use_lazy_data_integrity_checks = use_lazy_data_integrity_checks

if files:
for fn in files:
self.add_file(fn)
Expand Down Expand Up @@ -872,7 +941,9 @@ def add_file(self, path):
)

parsers = {
name: SiameseHDFDataset.parsers[stream.attrs["parser"]](norm_seqs, stream)
name: SiameseHDFDataset.parsers[stream.attrs["parser"]](
norm_seqs, stream, use_lazy_data_integrity_checks=self.use_lazy_data_integrity_checks
)
for name, stream in cur_file["streams"].items()
} # name - stream name (words, features, orth_features)
for k, v in parsers.items():
Expand Down