-
Notifications
You must be signed in to change notification settings - Fork 0
/
h5tools.py
55 lines (43 loc) · 1.77 KB
/
h5tools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import os
import h5py
def split(fname_src: str, fname_dest_prefix: str, maxsize_per_file: float):
"""
Splits an `h5` file into smaller parts, size of each not exceeding
`maxsize_per_file`.
"""
idx = 0
dest_fnames = []
is_file_open = False
with h5py.File(fname_src, "r") as src:
for group in src:
fname = f"{fname_dest_prefix}{idx}.h5"
if not is_file_open:
dest = h5py.File(fname, "w")
dest.attrs.update(src.attrs)
dest_fnames.append(fname)
is_file_open = True
group_id = dest.require_group(src[group].parent.name)
src.copy(f"/{group}", group_id, name=group)
if os.path.getsize(fname) > maxsize_per_file:
dest.close()
idx += 1
is_file_open = False
dest.close()
return dest_fnames
def combine(fname_in: list, fname_out: str):
"""
Combines a series of `h5` files into a single file.
"""
with h5py.File(fname_out, "w") as combined:
for fname in fname_in:
with h5py.File(fname, "r") as src:
combined.attrs.update(src.attrs)
for group in src:
group_id = combined.require_group(src[group].parent.name)
src.copy(f"/{group}", group_id, name=group)
if __name__ == "__main__":
prefix = "model_weights_part"
fname_src = "path_to_large_model_weights_file.h5"
size_max = 90 * 1024**2 # maximum size allowed in bytes
fname_parts = split(fname_src, fname_dest_prefix=prefix, maxsize_per_file=size_max)
combine(fname_in=fname_parts, fname_out="model_weights.h5")