Skip to content

Commit

Permalink
binding for extended syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Mar 27, 2023
1 parent a93de3a commit 6bb7752
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 0 deletions.
1 change: 1 addition & 0 deletions ffi/py/tests/mobilenet_onnx_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def test_typed_model_to_nnef_and_back():
assert str(reloaded.output_fact(0)) == "B,1000,F32"

path = tmpdirname / "nnef.tar.gz"
nnef = nnef.with_extended_identifier_syntax()
nnef.write_model_to_tar_gz(typed, path)
reloaded = nnef.model_for_path(path)
assert str(reloaded.input_fact(0)) == "B,3,224,224,F32"
Expand Down
8 changes: 8 additions & 0 deletions ffi/py/tract/nnef.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ def with_pulse(self) -> "Nnef":
check(lib.tract_nnef_enable_pulse(self.ptr))
return self

def with_extended_identifier_syntax(self) -> "Nnef":
"""
Enable tract-opl extensions to NNEF for extended identifiers (will support PyTorch 2 path-like ids)
"""
self._valid()
check(lib.tract_nnef_allow_extended_identifier_syntax(self.ptr, True))
return self

def write_model_to_dir(self, model: Model, path: Union[str, Path]) -> None:
"""
Save `model` as a NNEF directory model in `path`.
Expand Down
9 changes: 9 additions & 0 deletions ffi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,15 @@ pub unsafe extern "C" fn tract_nnef_enable_pulse(nnef: *mut TractNnef) -> TRACT_
})
}

#[no_mangle]
pub unsafe extern "C" fn tract_nnef_allow_extended_identifier_syntax(nnef: *mut TractNnef, enable: bool) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(nnef);
(*nnef).0.allow_extended_identifier_syntax(enable);
Ok(())
})
}

/// Destroy the NNEF parser. It is safe to detroy the NNEF parser once the model had been loaded.
#[no_mangle]
pub unsafe extern "C" fn tract_nnef_destroy(nnef: *mut *mut TractNnef) -> TRACT_RESULT {
Expand Down

0 comments on commit 6bb7752

Please sign in to comment.