diff --git a/CHANGELOG.md b/CHANGELOG.md index 90d9bd20c..b914bf745 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ - Remove ``wp.synchronize()`` from PyTorch autograd function example - ``Tape.check_kernel_array_access()`` and ``Tape.reset_array_read_flags()`` are now private methods. - Fix reporting unmatched argument types +- Fix errors when launching a CUDA graph after a module is reloaded +- Minor improvements to kernel launch performance ## [1.3.0] - 2024-07-25 diff --git a/warp/context.py b/warp/context.py index 50f9a8ce6..0404b2f9f 100644 --- a/warp/context.py +++ b/warp/context.py @@ -1411,12 +1411,65 @@ def codegen(self, device): return source +# ModuleExec holds the compiled executable code for a specific device. +# It can be used to obtain kernel hooks on that device and serves +# as a reference-counted wrapper of the loaded module. +# Clients can keep a reference to a ModuleExec object to prevent the +# executable code from being unloaded prematurely. +# For example, the Graph class retains references to all the CUDA modules +# needed by a graph. This ensures that graphs remain valid even if +# the original Modules get reloaded. +class ModuleExec: + def __new__(cls, *args, **kwargs): + instance = super(ModuleExec, cls).__new__(cls) + instance.handle = None + return instance + + def __init__(self, handle, device): + self.handle = handle + self.device = device + self.kernel_hooks = {} + + # release the loaded module + def __del__(self): + if self.handle is not None: + if self.device.is_cuda: + # use CUDA context guard to avoid side effects during garbage collection + with self.device.context_guard: + runtime.core.cuda_unload_module(self.device.context, self.handle) + else: + runtime.llvm.unload_obj(self.handle.encode("utf-8")) + + # lookup and cache kernel entry points + def get_kernel_hooks(self, kernel): + hooks = self.kernel_hooks.get(kernel) + if hooks is not None: + return hooks + + name = kernel.get_mangled_name() + + if self.device.is_cuda: + forward = runtime.core.cuda_get_kernel( + self.device.context, self.handle, (name + "_cuda_kernel_forward").encode("utf-8") + ) + backward = runtime.core.cuda_get_kernel( + self.device.context, self.handle, (name + "_cuda_kernel_backward").encode("utf-8") + ) + else: + func = ctypes.CFUNCTYPE(None) + forward = func(runtime.llvm.lookup(self.handle.encode("utf-8"), (name + "_cpu_forward").encode("utf-8"))) + backward = func(runtime.llvm.lookup(self.handle.encode("utf-8"), (name + "_cpu_backward").encode("utf-8"))) + + hooks = KernelHooks(forward, backward) + self.kernel_hooks[kernel] = hooks + + return hooks + + # ----------------------------------------------------- # stores all functions and kernels for a Python module # creates a hash of the function to use for checking # build cache - - class Module: def __init__(self, name, loader): self.name = name @@ -1427,8 +1480,8 @@ def __init__(self, name, loader): self.constants = {} # Any constants referenced in this module including those defined in other modules self.structs = {} - self.cpu_module = None - self.cuda_modules = {} # module lookup by CUDA context + self.cpu_exec = None # executable CPU module + self.cuda_execs = {} # executable CUDA module lookup by CUDA context self.cpu_build_failed = False self.cuda_build_failed = False @@ -1441,11 +1494,6 @@ def __init__(self, name, loader): "mode": warp.config.mode, } - # kernel hook lookup per device - # hooks are stored with the module so they can be easily cleared when the module is reloaded. - # -> See ``Module.get_kernel_hooks()`` - self.kernel_hooks = {} - # Module dependencies are determined by scanning each function # and kernel for references to external functions and structs. # @@ -1685,27 +1733,26 @@ def hash_recursive(module, visited): return hash_recursive(self, visited=set()) - def load(self, device) -> bool: - from warp.utils import ScopedTimer - - device = get_device(device) + def load(self, device) -> ModuleExec: + device = runtime.get_device(device) if device.is_cpu: # check if already loaded - if self.cpu_module: - return True + if self.cpu_exec: + return self.cpu_exec # avoid repeated build attempts if self.cpu_build_failed: - return False + return None if not warp.is_cpu_available(): raise RuntimeError("Failed to build CPU module because no CPU buildchain was found") else: # check if already loaded - if device.context in self.cuda_modules: - return True + cuda_exec = self.cuda_execs.get(device.context) + if cuda_exec is not None: + return cuda_exec # avoid repeated build attempts if self.cuda_build_failed: - return False + return None if not warp.is_cuda_available(): raise RuntimeError("Failed to build CUDA module because CUDA is not available") @@ -1715,7 +1762,7 @@ def load(self, device) -> bool: # use a unique module path using the module short hash module_dir = os.path.join(warp.config.kernel_cache_dir, f"{module_name}_{module_hash.hex()[:7]}") - with ScopedTimer( + with warp.ScopedTimer( f"Module {self.name} {module_hash.hex()[:7]} load on device '{device}'", active=not warp.config.quiet ) as module_load_timer: # ----------------------------------------------------------- @@ -1787,7 +1834,7 @@ def load(self, device) -> bool: output_path = os.path.join(build_dir, output_name) # build object code - with ScopedTimer("Compile x86", active=warp.config.verbose): + with warp.ScopedTimer("Compile x86", active=warp.config.verbose): warp.build.build_cpu( output_path, source_code_path, @@ -1815,7 +1862,7 @@ def load(self, device) -> bool: output_path = os.path.join(build_dir, output_name) # generate PTX or CUBIN - with ScopedTimer("Compile CUDA", active=warp.config.verbose): + with warp.ScopedTimer("Compile CUDA", active=warp.config.verbose): warp.build.build_cuda( source_code_path, output_arch, @@ -1868,12 +1915,14 @@ def load(self, device) -> bool: # Load CPU or CUDA binary if device.is_cpu: runtime.llvm.load_obj(binary_path.encode("utf-8"), module_name.encode("utf-8")) - self.cpu_module = module_name + module_exec = ModuleExec(module_name, device) + self.cpu_exec = module_exec elif device.is_cuda: cuda_module = warp.build.load_cuda(binary_path, device) if cuda_module is not None: - self.cuda_modules[device.context] = cuda_module + module_exec = ModuleExec(cuda_module, device) + self.cuda_execs[device.context] = module_exec else: module_load_timer.extra_msg = " (error)" raise Exception(f"Failed to load CUDA module '{self.name}'") @@ -1884,65 +1933,27 @@ def load(self, device) -> bool: # clean up build_dir used for this process regardless shutil.rmtree(build_dir, ignore_errors=True) - return True + return module_exec def unload(self): - if self.cpu_module: - runtime.llvm.unload_obj(self.cpu_module.encode("utf-8")) - self.cpu_module = None - - # need to unload the CUDA module from all CUDA contexts where it is loaded - # note: we ensure that this doesn't change the current CUDA context - if self.cuda_modules: - saved_context = runtime.core.cuda_context_get_current() - for context, module in self.cuda_modules.items(): - device = runtime.context_map[context] - if device.is_capturing: - raise RuntimeError(f"Failed to unload CUDA module '{self.name}' because graph capture is active") - runtime.core.cuda_unload_module(context, module) - runtime.core.cuda_context_set_current(saved_context) - self.cuda_modules = {} - - # clear kernel hooks - self.kernel_hooks = {} + # clear loaded modules + self.cpu_exec = None + self.cuda_execs = {} # clear content hash self.content_hash = None - # lookup and cache kernel entry points based on name, called after compilation / module load + # lookup kernel entry points based on name, called after compilation / module load def get_kernel_hooks(self, kernel, device): - # get all hooks for this device - device_hooks = self.kernel_hooks.get(device.context) - if device_hooks is None: - self.kernel_hooks[device.context] = device_hooks = {} - - # look up this kernel - hooks = device_hooks.get(kernel) - if hooks is not None: - return hooks - - name = kernel.get_mangled_name() - - if device.is_cpu: - func = ctypes.CFUNCTYPE(None) - forward = func( - runtime.llvm.lookup(self.cpu_module.encode("utf-8"), (name + "_cpu_forward").encode("utf-8")) - ) - backward = func( - runtime.llvm.lookup(self.cpu_module.encode("utf-8"), (name + "_cpu_backward").encode("utf-8")) - ) + if device.is_cuda: + module_exec = self.cuda_execs.get(device.context) else: - cu_module = self.cuda_modules[device.context] - forward = runtime.core.cuda_get_kernel( - device.context, cu_module, (name + "_cuda_kernel_forward").encode("utf-8") - ) - backward = runtime.core.cuda_get_kernel( - device.context, cu_module, (name + "_cuda_kernel_backward").encode("utf-8") - ) + module_exec = self.cpu_exec - hooks = KernelHooks(forward, backward) - device_hooks[kernel] = hooks - return hooks + if module_exec is not None: + return module_exec.get_kernel_hooks(kernel) + else: + raise RuntimeError(f"Module is not loaded on device {device}") # ------------------------------------------- @@ -2199,8 +2210,8 @@ def __init__(self, runtime, alias, ordinal=-1, is_primary=False, context=None): self._stream = None self.null_stream = None - # set of streams where capture has started - self.captures = set() + # maps streams to started graph captures + self.captures = {} self.context_guard = ContextGuard(self) @@ -2437,20 +2448,25 @@ def can_access(self, other): class Graph: def __new__(cls, *args, **kwargs): instance = super(Graph, cls).__new__(cls) - instance.exec = None + instance.graph_exec = None return instance - def __init__(self, device: Device, exec: ctypes.c_void_p): + def __init__(self, device: Device, capture_id: int): self.device = device - self.exec = exec + self.capture_id = capture_id + self.module_execs = set() def __del__(self): - if not self.exec: + if not self.graph_exec: return # use CUDA context guard to avoid side effects during garbage collection with self.device.context_guard: - runtime.core.cuda_graph_destroy(self.device.context, self.exec) + runtime.core.cuda_graph_destroy(self.device.context, self.graph_exec) + + # retain executable CUDA modules used by this graph, which prevents them from being unloaded + def retain_module_exec(self, module_exec: ModuleExec): + self.module_execs.add(module_exec) class Runtime: @@ -2491,6 +2507,9 @@ def __init__(self): else: self.llvm = None + # maps capture ids to graphs + self.captures = {} + # setup c-types for warp.dll try: self.core.get_error_string.argtypes = [] @@ -3026,6 +3045,8 @@ def __init__(self): self.core.cuda_stream_wait_stream.restype = None self.core.cuda_stream_is_capturing.argtypes = [ctypes.c_void_p] self.core.cuda_stream_is_capturing.restype = ctypes.c_int + self.core.cuda_stream_get_capture_id.argtypes = [ctypes.c_void_p] + self.core.cuda_stream_get_capture_id.restype = ctypes.c_uint64 self.core.cuda_event_create.argtypes = [ctypes.c_void_p, ctypes.c_uint] self.core.cuda_event_create.restype = ctypes.c_void_p @@ -4493,13 +4514,14 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False): # so that launches can be replayed quickly, use `wp.launch(..., record_cmd=True)` class Launch: def __init__(self, kernel, device, hooks=None, params=None, params_addr=None, bounds=None, max_blocks=0): + # retain the module executable so it doesn't get unloaded + self.module_exec = kernel.module.load(device) + if not self.module_exec: + raise RuntimeError(f"Failed to load module {kernel.module.name} on device {device}") + # if not specified look up hooks if not hooks: - module = kernel.module - if not module.load(device): - return - - hooks = module.get_kernel_hooks(kernel, device) + hooks = self.module_exec.get_kernel_hooks(kernel) # if not specified set a zero bound if not bounds: @@ -4597,6 +4619,15 @@ def launch(self, stream=None) -> Any: else: if stream is None: stream = self.device.stream + + # If the stream is capturing, we retain the CUDA module so that it doesn't get unloaded + # before the captured graph is released. + if runtime.core.cuda_stream_is_capturing(stream.cuda_stream): + capture_id = runtime.core.cuda_stream_get_capture_id(stream.cuda_stream) + graph = runtime.captures.get(capture_id) + if graph is not None: + graph.retain_module_exec(self.module_exec) + runtime.core.cuda_launch_kernel( self.device.context, self.hooks.forward, @@ -4692,12 +4723,12 @@ def pack_args(args, params, adjoint=False): kernel = kernel.add_overload(fwd_types) # delay load modules, including new overload if needed - module = kernel.module - if not module.load(device): + module_exec = kernel.module.load(device) + if not module_exec: return # late bind - hooks = module.get_kernel_hooks(kernel, device) + hooks = module_exec.get_kernel_hooks(kernel) pack_args(fwd_args, params) pack_args(adj_args, params, adjoint=True) @@ -4733,6 +4764,14 @@ def pack_args(args, params, adjoint=False): if stream is None: stream = device.stream + # If the stream is capturing, we retain the CUDA module so that it doesn't get unloaded + # before the captured graph is released. + if runtime.core.cuda_stream_is_capturing(stream.cuda_stream): + capture_id = runtime.core.cuda_stream_get_capture_id(stream.cuda_stream) + graph = runtime.captures.get(capture_id) + if graph is not None: + graph.retain_module_exec(module_exec) + if adjoint: if hooks.backward is None: raise RuntimeError( @@ -5017,11 +5056,18 @@ def capture_begin(device: Devicelike = None, stream=None, force_module_load=None if force_module_load: force_load(device) - device.captures.add(stream) - if not runtime.core.cuda_graph_begin_capture(device.context, stream.cuda_stream, int(external)): raise RuntimeError(runtime.get_error_string()) + capture_id = runtime.core.cuda_stream_get_capture_id(stream.cuda_stream) + graph = Graph(device, capture_id) + + # add to ongoing captures on the device + device.captures[stream] = graph + + # add to lookup table by globally unique capture id + runtime.captures[capture_id] = graph + def capture_end(device: Devicelike = None, stream: Stream = None) -> Graph: """Ends the capture of a CUDA graph @@ -5043,21 +5089,27 @@ def capture_end(device: Devicelike = None, stream: Stream = None) -> Graph: raise RuntimeError("Must be a CUDA device") stream = device.stream - if stream not in device.captures: + # get the graph being captured + graph = device.captures.get(stream) + + if graph is None: raise RuntimeError("Graph capture is not active on this stream") - device.captures.remove(stream) + del device.captures[stream] + del runtime.captures[graph.capture_id] - graph = ctypes.c_void_p() - result = runtime.core.cuda_graph_end_capture(device.context, stream.cuda_stream, ctypes.byref(graph)) + # get the graph executable + graph_exec = ctypes.c_void_p() + result = runtime.core.cuda_graph_end_capture(device.context, stream.cuda_stream, ctypes.byref(graph_exec)) if not result: # A concrete error should've already been reported, so we don't need to go into details here raise RuntimeError(f"CUDA graph capture failed. {runtime.get_error_string()}") - # note that for external captures, we do not return a graph, because we don't instantiate it ourselves - if graph: - return Graph(device, graph) + # set the graph executable + graph.graph_exec = graph_exec + + return graph def capture_launch(graph: Graph, stream: Stream = None): @@ -5076,7 +5128,7 @@ def capture_launch(graph: Graph, stream: Stream = None): device = graph.device stream = device.stream - if not runtime.core.cuda_graph_launch(graph.exec, stream.cuda_stream): + if not runtime.core.cuda_graph_launch(graph.graph_exec, stream.cuda_stream): raise RuntimeError(f"Graph launch error: {runtime.get_error_string()}") diff --git a/warp/native/warp.cpp b/warp/native/warp.cpp index b7ad19a30..8d51e83bc 100644 --- a/warp/native/warp.cpp +++ b/warp/native/warp.cpp @@ -1019,6 +1019,7 @@ WP_API void cuda_stream_synchronize(void* stream) {} WP_API void cuda_stream_wait_event(void* stream, void* event) {} WP_API void cuda_stream_wait_stream(void* stream, void* other_stream, void* event) {} WP_API int cuda_stream_is_capturing(void* stream) { return 0; } +WP_API uint64_t cuda_stream_get_capture_id(void* stream) { return 0; } WP_API void* cuda_event_create(void* context, unsigned flags) { return NULL; } WP_API void cuda_event_destroy(void* event) {} diff --git a/warp/native/warp.cu b/warp/native/warp.cu index de19b2251..0c24862fe 100644 --- a/warp/native/warp.cu +++ b/warp/native/warp.cu @@ -2263,6 +2263,11 @@ int cuda_stream_is_capturing(void* stream) return int(status != cudaStreamCaptureStatusNone); } +uint64_t cuda_stream_get_capture_id(void* stream) +{ + return get_capture_id(static_cast(stream)); +} + void* cuda_event_create(void* context, unsigned flags) { ContextGuard guard(context, true); diff --git a/warp/native/warp.h b/warp/native/warp.h index 37e8c720d..a4be12400 100644 --- a/warp/native/warp.h +++ b/warp/native/warp.h @@ -295,6 +295,7 @@ extern "C" WP_API void cuda_stream_wait_event(void* stream, void* event); WP_API void cuda_stream_wait_stream(void* stream, void* other_stream, void* event); WP_API int cuda_stream_is_capturing(void* stream); + WP_API uint64_t cuda_stream_get_capture_id(void* stream); WP_API void* cuda_event_create(void* context, unsigned flags); WP_API void cuda_event_destroy(void* event); diff --git a/warp/tests/test_reload.py b/warp/tests/test_reload.py index 626955dd2..d5550f172 100644 --- a/warp/tests/test_reload.py +++ b/warp/tests/test_reload.py @@ -189,7 +189,29 @@ def test_reload_references(test, device): test_dependent.run(expect=4.0, device=device) # 2 * 2 = 4 +def test_graph_launch_after_module_reload(test, device): + @wp.kernel + def foo(a: wp.array(dtype=int)): + a[0] = 42 + + with wp.ScopedDevice(device): + a = wp.zeros(1, dtype=int) + + # capture a launch + with wp.ScopedCapture() as capture: + wp.launch(foo, dim=1, inputs=[a]) + + # unload the module + foo.module.unload() + + # launch previously captured graph + wp.capture_launch(capture.graph) + + test.assertEqual(a.numpy()[0], 42) + + devices = get_test_devices() +cuda_devices = get_cuda_test_devices() class TestReload(unittest.TestCase): @@ -200,6 +222,9 @@ class TestReload(unittest.TestCase): add_function_test(TestReload, "test_reload", test_reload, devices=devices) add_function_test(TestReload, "test_reload_class", test_reload_class, devices=devices) add_function_test(TestReload, "test_reload_references", test_reload_references, devices=devices) +add_function_test( + TestReload, "test_graph_launch_after_module_reload", test_graph_launch_after_module_reload, devices=cuda_devices +) if __name__ == "__main__":