Skip to content

Commit

Permalink
Merge branch 'ccrouzet/fix-overload-resolution' into 'main'
Browse files Browse the repository at this point in the history
Fix incorrect user function overloading

See merge request omniverse/warp!811
  • Loading branch information
mmacklin committed Oct 22, 2024
2 parents df32b56 + 28b5d56 commit c3291b7
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

- Fix printing vector and matrix adjoints in backward kernels.
- Fix kernel compile error when printing structs.
- Fix an incorrect user function being sometimes resolved when multiple overloads are available with array parameters with different `dtype` values.

## [1.4.1] - 2024-10-15

Expand Down
36 changes: 35 additions & 1 deletion warp/tests/test_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import math
import unittest
from typing import Tuple
from typing import Any, Tuple

import numpy as np

Expand Down Expand Up @@ -191,6 +191,37 @@ def test_user_func_return_multiple_values():
wp.expect_eq(b, 54756.0)


@wp.func
def user_func_overload(
b: wp.array(dtype=Any),
i: int,
):
return b[i] * 2.0


@wp.kernel
def user_func_overload_resolution_kernel(
a: wp.array(dtype=Any),
b: wp.array(dtype=Any),
):
i = wp.tid()
a[i] = user_func_overload(b, i)


def test_user_func_overload_resolution(test, device):
a0 = wp.array((1, 2, 3), dtype=wp.vec3)
b0 = wp.array((2, 3, 4), dtype=wp.vec3)

a1 = wp.array((5,), dtype=float)
b1 = wp.array((6,), dtype=float)

wp.launch(user_func_overload_resolution_kernel, a0.shape, (a0, b0))
wp.launch(user_func_overload_resolution_kernel, a1.shape, (a1, b1))

assert_np_equal(a0.numpy()[0], (4, 6, 8))
assert a1.numpy()[0] == 12


devices = get_test_devices()


Expand Down Expand Up @@ -375,6 +406,9 @@ def test_native_function_error_resolution(self):
dim=1,
devices=devices,
)
add_function_test(
TestFunc, func=test_user_func_overload_resolution, name="test_user_func_overload_resolution", devices=devices
)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion warp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1492,7 +1492,7 @@ def types_equal(a, b, match_generic=False):

return True

if is_array(a) and type(a) is type(b):
if is_array(a) and type(a) is type(b) and types_equal(a.dtype, b.dtype, match_generic=match_generic):
return True

# match NewStructInstance and Struct dtype
Expand Down

0 comments on commit c3291b7

Please sign in to comment.