Skip to content

Commit

Permalink
feat: Allow postponing dataset integrity checks to training time
Browse files Browse the repository at this point in the history
  • Loading branch information
NeoLegends committed May 5, 2023
1 parent e9a9920 commit 6e037bd
Showing 1 changed file with 73 additions and 13 deletions.
86 changes: 73 additions & 13 deletions returnn/datasets/hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,9 +471,10 @@ class StreamParser(object):
Stream parser.
"""

def __init__(self, seq_names, stream):
def __init__(self, seq_names, stream, use_lazy_data_integrity_checks=False):
self.seq_names = seq_names
self.stream = stream
self.use_lazy_data_integrity_checks = use_lazy_data_integrity_checks

self.num_features = None
self.feature_type = None # 1 for sparse, 2 for dense
Expand Down Expand Up @@ -518,8 +519,10 @@ def __init__(self, *args, **kwargs):
if self.dtype is None:
self.dtype = str(seq_data.dtype)

assert seq_data.shape[1] == self.num_features
assert str(seq_data.dtype) == self.dtype
if self.use_lazy_data_integrity_checks:
break

self.check_data_integrity(data, s)

self.feature_type = 2

Expand All @@ -528,7 +531,12 @@ def get_data(self, seq_name):
:param str seq_name:
:rtype: numpy.ndarray
"""
return self.stream["data"][seq_name][...]
data = self.stream["data"][seq_name][...]

if self.use_lazy_data_integrity_checks:
self.check_data_integrity(data, seq_name)

return data

def get_seq_length(self, seq_name):
"""
Expand All @@ -537,6 +545,13 @@ def get_seq_length(self, seq_name):
"""
return self.stream["data"][seq_name].shape[0]

def check_data_integrity(self, data, seq_name):
assert len(data.shape) == 2, f"shape length mismatch in {seq_name}: {data.shape} (should be 2-dimensional)"
assert (
self.num_features == data.shape[1]
), f"feature dim mismatch in {seq_name}: {data.shape[1]} (should be {self.num_features})"
assert self.dtype == str(data.dtype), f"dtype mismatch {seq_name}: {str(data.dtype)} (should be {self.dtype})"


class SparseStreamParser(StreamParser):
"""
Expand All @@ -552,7 +567,11 @@ def __init__(self, *args, **kwargs):

if self.dtype is None:
self.dtype = str(seq_data.dtype)
assert str(seq_data.dtype) == self.dtype

if self.use_lazy_data_integrity_checks:
break

self.check_data_integrity(seq_data, s)

self.num_features = self.stream["feature_names"].shape[0]
self.feature_type = 1
Expand All @@ -562,7 +581,12 @@ def get_data(self, seq_name):
:param str seq_name:
:rtype: numpy.ndarray
"""
return self.stream["data"][seq_name][:]
data = self.stream["data"][seq_name][:]

if self.use_lazy_data_integrity_checks:
self.check_data_integrity(data, seq_name)

return data

def get_seq_length(self, seq_name):
"""
Expand All @@ -571,6 +595,12 @@ def get_seq_length(self, seq_name):
"""
return self.stream["data"][seq_name].shape[0]

def check_data_integrity(self, data, seq_name):
assert len(data.shape) == 1, f"shape length mismatch in {seq_name}: {data.shape} (should be 2-dimensional)"
assert self.dtype == str(
data.dtype
), f"dtype mismatch in {seq_name}: {str(data.dtype)} (should be {self.dtype})"


class SegmentAlignmentStreamParser(StreamParser):
"""
Expand All @@ -585,10 +615,11 @@ def __init__(self, *args, **kwargs):

if self.dtype is None:
self.dtype = str(seq_data.dtype)
assert str(seq_data.dtype) == self.dtype

assert len(seq_data.shape) == 2
assert seq_data.shape[1] == 2
if self.use_lazy_data_integrity_checks:
break

self.check_data_integrity(seq_data, s)

self.num_features = self.stream["feature_names"].shape[0]
self.feature_type = 1
Expand All @@ -602,6 +633,9 @@ def get_data(self, seq_name):
length = self.get_seq_length(seq_name) // 2
segments = self.stream["data"][seq_name][:]

if self.use_lazy_data_integrity_checks:
self.check_data_integrity(segments, seq_name)

alignment = numpy.zeros((length, 2), dtype=self.dtype)
num_segments = segments.shape[0]
seg_end = 0
Expand All @@ -621,6 +655,17 @@ def get_seq_length(self, seq_name):
"""
return 2 * sum(self.stream["data"][seq_name][:, 1])

def check_data_integrity(self, data, seq_name):
assert (
len(segments.shape) == 2
), f"shape length mismatch in {seq_name}: {segments.shape} (should be 2-dimensional)"
assert (
segments.shape[1] == 2
), f"feature dim mismatch in {seq_name}: {segments.shape[1]} (should be 2-dimensional)"
assert self.dtype == str(
segments.dtype
), f"dtype mismatch in {seq_name}: {str(segments.dtype)} (should be {self.dtype})"


class NextGenHDFDataset(CachedDataset2):
"""
Expand All @@ -633,7 +678,7 @@ class NextGenHDFDataset(CachedDataset2):
"segment_alignment": SegmentAlignmentStreamParser,
}

def __init__(self, input_stream_name, files=None, **kwargs):
def __init__(self, input_stream_name, files=None, use_lazy_data_integrity_checks=False, **kwargs):
"""
:param str input_stream_name:
:param None|list[str] files:
Expand All @@ -649,6 +694,7 @@ def __init__(self, input_stream_name, files=None, **kwargs):
self.file_indices = []
self.seq_order = []
self.all_parsers = collections.defaultdict(list)
self.use_lazy_data_integrity_checks = use_lazy_data_integrity_checks

if files:
for fn in files:
Expand Down Expand Up @@ -684,7 +730,9 @@ def add_file(self, path):
)

parsers = {
name: NextGenHDFDataset.parsers[stream.attrs["parser"]](norm_seqs, stream)
name: NextGenHDFDataset.parsers[stream.attrs["parser"]](
norm_seqs, stream, use_lazy_data_integrity_checks=self.use_lazy_data_integrity_checks
)
for name, stream in cur_file["streams"].items()
}
for k, v in parsers.items():
Expand Down Expand Up @@ -807,7 +855,15 @@ class SiameseHDFDataset(CachedDataset2):
"segment_alignment": SegmentAlignmentStreamParser,
}

def __init__(self, input_stream_name, seq_label_stream="words", class_distribution=None, files=None, **kwargs):
def __init__(
self,
input_stream_name,
seq_label_stream="words",
class_distribution=None,
files=None,
use_lazy_data_integrity_checks=False,
**kwargs,
):
"""
:param str input_stream_name: name of a feature stream
:param str seq_label_stream: name of a stream with labels
Expand All @@ -833,6 +889,8 @@ def __init__(self, input_stream_name, seq_label_stream="words", class_distributi
self.target_to_seqs = {} # (int) class_index -> (string) sequence_names
self.curr_epoch_triplets = []
self.targets_stream = seq_label_stream
self.use_lazy_data_integrity_checks = use_lazy_data_integrity_checks

if files:
for fn in files:
self.add_file(fn)
Expand Down Expand Up @@ -872,7 +930,9 @@ def add_file(self, path):
)

parsers = {
name: SiameseHDFDataset.parsers[stream.attrs["parser"]](norm_seqs, stream)
name: SiameseHDFDataset.parsers[stream.attrs["parser"]](
norm_seqs, stream, use_lazy_data_integrity_checks=self.use_lazy_data_integrity_checks
)
for name, stream in cur_file["streams"].items()
} # name - stream name (words, features, orth_features)
for k, v in parsers.items():
Expand Down

0 comments on commit 6e037bd

Please sign in to comment.