forked from karpathy/llama2.c
-
-
Notifications
You must be signed in to change notification settings - Fork 41
/
export.py
567 lines (485 loc) · 23.9 KB
/
export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
"""
This script has functions and utilties for model export.
Basically, we have a bunch of versions of the model, and we
want to export them to .bin files to be read from and inferenced in C.
Among the "input" versions of PyTorch files/models:
- Official Llama 2 weights released by Meta
- Huggingface weights available on the hub
- llama2.c (this repo) trained models
Among the "output" versions of .bin files:
- v0: Legacy files of the original llama2.c repo (will eventually be DEPRECATED)
- v1-vN: Improved .bin files with a proper header, cache alignment, etc.
This script aspires to provide all of these conversions.
"""
import os
import gzip
import shutil
import struct
import argparse
import json
from pathlib import Path
import numpy as np
import torch
from torch import nn
from model import ModelArgs, Transformer
# -----------------------------------------------------------------------------
# common utilities
def serialize_fp32(file, tensor):
""" writes one fp32 tensor to file that is open in wb mode """
d = tensor.detach().cpu().view(-1).to(torch.float32).numpy()
b = struct.pack(f'{len(d)}f', *d)
file.write(b)
def serialize_int8(file, tensor):
""" writes one int8 tensor to file that is open in wb mode """
d = tensor.detach().cpu().view(-1).numpy().astype(np.int8)
b = struct.pack(f'{len(d)}b', *d)
file.write(b)
def quantize_q80(w, group_size):
"""
takes a tensor and returns the Q8_0 quantized version
i.e. symmetric quantization into int8, range [-127,127]
"""
assert w.numel() % group_size == 0
ori_shape = w.shape
w = w.float() # convert to float32
w = w.reshape(-1, group_size)
# find the max in each group
wmax = torch.abs(w).max(dim=1).values
# calculate the scaling factor such that float = quant * scale
scale = wmax / 127.0
# scale into range [-127, 127]
quant = w / scale[:,None]
# round to nearest integer
int8val = torch.round(quant).to(torch.int8)
# dequantize by rescaling
fp32val = (int8val.float() * scale[:,None]).view(-1)
fp32valr = fp32val.reshape(-1, group_size)
# calculate the max error in each group
err = torch.abs(fp32valr - w).max(dim=1).values
# find the max error across all groups
maxerr = err.max().item()
return int8val, scale, maxerr
# -----------------------------------------------------------------------------
# legacy
def legacy_export(model, filepath):
""" Original export of llama2.c bin files, i.e. version v0 """
out_file = open(filepath, 'wb')
# first write out the header
hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0]
p = model.params
shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight)
# legacy format uses negative/positive vocab size as a shared classifier flag
if not shared_classifier:
p.vocab_size = -p.vocab_size
n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
n_kv_heads, p.vocab_size, p.max_seq_len)
out_file.write(header)
# next write out the embedding weights
serialize_fp32(out_file, model.tok_embeddings.weight)
# now all the layers
# attention weights
for layer in model.layers:
serialize_fp32(out_file, layer.attention_norm.weight)
for layer in model.layers:
serialize_fp32(out_file, layer.attention.wq.weight)
for layer in model.layers:
serialize_fp32(out_file, layer.attention.wk.weight)
for layer in model.layers:
serialize_fp32(out_file, layer.attention.wv.weight)
for layer in model.layers:
serialize_fp32(out_file, layer.attention.wo.weight)
# ffn weights
for layer in model.layers:
serialize_fp32(out_file, layer.ffn_norm.weight)
for layer in model.layers:
serialize_fp32(out_file, layer.feed_forward.w1.weight)
for layer in model.layers:
serialize_fp32(out_file, layer.feed_forward.w2.weight)
for layer in model.layers:
serialize_fp32(out_file, layer.feed_forward.w3.weight)
# final rmsnorm
serialize_fp32(out_file, model.norm.weight)
# freqs_cis
serialize_fp32(out_file, model.freqs_cos[:p.max_seq_len])
serialize_fp32(out_file, model.freqs_sin[:p.max_seq_len])
# final classifier weights
if not shared_classifier:
serialize_fp32(out_file, model.output.weight)
# write to binary file
out_file.close()
print(f"wrote {filepath}")
# -----------------------------------------------------------------------------
# new version
def version1_export(model, filepath):
"""
Export the model weights in full float32 .bin file to be read from C.
This is same as legacy_export, but with a proper header.
"""
version = 1
out_file = open(filepath, 'wb')
# first write out the header. the header will be 256 bytes
# 1) write magic, which will be uint32 of "ak42" in ASCII
out_file.write(struct.pack('I', 0x616b3432))
# 2) write version, which will be int
out_file.write(struct.pack('i', version))
# 3) write the params, which will be 7 ints
p = model.params
hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0]
n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
n_kv_heads, p.vocab_size, p.max_seq_len)
out_file.write(header)
# 4) write some other flags
shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight)
out_file.write(struct.pack('B', int(shared_classifier)))
pad = 256 - out_file.tell() # pad rest with zeros; tell returns current pos
assert pad >= 0
out_file.write(b'\0' * pad)
# now let's write out all the params
weights = [
*[layer.attention_norm.weight for layer in model.layers],
*[layer.ffn_norm.weight for layer in model.layers],
model.norm.weight,
model.tok_embeddings.weight,
*[layer.attention.wq.weight for layer in model.layers],
*[layer.attention.wk.weight for layer in model.layers],
*[layer.attention.wv.weight for layer in model.layers],
*[layer.attention.wo.weight for layer in model.layers],
*[layer.feed_forward.w1.weight for layer in model.layers],
*[layer.feed_forward.w2.weight for layer in model.layers],
*[layer.feed_forward.w3.weight for layer in model.layers],
]
if not shared_classifier:
weights.append(model.output.weight)
for w in weights:
serialize_fp32(out_file, w)
# write to binary file
out_file.close()
print(f"wrote {filepath}")
def version2_export(model, filepath, group_size=64):
"""
Export the model weights in Q8_0 into .bin file to be read from C.
That is:
- quantize all weights to symmetric int8, in range [-127, 127]
- all other tensors (the rmsnorm params) are kept and exported in fp32
- quantization is done in groups of group_size to reduce the effects of any outliers
"""
version = 2
# let's first do some validation for this export type
while model.params.dim % group_size != 0:
group_size //= 2
print(f"BACKOFF: reducing group size to {group_size} to fit hidden_dim")
weights = [
model.tok_embeddings.weight,
*[layer.attention.wq.weight for layer in model.layers],
*[layer.attention.wk.weight for layer in model.layers],
*[layer.attention.wv.weight for layer in model.layers],
*[layer.attention.wo.weight for layer in model.layers],
*[layer.feed_forward.w1.weight for layer in model.layers],
*[layer.feed_forward.w2.weight for layer in model.layers],
*[layer.feed_forward.w3.weight for layer in model.layers],
]
shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight)
if not shared_classifier:
weights.append(model.output.weight)
for w in weights:
assert w.numel() % group_size == 0, f"weight {i} has numel {w.numel()}, not a multiple of group_size {group_size}"
# write
out_file = open(filepath, 'wb')
# first write out the header. the header will be 256 bytes
# 1) write magic, which will be uint32 of "ak42" in ASCII
out_file.write(struct.pack('I', 0x616b3432))
# 2) write version, which will be int
out_file.write(struct.pack('i', version))
# 3) write the params, which will be 7 ints
p = model.params
hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0]
n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
n_kv_heads, p.vocab_size, p.max_seq_len)
out_file.write(header)
# 4) write some other flags
out_file.write(struct.pack('B', int(shared_classifier)))
out_file.write(struct.pack('i', group_size)) # group size used for quantization
pad = 256 - out_file.tell() # pad rest with zeros; tell returns current pos
assert pad >= 0
out_file.write(b'\0' * pad)
# now that the header is done, let's write out the model
# first let's write out all the params that we are keeping in fp32: the norms
for layer in model.layers: # attention norms
serialize_fp32(out_file, layer.attention_norm.weight)
for layer in model.layers: # MLP norms
serialize_fp32(out_file, layer.ffn_norm.weight)
serialize_fp32(out_file, model.norm.weight) # final pre-classifier norm
# now let's write out all the params that we are quantizing to Q8_0
# note we skip classifier weights, which are shared with the embedding
ew = []
for i, w in enumerate(weights):
# quantize this weight
q, s, err = quantize_q80(w, group_size)
# save the int8 weights to file
serialize_int8(out_file, q) # save the tensor in int8
serialize_fp32(out_file, s) # save scale factors
# logging
ew.append((err, w.shape))
print(f"{i+1}/{len(weights)} quantized {tuple(w.shape)} to Q8_0 with max error {err}")
# print the highest error across all weights, should be very small, e.g. O(~0.001)
ew.sort(reverse=True)
print(f"max quantization group error across all weights: {ew[0][0]}")
# write to binary file
out_file.close()
print(f"wrote {filepath}")
def hf_export(llama_model, filepath, group_size=64, dtype=torch.float32):
""" Generate the pytorch_model.bin state_dict and config.json for HuggingFace """
try:
from transformers.models.llama.configuration_llama import LlamaConfig
except ImportError:
print("Error: transformers package is required to load huggingface models")
print("Please run `pip install transformers` to install it")
return None
# Generate LlamaModel state_dict
hf_state_dict = {}
# Sometimes we have repeated key values for the heads
dim = llama_model.params.dim
num_key_value_heads = llama_model.params.n_kv_heads
n_rep = llama_model.params.n_heads // num_key_value_heads
key_value_dim = dim // n_rep
# HuggingFace needs the weights permuted.
# See: https://github.com/huggingface/transformers/blob/b132c1703eb1c8bd9dfa4ad6a9be2bfd6ef819e9/src/transformers/models/llama/convert_llama_weights_to_hf.py#L122
def permute_original(w, n_heads=llama_model.params.n_heads, dim1=dim, dim2=dim):
return w.view(dim1, dim2).reshape(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
# Transfer weights from llama model to the HF state dictionary format
hf_state_dict['model.embed_tokens.weight'] = llama_model.tok_embeddings.weight.clone().to(dtype)
hf_state_dict['model.norm.weight'] = llama_model.norm.weight.clone().to(dtype)
# Add each layer's weights to the HF state dictionary
for i, layer in enumerate(llama_model.layers):
layer_id = layer.layer_id
hf_state_dict[f'model.layers.{i}.input_layernorm.weight'] = llama_model.layers[layer_id].attention_norm.weight.clone().to(dtype)
hf_state_dict[f'model.layers.{i}.self_attn.q_proj.weight'] = permute_original(llama_model.layers[layer_id].attention.wq.weight.clone()).to(dtype)
hf_state_dict[f'model.layers.{i}.self_attn.k_proj.weight'] = permute_original(llama_model.layers[layer_id].attention.wk.weight.clone(), num_key_value_heads, key_value_dim, dim).to(dtype)
hf_state_dict[f'model.layers.{i}.self_attn.v_proj.weight'] = llama_model.layers[layer_id].attention.wv.weight.clone().to(dtype)
hf_state_dict[f'model.layers.{i}.self_attn.o_proj.weight'] = llama_model.layers[layer_id].attention.wo.weight.clone().to(dtype)
hf_state_dict[f'model.layers.{i}.post_attention_layernorm.weight'] = llama_model.layers[layer_id].ffn_norm.weight.clone().to(dtype)
hf_state_dict[f'model.layers.{i}.mlp.gate_proj.weight'] = llama_model.layers[layer_id].feed_forward.w1.weight.clone().to(dtype)
hf_state_dict[f'model.layers.{i}.mlp.down_proj.weight'] = llama_model.layers[layer_id].feed_forward.w2.weight.clone().to(dtype)
hf_state_dict[f'model.layers.{i}.mlp.up_proj.weight'] = llama_model.layers[layer_id].feed_forward.w3.weight.clone().to(dtype)
# llama2.c usually uses tied weights -> reference the embed_tokens.weights instead
hf_state_dict['lm_head.weight'] = hf_state_dict['model.embed_tokens.weight']
# We check that the embeddings are tied, else use manual output weights
_embeddings_are_tied: bool = torch.equal(llama_model.tok_embeddings.weight, llama_model.output.weight)
if not _embeddings_are_tied:
hf_state_dict['lm_head.weight'] = llama_model.output.weight.clone().to(dtype)
# Generate LlamaConfig (seen in transformers.models.llama.configuration_llama)
# Extract necessary attributes from llama.c model
vocab_size = llama_model.params.vocab_size
hidden_size = llama_model.params.dim
intermediate_size = llama_model.layers[0].feed_forward.w1.weight.shape[0]
num_hidden_layers = llama_model.params.n_layers
num_attention_heads = llama_model.params.n_heads
num_key_value_heads = llama_model.params.n_kv_heads
max_position_embeddings = llama_model.params.max_seq_len
rms_norm_eps = llama_model.params.norm_eps
# TODO check values for:
# pretraining_tp, initializer_range, use_cache,
# rope_theta, and rope_scaling.
config = LlamaConfig(
vocab_size=vocab_size,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
max_position_embeddings=max_position_embeddings,
rms_norm_eps=rms_norm_eps,
tie_word_embeddings=_embeddings_are_tied,
# Manual
architectures=["LlamaForCausalLM"],
hidden_act="silu",
)
# Save files in directory filepath
# First make the directory if it doesn't exist
os.makedirs(filepath, exist_ok=True)
# Save the state dictionary in .bin format, and config as .json
torch.save(hf_state_dict, os.path.join(filepath, "pytorch_model.bin"))
config.save_pretrained(filepath)
# -----------------------------------------------------------------------------
# Load / import functions
def load_checkpoint(checkpoint):
# load the provided model checkpoint
checkpoint_dict = torch.load(checkpoint, map_location='cpu')
gptconf = ModelArgs(**checkpoint_dict['model_args'])
model = Transformer(gptconf)
state_dict = checkpoint_dict['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict, strict=False)
model.eval()
return model
def load_meta_model(model_path):
params_path = os.path.join(model_path, 'params.json')
with open(params_path) as f:
params = json.load(f)
print(params)
model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth')))
models = [torch.load(p, map_location='cpu') for p in model_paths]
def concat_weights(models):
state_dict = {}
for name in list(models[0]):
tensors = [model[name] for model in models]
if len(tensors) == 1 or len(tensors[0].shape) == 1:
state_dict[name] = tensors[0]
continue
is_axis_1 = (
name.startswith('tok_embeddings.')
or name.endswith('.attention.wo.weight')
or name.endswith('.feed_forward.w2.weight')
)
axis = 1 if is_axis_1 else 0
state_dict[name] = torch.cat(tensors, dim=axis)
for model in models:
del model[name]
return state_dict
state_dict = concat_weights(models)
del models
# set ModelArgs
config = ModelArgs()
config.dim = params["dim"]
config.n_layers = params["n_layers"]
config.n_heads = params["n_heads"]
config.n_kv_heads = params.get('n_kv_heads') or params['n_heads']
config.multiple_of = params["multiple_of"]
config.norm_eps = params["norm_eps"]
config.vocab_size = state_dict['tok_embeddings.weight'].shape[0]
config.max_seq_len = 2048
# create a new Transformer object and set weights
model = Transformer(config)
model.tok_embeddings.weight = nn.Parameter(state_dict['tok_embeddings.weight'])
model.norm.weight = nn.Parameter(state_dict['norm.weight'])
for layer in model.layers:
i = layer.layer_id
layer.attention_norm.weight = nn.Parameter(state_dict[f'layers.{i}.attention_norm.weight'])
layer.attention.wq.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wq.weight'])
layer.attention.wk.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wk.weight'])
layer.attention.wv.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wv.weight'])
layer.attention.wo.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wo.weight'])
layer.ffn_norm.weight = nn.Parameter(state_dict[f'layers.{i}.ffn_norm.weight'])
layer.feed_forward.w1.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w1.weight'])
layer.feed_forward.w2.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w2.weight'])
layer.feed_forward.w3.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w3.weight'])
# final classifier
model.output.weight = nn.Parameter(state_dict['output.weight'])
model.eval()
return model
def load_hf_model(model_path):
try:
from transformers import AutoModelForCausalLM
except ImportError:
print("Error: transformers package is required to load huggingface models")
print("Please run `pip install transformers` to install it")
return None
# load HF model
hf_model = AutoModelForCausalLM.from_pretrained(model_path)
hf_dict = hf_model.state_dict()
# convert LlamaConfig to ModelArgs
config = ModelArgs()
config.dim = hf_model.config.hidden_size
config.n_layers = hf_model.config.num_hidden_layers
config.n_heads = hf_model.config.num_attention_heads
config.n_kv_heads = hf_model.config.num_attention_heads
config.vocab_size = hf_model.config.vocab_size
config.hidden_dim = hf_model.config.intermediate_size
config.norm_eps = hf_model.config.rms_norm_eps
config.max_seq_len = hf_model.config.max_position_embeddings
# create a new Transformer object and set weights
model = Transformer(config)
model.tok_embeddings.weight = nn.Parameter(hf_dict['model.embed_tokens.weight'])
model.norm.weight = nn.Parameter(hf_dict['model.norm.weight'])
# huggingface permutes WQ and WK, this function reverses it
def permute_reverse(w, n_heads=config.n_heads, dim1=config.dim, dim2=config.dim):
return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2)
for layer in model.layers:
i = layer.layer_id
layer.attention_norm.weight = nn.Parameter(hf_dict[f'model.layers.{i}.input_layernorm.weight'])
layer.attention.wq.weight = nn.Parameter(permute_reverse(hf_dict[f'model.layers.{i}.self_attn.q_proj.weight']))
layer.attention.wk.weight = nn.Parameter(permute_reverse(hf_dict[f'model.layers.{i}.self_attn.k_proj.weight']))
layer.attention.wv.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.v_proj.weight'])
layer.attention.wo.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.o_proj.weight'])
layer.ffn_norm.weight = nn.Parameter(hf_dict[f'model.layers.{i}.post_attention_layernorm.weight'])
layer.feed_forward.w1.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.gate_proj.weight'])
layer.feed_forward.w2.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.down_proj.weight'])
layer.feed_forward.w3.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.up_proj.weight'])
# final classifier
model.output.weight = nn.Parameter(hf_dict['lm_head.weight'])
model.eval()
return model
# -----------------------------------------------------------------------------
# API entrypoint
def model_export(model, filepath, version, dtype=torch.float32):
"""
Versions docs:
v-1:huggingface export, i.e. intended for use outside of this repo, in HF
v0: legacy llama2.c float format, DEPRECATED
v1: float32 export
v2: int8 quantized Q8_0 export, similar to llama.cpp, in groups
# TODO: add dtype export support for other versions (?)
"""
if version == 0:
legacy_export(model, filepath)
elif version == 1:
version1_export(model, filepath)
elif version == 2:
version2_export(model, filepath)
elif version == -1:
hf_export(model, filepath, dtype)
else:
raise ValueError(f"unknown version {version}")
def torchscript_export(model, filepath, zero_params=False, gzip_output=False):
"""
(This was submitted via a PR earlier. Leaving it here, but "orphaned" for now)
Saves the model as a TorchScript.
The resulting file can be loaded in C++ code and then used for training or
inference with:
#include <torch/script.h>
torch::jit::Module module = torch::jit::load("model.pt")
Note that the serialized model includes the initial parameters and with the default
ModelArgs the file is 59M and gzips down to 55M. If you want to serialize/distribute
the model parameters separately you can zero out the parameters before saving it and
it will gzip down to 780K.
"""
# If requested zero params before saving the model. This is useful in
# conjunction with gzip_output.
if zero_params:
for p in model.parameters():
p.detach().zero_()
torch.jit.save(torch.jit.script(model), filepath)
if gzip_output:
with open(filepath, "rb") as f_in:
with gzip.open(f"{filepath}.gz", "wb") as f_out:
shutil.copyfileobj(f_in, f_out)
os.unlink(filepath)
# -----------------------------------------------------------------------------
# CLI entrypoint
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("filepath", type=str, help="the output filepath")
parser.add_argument("--version", default=0, type=int, help="the version to export with")
parser.add_argument("--dtype", type=str, help="dtype of the model (fp16, fp32)", default="fp32")
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--checkpoint", type=str, help="model checkpoint, .pt file")
group.add_argument("--meta-llama", type=str, help="meta llama model path")
group.add_argument("--hf", type=str, help="huggingface model path")
args = parser.parse_args()
dtype = {"fp16": torch.float16, "fp32": torch.float32}[args.dtype]
if args.checkpoint:
model = load_checkpoint(args.checkpoint)
elif args.meta_llama:
model = load_meta_model(args.meta_llama)
elif args.hf:
model = load_hf_model(args.hf)
if model is None:
parser.error("Can't load input model!")
# export
model_export(model, args.filepath, args.version, args.dtype)