diff --git a/cuda_core/cuda/core/__init__.py b/cuda_core/cuda/core/__init__.py index dfd52accea3..fe1ba76806a 100644 --- a/cuda_core/cuda/core/__init__.py +++ b/cuda_core/cuda/core/__init__.py @@ -29,7 +29,15 @@ def _import_versioned_module(): from cuda.core import system, utils +from cuda.core._context import Context, ContextOptions from cuda.core._device import Device +from cuda.core._device_resources import ( + DeviceResources, + SMResource, + SMResourceOptions, + WorkqueueResource, + WorkqueueResourceOptions, +) from cuda.core._event import Event, EventOptions from cuda.core._graphics import GraphicsResource from cuda.core._launch_config import LaunchConfig diff --git a/cuda_core/cuda/core/_context.pxd b/cuda_core/cuda/core/_context.pxd index 9e1a460f50f..92fa5700a06 100644 --- a/cuda_core/cuda/core/_context.pxd +++ b/cuda_core/cuda/core/_context.pxd @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from cuda.core._resource_handles cimport ContextHandle +from cuda.core._resource_handles cimport ContextHandle, GreenCtxHandle cdef class Context: """Cython declaration for Context class. @@ -18,3 +18,8 @@ cdef class Context: @staticmethod cdef Context _from_handle(type cls, ContextHandle h_context, int device_id) + + @staticmethod + cdef Context _from_green_ctx(type cls, GreenCtxHandle h_green_ctx, int device_id) + + cpdef close(self) diff --git a/cuda_core/cuda/core/_context.pyx b/cuda_core/cuda/core/_context.pyx index b2b21465c81..225500c7093 100644 --- a/cuda_core/cuda/core/_context.pyx +++ b/cuda_core/cuda/core/_context.pyx @@ -2,18 +2,34 @@ # # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from collections.abc import Sequence from dataclasses import dataclass +from cuda.bindings cimport cydriver +from cuda.core._device_resources cimport DeviceResources, SMResource, WorkqueueResource +from cuda.core._device_resources import SMResource, WorkqueueResource from cuda.core._resource_handles cimport ( ContextHandle, + GreenCtxHandle, + as_cu, + create_context_handle_from_green_ctx, + get_context_green_ctx, + get_last_error, as_intptr, as_py, ) +from cuda.core._stream import Stream, StreamOptions +from cuda.core._utils.cuda_utils cimport HANDLE_RETURN __all__ = ['Context', 'ContextOptions'] +DeviceResourcesT = Sequence[SMResource | WorkqueueResource] + + cdef class Context: """CUDA context wrapper. @@ -32,10 +48,21 @@ cdef class Context: ctx._device_id = device_id return ctx + @staticmethod + cdef Context _from_green_ctx(type cls, GreenCtxHandle h_green_ctx, int device_id): + """Create Context from an owning green context handle.""" + cdef ContextHandle h_context = create_context_handle_from_green_ctx(h_green_ctx) + if not h_context: + HANDLE_RETURN(get_last_error()) + raise RuntimeError("Failed to create CUDA context view from green context") + return Context._from_handle(cls, h_context, device_id) + @property def handle(self): """Return the underlying CUcontext handle.""" - if self._h_context.get() == NULL: + if not self._h_context: + return None + if as_cu(self._h_context) == NULL: return None return as_py(self._h_context) @@ -43,6 +70,66 @@ cdef class Context: def _handle(self): return self.handle + @property + def is_green(self) -> bool: + """True if this context was created from device resources.""" + if not self._h_context: + return False + return get_context_green_ctx(self._h_context).get() != NULL + + @property + def resources(self) -> DeviceResources: + """Query the hardware resources provisioned for this context. + + For green contexts, returns the resources this context was created + with (SM partition, workqueue config). For primary contexts, returns + the full device resources. + + Raises :class:`RuntimeError` if the context has been closed. + """ + if not self._h_context: + raise RuntimeError("Cannot query resources on a closed context") + return DeviceResources._init_from_ctx(self._h_context, self._device_id) + + def create_stream(self, options: StreamOptions | None = None): + """Create a new stream bound to this green context. + + This method is only available on green contexts. For primary + contexts, use :meth:`Device.create_stream` instead. + + Parameters + ---------- + options : :obj:`~_stream.StreamOptions`, optional + Customizable dataclass for stream creation options. + + Returns + ------- + :obj:`~_stream.Stream` + Newly created stream object. + """ + if not self._h_context: + raise RuntimeError("Cannot create a stream on a closed context") + if not self.is_green: + raise RuntimeError( + "Context.create_stream() is only supported on green contexts. " + "Use Device.create_stream() for primary contexts." + ) + + return Stream._init(options=options, device_id=self._device_id, ctx=self) + + cpdef close(self): + """Release this context wrapper's underlying CUDA handles.""" + cdef cydriver.CUcontext current_ctx + if self._h_context and as_cu(self._h_context) != NULL: + with nogil: + HANDLE_RETURN(cydriver.cuCtxGetCurrent(¤t_ctx)) + if current_ctx == as_cu(self._h_context): + raise RuntimeError( + "Cannot close a CUDA context while it is current. " + "Restore a previous context before closing this context." + ) + self._h_context.reset() + def __eq__(self, other): if not isinstance(other, Context): return NotImplemented @@ -57,9 +144,12 @@ cdef class Context: @dataclass -class ContextOptions: +cdef class ContextOptions: """Options for context creation. - Currently unused, reserved for future use. + Attributes + ---------- + resources : :obj:`~cuda.core.typing.DeviceResourcesT` + Device resources used to create a green context. """ - pass # TODO + resources: DeviceResourcesT diff --git a/cuda_core/cuda/core/_cpp/REGISTRY_DESIGN.md b/cuda_core/cuda/core/_cpp/REGISTRY_DESIGN.md index cbfc609686b..089f98acd93 100644 --- a/cuda_core/cuda/core/_cpp/REGISTRY_DESIGN.md +++ b/cuda_core/cuda/core/_cpp/REGISTRY_DESIGN.md @@ -29,7 +29,8 @@ carries timing/IPC flags, `KernelBox` carries the library dependency). Without this level, a round-tripped handle would produce a new Box with default metadata, losing information that was set at creation. -Instances: `event_registry`, `kernel_registry`, `graph_node_registry`. +Instances: `context_registry`, `stream_registry`, `event_registry`, +`kernel_registry`, `graph_node_registry`. ## Level 2: Resource Handle -> Python Object (Cython) diff --git a/cuda_core/cuda/core/_cpp/resource_handles.cpp b/cuda_core/cuda/core/_cpp/resource_handles.cpp index a21cd8a8aa5..2413d9473c7 100644 --- a/cuda_core/cuda/core/_cpp/resource_handles.cpp +++ b/cuda_core/cuda/core/_cpp/resource_handles.cpp @@ -29,6 +29,12 @@ namespace cuda_core { decltype(&cuDevicePrimaryCtxRetain) p_cuDevicePrimaryCtxRetain = nullptr; decltype(&cuDevicePrimaryCtxRelease) p_cuDevicePrimaryCtxRelease = nullptr; decltype(&cuCtxGetCurrent) p_cuCtxGetCurrent = nullptr; +decltype(&cuGreenCtxCreate) p_cuGreenCtxCreate = nullptr; +decltype(&cuGreenCtxDestroy) p_cuGreenCtxDestroy = nullptr; +decltype(&cuCtxFromGreenCtx) p_cuCtxFromGreenCtx = nullptr; +decltype(&cuDevResourceGenerateDesc) p_cuDevResourceGenerateDesc = nullptr; + +decltype(&cuGreenCtxStreamCreate) p_cuGreenCtxStreamCreate = nullptr; decltype(&cuStreamCreateWithPriority) p_cuStreamCreateWithPriority = nullptr; decltype(&cuStreamDestroy) p_cuStreamDestroy = nullptr; @@ -223,12 +229,112 @@ void clear_last_error() noexcept { namespace { struct ContextBox { CUcontext resource; + GreenCtxHandle h_green_ctx; +}; + +struct GreenCtxBox { + CUgreenCtx resource; }; + +static const ContextBox* get_box(const ContextHandle& h) noexcept { + const CUcontext* p = h.get(); + return reinterpret_cast( + reinterpret_cast(p) - offsetof(ContextBox, resource) + ); +} + +// See REGISTRY_DESIGN.md (Level 1: Driver Handle -> Resource Handle) +static HandleRegistry context_registry; } // namespace ContextHandle create_context_handle_ref(CUcontext ctx) { - auto box = std::make_shared(ContextBox{ctx}); - return ContextHandle(box, &box->resource); + if (!ctx) { + return {}; + } + if (auto h = context_registry.lookup(ctx)) { + return h; + } + auto box = std::shared_ptr( + new ContextBox{ctx, {}}, + [](const ContextBox* b) { + context_registry.unregister_handle(b->resource); + delete b; + } + ); + ContextHandle h(box, &box->resource); + context_registry.register_handle(ctx, h); + return h; +} + +ContextHandle create_context_handle_from_green_ctx(const GreenCtxHandle& h_green_ctx) { + if (!h_green_ctx) { + return {}; + } + if (!p_cuCtxFromGreenCtx) { + err = CUDA_ERROR_NOT_SUPPORTED; + return {}; + } + + GILReleaseGuard gil; + CUcontext ctx = nullptr; + if (CUDA_SUCCESS != (err = p_cuCtxFromGreenCtx(&ctx, as_cu(h_green_ctx)))) { + return {}; + } + + auto box = std::shared_ptr( + new ContextBox{ctx, h_green_ctx}, + [](const ContextBox* b) { + context_registry.unregister_handle(b->resource); + delete b; + } + ); + ContextHandle h(box, &box->resource); + context_registry.register_handle(ctx, h); + return h; +} + +GreenCtxHandle get_context_green_ctx(const ContextHandle& h) noexcept { + if (!h) { + return {}; + } + return get_box(h)->h_green_ctx; +} + +GreenCtxHandle create_green_ctx_handle(CUdevResource* resources, unsigned int nbResources, + CUdevice dev, unsigned int flags) { + if (!p_cuDevResourceGenerateDesc || !p_cuGreenCtxCreate || !p_cuGreenCtxDestroy) { + err = CUDA_ERROR_NOT_SUPPORTED; + return {}; + } + + GILReleaseGuard gil; + CUdevResourceDesc desc = nullptr; + if (CUDA_SUCCESS != (err = p_cuDevResourceGenerateDesc(&desc, resources, nbResources))) { + return {}; + } + + CUgreenCtx green_ctx = nullptr; + if (CUDA_SUCCESS != (err = p_cuGreenCtxCreate(&green_ctx, desc, dev, flags))) { + return {}; + } + + auto box = std::shared_ptr( + new GreenCtxBox{green_ctx}, + [](const GreenCtxBox* b) { + GILReleaseGuard gil; + p_cuGreenCtxDestroy(b->resource); + delete b; + } + ); + return GreenCtxHandle(box, &box->resource); +} + +GreenCtxHandle create_green_ctx_handle_ref(CUgreenCtx green_ctx) { + if (!green_ctx) { + return {}; + } + auto box = std::make_shared(GreenCtxBox{green_ctx}); + return GreenCtxHandle(box, &box->resource); } // Thread-local cache of primary contexts indexed by device ID @@ -250,14 +356,16 @@ ContextHandle get_primary_context(int device_id) { } auto box = std::shared_ptr( - new ContextBox{ctx}, + new ContextBox{ctx, {}}, [device_id](const ContextBox* b) { + context_registry.unregister_handle(b->resource); GILReleaseGuard gil; p_cuDevicePrimaryCtxRelease(device_id); delete b; } ); auto h = ContextHandle(box, &box->resource); + context_registry.register_handle(ctx, h); // Update cache if (static_cast(device_id) >= primary_context_cache.size()) { @@ -286,33 +394,79 @@ ContextHandle get_current_context() { namespace { struct StreamBox { CUstream resource; + ContextHandle h_context; }; + +static const StreamBox* get_box(const StreamHandle& h) noexcept { + const CUstream* p = h.get(); + return reinterpret_cast( + reinterpret_cast(p) - offsetof(StreamBox, resource) + ); +} + +// See REGISTRY_DESIGN.md (Level 1: Driver Handle -> Resource Handle) +static HandleRegistry stream_registry; } // namespace StreamHandle create_stream_handle(const ContextHandle& h_ctx, unsigned int flags, int priority) { GILReleaseGuard gil; CUstream stream; - if (CUDA_SUCCESS != (err = p_cuStreamCreateWithPriority(&stream, flags, priority))) { - return {}; + + // Dispatch: green context uses cuGreenCtxStreamCreate, primary uses cuStreamCreateWithPriority + GreenCtxHandle h_green = get_context_green_ctx(h_ctx); + if (h_green) { + if (!p_cuGreenCtxStreamCreate) { + err = CUDA_ERROR_NOT_SUPPORTED; + return {}; + } + if (CUDA_SUCCESS != (err = p_cuGreenCtxStreamCreate(&stream, as_cu(h_green), flags, priority))) { + return {}; + } + } else { + if (CUDA_SUCCESS != (err = p_cuStreamCreateWithPriority(&stream, flags, priority))) { + return {}; + } } auto box = std::shared_ptr( - new StreamBox{stream}, - [h_ctx](const StreamBox* b) { + new StreamBox{stream, h_ctx}, + [](const StreamBox* b) { + stream_registry.unregister_handle(b->resource); GILReleaseGuard gil; p_cuStreamDestroy(b->resource); delete b; } ); - return StreamHandle(box, &box->resource); + StreamHandle h(box, &box->resource); + stream_registry.register_handle(stream, h); + return h; } StreamHandle create_stream_handle_ref(CUstream stream) { - auto box = std::make_shared(StreamBox{stream}); - return StreamHandle(box, &box->resource); + if (auto h = stream_registry.lookup(stream)) { + return h; + } + auto box = std::shared_ptr( + new StreamBox{stream, {}}, + [](const StreamBox* b) { + stream_registry.unregister_handle(b->resource); + delete b; + } + ); + StreamHandle h(box, &box->resource); + stream_registry.register_handle(stream, h); + return h; } StreamHandle create_stream_handle_with_owner(CUstream stream, PyObject* owner) { + if (auto h = stream_registry.lookup(stream)) { + // Reuse handles that already carry structural context metadata, e.g. + // cuda-core-owned streams. Owner-backed foreign streams still need a + // fresh handle so the supplied owner is retained. + if (get_box(h)->h_context) { + return h; + } + } if (!owner) { return create_stream_handle_ref(stream); } @@ -324,8 +478,9 @@ StreamHandle create_stream_handle_with_owner(CUstream stream, PyObject* owner) { } Py_INCREF(owner); auto box = std::shared_ptr( - new StreamBox{stream}, + new StreamBox{stream, {}}, [owner](const StreamBox* b) { + stream_registry.unregister_handle(b->resource); GILAcquireGuard gil; if (gil.acquired()) { Py_DECREF(owner); @@ -333,7 +488,13 @@ StreamHandle create_stream_handle_with_owner(CUstream stream, PyObject* owner) { delete b; } ); - return StreamHandle(box, &box->resource); + StreamHandle h(box, &box->resource); + stream_registry.register_handle(stream, h); + return h; +} + +ContextHandle get_stream_context(const StreamHandle& h) noexcept { + return h ? get_box(h)->h_context : ContextHandle{}; } StreamHandle get_legacy_stream() { diff --git a/cuda_core/cuda/core/_cpp/resource_handles.hpp b/cuda_core/cuda/core/_cpp/resource_handles.hpp index d63fb869973..73d3364ba5f 100644 --- a/cuda_core/cuda/core/_cpp/resource_handles.hpp +++ b/cuda_core/cuda/core/_cpp/resource_handles.hpp @@ -59,6 +59,12 @@ void clear_last_error() noexcept; extern decltype(&cuDevicePrimaryCtxRetain) p_cuDevicePrimaryCtxRetain; extern decltype(&cuDevicePrimaryCtxRelease) p_cuDevicePrimaryCtxRelease; extern decltype(&cuCtxGetCurrent) p_cuCtxGetCurrent; +extern decltype(&cuGreenCtxCreate) p_cuGreenCtxCreate; +extern decltype(&cuGreenCtxDestroy) p_cuGreenCtxDestroy; +extern decltype(&cuCtxFromGreenCtx) p_cuCtxFromGreenCtx; +extern decltype(&cuDevResourceGenerateDesc) p_cuDevResourceGenerateDesc; + +extern decltype(&cuGreenCtxStreamCreate) p_cuGreenCtxStreamCreate; extern decltype(&cuStreamCreateWithPriority) p_cuStreamCreateWithPriority; extern decltype(&cuStreamDestroy) p_cuStreamDestroy; @@ -142,6 +148,7 @@ extern NvJitLinkDestroyFn p_nvJitLinkDestroy; // ============================================================================ using ContextHandle = std::shared_ptr; +using GreenCtxHandle = std::shared_ptr; using StreamHandle = std::shared_ptr; using EventHandle = std::shared_ptr; using MemoryPoolHandle = std::shared_ptr; @@ -164,6 +171,21 @@ using FileDescriptorHandle = std::shared_ptr; // Function to create a non-owning context handle (references existing context). ContextHandle create_context_handle_ref(CUcontext ctx); +// Create a context handle for the CUcontext view of the provided green context. +// The returned ContextHandle keeps the green context alive, but the CUcontext +// view is non-owning and is not destroyed independently. +ContextHandle create_context_handle_from_green_ctx(const GreenCtxHandle& h_green_ctx); + +// Return the green context dependency associated with a ContextHandle, if any. +GreenCtxHandle get_context_green_ctx(const ContextHandle& h) noexcept; + +// Create an owning green context handle from a list of device resources. +GreenCtxHandle create_green_ctx_handle(CUdevResource* resources, unsigned int nbResources, + CUdevice dev, unsigned int flags); + +// Create a non-owning green context handle. +GreenCtxHandle create_green_ctx_handle_ref(CUgreenCtx ctx); + // Get handle to the primary context for a device (with thread-local caching) // Returns empty handle on error (caller must check) ContextHandle get_primary_context(int device_id); @@ -193,6 +215,9 @@ StreamHandle create_stream_handle_ref(CUstream stream); // The owner is responsible for keeping the stream's context alive. StreamHandle create_stream_handle_with_owner(CUstream stream, PyObject* owner); +// Return the context dependency associated with a stream handle, if any. +ContextHandle get_stream_context(const StreamHandle& h) noexcept; + // Get non-owning handle to the legacy default stream (CU_STREAM_LEGACY) // Note: Legacy stream has no specific context dependency. StreamHandle get_legacy_stream(); @@ -501,6 +526,10 @@ inline CUcontext as_cu(const ContextHandle& h) noexcept { return h ? *h : nullptr; } +inline CUgreenCtx as_cu(const GreenCtxHandle& h) noexcept { + return h ? *h : nullptr; +} + inline CUstream as_cu(const StreamHandle& h) noexcept { return h ? *h : nullptr; } @@ -559,6 +588,10 @@ inline std::intptr_t as_intptr(const ContextHandle& h) noexcept { return reinterpret_cast(as_cu(h)); } +inline std::intptr_t as_intptr(const GreenCtxHandle& h) noexcept { + return reinterpret_cast(as_cu(h)); +} + inline std::intptr_t as_intptr(const StreamHandle& h) noexcept { return reinterpret_cast(as_cu(h)); } @@ -649,6 +682,10 @@ inline PyObject* as_py(const ContextHandle& h) noexcept { return detail::make_py("cuda.bindings.driver", "CUcontext", as_intptr(h)); } +inline PyObject* as_py(const GreenCtxHandle& h) noexcept { + return detail::make_py("cuda.bindings.driver", "CUgreenCtx", as_intptr(h)); +} + inline PyObject* as_py(const StreamHandle& h) noexcept { return detail::make_py("cuda.bindings.driver", "CUstream", as_intptr(h)); } diff --git a/cuda_core/cuda/core/_device.pyx b/cuda_core/cuda/core/_device.pyx index 1ea2df564c4..32b96acb99d 100644 --- a/cuda_core/cuda/core/_device.pyx +++ b/cuda_core/cuda/core/_device.pyx @@ -7,19 +7,24 @@ from __future__ import annotations cimport cpython from cuda.bindings cimport cydriver -from cuda.core._utils.cuda_utils cimport HANDLE_RETURN +from cuda.core._utils.cuda_utils cimport check_or_create_options, HANDLE_RETURN +from libc.stdlib cimport free, malloc import threading from cuda.core._context cimport Context from cuda.core._context import ContextOptions +from cuda.core._device_resources cimport DeviceResources, SMResource, WorkqueueResource from cuda.core._event cimport Event as cyEvent from cuda.core._event import Event, EventOptions from cuda.core._memory._buffer cimport Buffer, MemoryResource from cuda.core._resource_handles cimport ( ContextHandle, + GreenCtxHandle, create_context_handle_ref, + create_green_ctx_handle, get_primary_context, + get_last_error, as_cu, ) @@ -954,7 +959,16 @@ class Device: Default value of `None` return the currently used device. """ - __slots__ = ("_device_id", "_memory_resource", "_has_inited", "_properties", "_uuid", "_context", "__weakref__") + __slots__ = ( + "_device_id", + "_memory_resource", + "_has_inited", + "_properties", + "_resources", + "_uuid", + "_context", + "__weakref__", + ) def __new__(cls, device_id: Device | int | None = None): if isinstance(device_id, Device): @@ -1100,6 +1114,13 @@ class Device: return self._properties + @property + def resources(self) -> DeviceResources: + """Return the hardware resource query namespace for this device.""" + if self._resources is None: + self._resources = DeviceResources._init(self._device_id) + return self._resources + @property def compute_capability(self) -> ComputeCapability: """Return a named tuple with 2 fields: major and minor.""" @@ -1219,6 +1240,7 @@ class Device: """ cdef ContextHandle h_context cdef cydriver.CUcontext prev_ctx, curr_ctx + cdef Context prev_owned = None if ctx is not None: # TODO: revisit once Context is cythonized @@ -1228,6 +1250,8 @@ class Device: "the provided context was created on the device with" f" id={ctx._device_id}, which is different from the target id={self._device_id}" ) + if self._has_inited and self._context is not None: + prev_owned = self._context # prev_ctx is the previous context curr_ctx = as_cu(ctx._h_context) prev_ctx = NULL @@ -1237,6 +1261,8 @@ class Device: self._has_inited = True self._context = ctx # Store owning context reference if prev_ctx != NULL: + if prev_owned is not None and as_cu(prev_owned._h_context) == prev_ctx: + return prev_owned return Context._from_handle(Context, create_context_handle_ref(prev_ctx), self._device_id) else: # use primary ctx @@ -1266,7 +1292,63 @@ class Device: Newly created context object. """ - raise NotImplementedError("WIP: https://github.com/NVIDIA/cuda-python/issues/189") + cdef int n_resources + cdef int i + cdef object resources + cdef object res + cdef SMResource sm_res + cdef WorkqueueResource wq_res + cdef cydriver.CUdevResource* c_resources = NULL + cdef GreenCtxHandle h_green + + if options is None: + raise ValueError( + "options with device resources must be provided to create a green context" + ) + + options = check_or_create_options(ContextOptions, options, "Context options") + if options.resources is None: + raise ValueError( + "ContextOptions.resources must be provided to create a green context" + ) + + resources = tuple(options.resources) + if len(resources) == 0: + raise ValueError("ContextOptions.resources must not be empty") + + n_resources = len(resources) + c_resources = malloc( + n_resources * sizeof(cydriver.CUdevResource) + ) + if c_resources == NULL: + raise MemoryError() + + try: + for i, res in enumerate(resources): + if isinstance(res, SMResource): + sm_res = res + if not sm_res._is_usable: + raise ValueError("dry-run SMResource objects cannot be used to create a context") + c_resources[i] = sm_res._resource + elif isinstance(res, WorkqueueResource): + wq_res = res + c_resources[i] = wq_res._wq_config_resource + else: + raise TypeError(f"Unsupported context resource type: {type(res)}") + + h_green = create_green_ctx_handle( + c_resources, + (n_resources), + (self._device_id), + (cydriver.CUgreenCtxCreate_flags.CU_GREEN_CTX_DEFAULT_STREAM), + ) + if h_green.get() == NULL: + HANDLE_RETURN(get_last_error()) + raise RuntimeError("Failed to create CUDA green context") + + return Context._from_green_ctx(Context, h_green, self._device_id) + finally: + free(c_resources) def create_stream(self, obj: IsStreamT | None = None, options: StreamOptions | None = None) -> Stream: """Create a Stream object. @@ -1429,6 +1511,7 @@ cdef inline list Device_ensure_tls_devices(cls): device._memory_resource = None device._has_inited = False device._properties = None + device._resources = None device._uuid = None device._context = None devices.append(device) diff --git a/cuda_core/cuda/core/_device_resources.pxd b/cuda_core/cuda/core/_device_resources.pxd new file mode 100644 index 00000000000..d618c24cf10 --- /dev/null +++ b/cuda_core/cuda/core/_device_resources.pxd @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from cuda.bindings cimport cydriver +from cuda.core._resource_handles cimport ContextHandle, GreenCtxHandle + + +cdef class SMResource: + cdef: + cydriver.CUdevResource _resource + unsigned int _sm_count + unsigned int _min_partition_size + unsigned int _coscheduled_alignment + unsigned int _flags + bint _is_usable + object __weakref__ + + @staticmethod + cdef SMResource _from_dev_resource(cydriver.CUdevResource res, int device_id) + + @staticmethod + cdef SMResource _from_split_resource(cydriver.CUdevResource res, SMResource parent, bint is_usable) + + +cdef class WorkqueueResource: + cdef: + cydriver.CUdevResource _wq_config_resource + cydriver.CUdevResource _wq_resource + object __weakref__ + + @staticmethod + cdef WorkqueueResource _from_dev_resources( + cydriver.CUdevResource wq_config, + cydriver.CUdevResource wq, + ) + + +cdef class DeviceResources: + cdef: + int _device_id + ContextHandle _h_context # NULL for device-level queries + object __weakref__ + + @staticmethod + cdef DeviceResources _init(int device_id) + + @staticmethod + cdef DeviceResources _init_from_ctx(ContextHandle h_context, int device_id) + + cdef inline int _query_sm(self, cydriver.CUdevResource* res) except?-1 nogil diff --git a/cuda_core/cuda/core/_device_resources.pyx b/cuda_core/cuda/core/_device_resources.pyx new file mode 100644 index 00000000000..e4851a73a41 --- /dev/null +++ b/cuda_core/cuda/core/_device_resources.pyx @@ -0,0 +1,663 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections.abc import Sequence as SequenceABC +from dataclasses import dataclass + +from libc.stdint cimport intptr_t +from libc.stdlib cimport free, malloc +from libc.string cimport memset + +from cuda.bindings cimport cydriver +from cuda.core._resource_handles cimport ContextHandle, GreenCtxHandle, as_cu, get_context_green_ctx +from cuda.core._utils.cuda_utils cimport check_or_create_options, HANDLE_RETURN +from cuda.core._utils.cuda_utils import is_sequence +from cuda.core._utils.version cimport cy_binding_version, cy_driver_version + + +__all__ = [ + "DeviceResources", + "SMResource", + "SMResourceOptions", + "WorkqueueResource", + "WorkqueueResourceOptions", +] + + +# Module-level cached version checks (trinary: 0=unchecked, 1=supported, -1=unsupported) +cdef int _green_ctx_checked = 0 +cdef int _workqueue_checked = 0 +cdef str _green_ctx_err_msg = "" +cdef str _workqueue_err_msg = "" + + +cdef inline int _check_green_ctx_support() except?-1: + global _green_ctx_checked, _green_ctx_err_msg + if _green_ctx_checked == 1: + return 0 + if _green_ctx_checked == -1: + raise RuntimeError(_green_ctx_err_msg) + cdef tuple drv = cy_driver_version() + cdef tuple bind = cy_binding_version() + if drv < (12, 4, 0): + _green_ctx_err_msg = ( + "Green context support requires CUDA driver 12.4 or newer " + f"(current driver: {'.'.join(map(str, drv))})" + ) + _green_ctx_checked = -1 + raise RuntimeError(_green_ctx_err_msg) + if bind < (12, 4, 0): + _green_ctx_err_msg = ( + "Green context support requires cuda.bindings 12.4 or newer " + f"(current bindings: {'.'.join(map(str, bind))})" + ) + _green_ctx_checked = -1 + raise RuntimeError(_green_ctx_err_msg) + _green_ctx_checked = 1 + return 0 + + +cdef inline int _check_workqueue_support() except?-1: + global _workqueue_checked, _workqueue_err_msg + if _workqueue_checked == 1: + return 0 + if _workqueue_checked == -1: + raise RuntimeError(_workqueue_err_msg) + cdef tuple drv = cy_driver_version() + cdef tuple bind = cy_binding_version() + if drv < (13, 1, 0): + _workqueue_err_msg = ( + "WorkqueueResource requires CUDA driver 13.1 or newer " + f"(current driver: {'.'.join(map(str, drv))})" + ) + _workqueue_checked = -1 + raise RuntimeError(_workqueue_err_msg) + if bind < (13, 1, 0): + _workqueue_err_msg = ( + "WorkqueueResource requires cuda.bindings 13.1 or newer " + f"(current bindings: {'.'.join(map(str, bind))})" + ) + _workqueue_checked = -1 + raise RuntimeError(_workqueue_err_msg) + _workqueue_checked = 1 + return 0 + + +@dataclass +cdef class SMResourceOptions: + """Customizable :obj:`SMResource.split` options. + + Each field accepts a scalar (for a single group) or a ``Sequence`` + (for multiple groups). ``count`` drives the number of groups; other + ``Sequence`` fields must match its length. + + Attributes + ---------- + count : int or Sequence[int], optional + Requested SM count per group. ``None`` means discovery mode + (auto-detect). (Default to ``None``) + coscheduled_sm_count : int or Sequence[int], optional + Minimum number of SMs guaranteed to be co-scheduled in each + group. (Default to ``None``) + preferred_coscheduled_sm_count : int or Sequence[int], optional + Preferred co-scheduled SM count; the driver tries to satisfy + this but may fall back to ``coscheduled_sm_count``. + (Default to ``None``) + """ + + count: int | SequenceABC | None = None + coscheduled_sm_count: int | SequenceABC | None = None + preferred_coscheduled_sm_count: int | SequenceABC | None = None + + +@dataclass +cdef class WorkqueueResourceOptions: + """Customizable :obj:`WorkqueueResource.configure` options. + + Attributes + ---------- + sharing_scope : str, optional + Workqueue sharing scope. Accepted values: ``"device_ctx"`` + or ``"green_ctx_balanced"``. (Default to ``None``) + """ + + sharing_scope: str | None = None + + +cdef inline int _validate_split_field_length( + object value, str field_name, int n_groups, bint count_is_scalar +) except?-1: + if count_is_scalar: + if is_sequence(value): + raise ValueError( + f"{field_name} is a Sequence but count is scalar; " + "count must be a Sequence to specify multiple groups" + ) + elif is_sequence(value) and len(value) != n_groups: + raise ValueError( + f"{field_name} has length {len(value)}, expected {n_groups} " + "(must match count)" + ) + return 0 + + +cdef inline int _resolve_group_count(SMResourceOptions options) except?-1: + cdef object count = options.count + cdef int n_groups + cdef bint count_is_scalar + + if count is None or isinstance(count, int): + n_groups = 1 + count_is_scalar = True + elif is_sequence(count): + n_groups = len(count) + if n_groups == 0: + raise ValueError("count sequence must not be empty") + count_is_scalar = False + else: + raise TypeError(f"count must be int, Sequence, or None, got {type(count)}") + + _validate_split_field_length( + options.coscheduled_sm_count, + "coscheduled_sm_count", + n_groups, + count_is_scalar, + ) + _validate_split_field_length( + options.preferred_coscheduled_sm_count, + "preferred_coscheduled_sm_count", + n_groups, + count_is_scalar, + ) + return n_groups + + +cdef inline object _broadcast_field(object value, int n_groups): + if is_sequence(value): + return list(value) + return [value] * n_groups + + +cdef inline unsigned int _to_sm_count(object value) except? 0: + """Convert a count value to unsigned int. None maps to 0 (discovery).""" + if value is None: + return 0 + if value < 0: + raise ValueError(f"count must be non-negative, got {value}") + return (value) + + +cdef int _structured_split_checked = 0 + +cdef inline bint _can_use_structured_sm_split(): + """Check if cuDevSmResourceSplit (13.1+) is available. Cached.""" + global _structured_split_checked + if _structured_split_checked != 0: + return _structured_split_checked == 1 + IF CUDA_CORE_BUILD_MAJOR >= 13: + if cy_driver_version() >= (13, 1, 0) and cy_binding_version() >= (13, 1, 0): + _structured_split_checked = 1 + return True + _structured_split_checked = -1 + return False + + +cdef object _resolve_split_by_count_request(SMResourceOptions options): + cdef int n_groups = _resolve_group_count(options) + cdef list counts = _broadcast_field(options.count, n_groups) + cdef object first = counts[0] + cdef object value + cdef unsigned int min_count + + if options.coscheduled_sm_count is not None: + raise RuntimeError( + "SMResourceOptions.coscheduled_sm_count requires the CUDA 13.1 " + "structured SM split API" + ) + if options.preferred_coscheduled_sm_count is not None: + raise RuntimeError( + "SMResourceOptions.preferred_coscheduled_sm_count requires the " + "CUDA 13.1 structured SM split API" + ) + + for value in counts[1:]: + if value != first: + raise RuntimeError( + "CUDA 12 SM splitting only supports homogeneous count values; " + "use CUDA 13.1 or newer for per-group counts" + ) + + min_count = _to_sm_count(first) + return n_groups, min_count + + +IF CUDA_CORE_BUILD_MAJOR >= 13: + cdef inline int _fill_group_params( + cydriver.CU_DEV_SM_RESOURCE_GROUP_PARAMS* params, + int n_groups, + SMResourceOptions options, + ) except?-1: + cdef list counts = _broadcast_field(options.count, n_groups) + cdef list coscheduled = _broadcast_field(options.coscheduled_sm_count, n_groups) + cdef list preferred = _broadcast_field(options.preferred_coscheduled_sm_count, n_groups) + cdef int i + + for i in range(n_groups): + memset(¶ms[i], 0, sizeof(cydriver.CU_DEV_SM_RESOURCE_GROUP_PARAMS)) + params[i].smCount = _to_sm_count(counts[i]) + if coscheduled[i] is not None: + params[i].coscheduledSmCount = (coscheduled[i]) + if preferred[i] is not None: + params[i].preferredCoscheduledSmCount = (preferred[i]) + params[i].flags = 0 + return 0 + + + cdef object _split_with_general_api(SMResource sm, SMResourceOptions options, bint dry_run): + cdef int n_groups = _resolve_group_count(options) + cdef cydriver.CUdevResource* result = NULL + cdef cydriver.CUdevResource remaining + cdef cydriver.CUdevResource synth + cdef cydriver.CU_DEV_SM_RESOURCE_GROUP_PARAMS* params = NULL + cdef list groups = [] + cdef int i + + params = malloc( + n_groups * sizeof(cydriver.CU_DEV_SM_RESOURCE_GROUP_PARAMS) + ) + if params == NULL: + raise MemoryError() + + try: + _fill_group_params(params, n_groups, options) + + if not dry_run: + result = malloc( + n_groups * sizeof(cydriver.CUdevResource) + ) + if result == NULL: + raise MemoryError() + + memset(&remaining, 0, sizeof(cydriver.CUdevResource)) + with nogil: + HANDLE_RETURN(cydriver.cuDevSmResourceSplit( + result, + (n_groups), + &sm._resource, + &remaining, + 0, + params, + )) + + if result != NULL: + for i in range(n_groups): + groups.append(SMResource._from_split_resource(result[i], sm, True)) + return groups, SMResource._from_split_resource(remaining, sm, True) + + for i in range(n_groups): + memset(&synth, 0, sizeof(cydriver.CUdevResource)) + synth.type = cydriver.CUdevResourceType.CU_DEV_RESOURCE_TYPE_SM + synth.sm.smCount = params[i].smCount + groups.append(SMResource._from_split_resource(synth, sm, False)) + return groups, SMResource._from_split_resource(remaining, sm, False) + finally: + if params != NULL: + free(params) + if result != NULL: + free(result) +ELSE: + cdef object _split_with_general_api(SMResource sm, SMResourceOptions options, bint dry_run): + raise RuntimeError( + "SMResource.split() requires cuda.core to be built with CUDA 13.x bindings" + ) + + +cdef object _split_with_count_api(SMResource sm, SMResourceOptions options, bint dry_run): + cdef object request = _resolve_split_by_count_request(options) + cdef unsigned int nb_groups = (request[0]) + cdef unsigned int min_count = (request[1]) + cdef unsigned int actual_groups = nb_groups + cdef cydriver.CUdevResource* result = NULL + cdef cydriver.CUdevResource remaining + cdef list groups = [] + cdef int i + + result = malloc(nb_groups * sizeof(cydriver.CUdevResource)) + if result == NULL: + raise MemoryError() + + try: + memset(&remaining, 0, sizeof(cydriver.CUdevResource)) + with nogil: + HANDLE_RETURN(cydriver.cuDevSmResourceSplitByCount( + result, + &actual_groups, + &sm._resource, + &remaining, + 0, + min_count, + )) + + for i in range(actual_groups): + if dry_run: + groups.append(SMResource._from_split_resource(result[i], sm, False)) + else: + groups.append(SMResource._from_split_resource(result[i], sm, True)) + if dry_run: + return groups, SMResource._from_split_resource(remaining, sm, False) + return groups, SMResource._from_split_resource(remaining, sm, True) + finally: + free(result) + + +cdef inline unsigned int _sm_resource_granularity(int device_id) except? 0: + cdef int major + + with nogil: + HANDLE_RETURN(cydriver.cuDeviceGetAttribute( + &major, + cydriver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, + (device_id), + )) + if major >= 9: + return 8 + return 2 + + +cdef inline unsigned int _fallback_if_zero(unsigned int value, unsigned int fallback) noexcept: + if value != 0: + return value + return fallback + + +cdef class SMResource: + """Represent an SM (streaming multiprocessor) resource partition. + + Instances are returned by :obj:`DeviceResources.sm` or + :meth:`SMResource.split` and cannot be instantiated directly. + """ + + def __init__(self, *args, **kwargs): + raise RuntimeError( + "SMResource cannot be instantiated directly. " + "Use dev.resources.sm or SMResource.split()." + ) + + @staticmethod + cdef SMResource _from_dev_resource(cydriver.CUdevResource res, int device_id): + cdef SMResource self = SMResource.__new__(SMResource) + self._resource = res + self._sm_count = res.sm.smCount + IF CUDA_CORE_BUILD_MAJOR >= 13: + self._min_partition_size = res.sm.minSmPartitionSize + self._coscheduled_alignment = res.sm.smCoscheduledAlignment + self._flags = res.sm.flags + ELSE: + self._min_partition_size = _sm_resource_granularity(device_id) + self._coscheduled_alignment = self._min_partition_size + self._flags = 0 + self._is_usable = True + return self + + @staticmethod + cdef SMResource _from_split_resource(cydriver.CUdevResource res, SMResource parent, bint is_usable): + cdef SMResource self = SMResource.__new__(SMResource) + self._resource = res + self._sm_count = res.sm.smCount + IF CUDA_CORE_BUILD_MAJOR >= 13: + self._min_partition_size = _fallback_if_zero( + res.sm.minSmPartitionSize, + parent._min_partition_size, + ) + self._coscheduled_alignment = _fallback_if_zero( + res.sm.smCoscheduledAlignment, + parent._coscheduled_alignment, + ) + self._flags = res.sm.flags + ELSE: + self._min_partition_size = parent._min_partition_size + self._coscheduled_alignment = parent._coscheduled_alignment + self._flags = parent._flags + self._is_usable = is_usable + return self + + @property + def handle(self) -> int: + """Return the address of the underlying ``CUdevResource`` struct.""" + return (&self._resource) + + @property + def sm_count(self) -> int: + """Total SMs available in this resource.""" + return self._sm_count + + @property + def min_partition_size(self) -> int: + """Minimum SM count required to create a partition.""" + return self._min_partition_size + + @property + def coscheduled_alignment(self) -> int: + """Number of SMs guaranteed to be co-scheduled.""" + return self._coscheduled_alignment + + @property + def flags(self) -> int: + """Raw flags from the underlying SM resource.""" + return self._flags + + def split(self, options not None, *, bint dry_run=False): + """Split this SM resource into groups and a remainder. + + Parameters + ---------- + options : :obj:`SMResourceOptions` + Split configuration (count, co-scheduling constraints). + dry_run : bool, optional + If ``True``, return filled-in metadata without creating + usable resource objects. (Default to ``False``) + + Returns + ------- + tuple[list[:obj:`SMResource`], :obj:`SMResource`] + ``(groups, remainder)`` where each group holds a disjoint + SM partition and *remainder* holds any unassigned SMs. + """ + cdef SMResourceOptions opts = check_or_create_options( + SMResourceOptions, options, "SM resource options" + ) + _resolve_group_count(opts) + _check_green_ctx_support() + if _can_use_structured_sm_split(): + return _split_with_general_api(self, opts, dry_run) + # SplitByCount requires the same 12.4+ as green ctx support (already checked above) + return _split_with_count_api(self, opts, dry_run) + + +cdef class WorkqueueResource: + """Represent a workqueue resource for a device or green context. + + Merges ``CU_DEV_RESOURCE_TYPE_WORKQUEUE_CONFIG`` and + ``CU_DEV_RESOURCE_TYPE_WORKQUEUE`` under one user-facing type. + Instances are returned by :obj:`DeviceResources.workqueue` and + cannot be instantiated directly. + """ + + def __init__(self, *args, **kwargs): + raise RuntimeError( + "WorkqueueResource cannot be instantiated directly. " + "Use dev.resources.workqueue." + ) + + @staticmethod + cdef WorkqueueResource _from_dev_resources( + cydriver.CUdevResource wq_config, + cydriver.CUdevResource wq, + ): + cdef WorkqueueResource self = WorkqueueResource.__new__(WorkqueueResource) + self._wq_config_resource = wq_config + self._wq_resource = wq + return self + + @property + def handle(self) -> int: + """Return the address of the underlying config ``CUdevResource`` struct.""" + return (&self._wq_config_resource) + + def configure(self, options not None): + """Configure the workqueue resource in place. + + Parameters + ---------- + options : :obj:`WorkqueueResourceOptions` + Configuration options (sharing scope, etc.). + """ + cdef WorkqueueResourceOptions opts = check_or_create_options( + WorkqueueResourceOptions, options, "Workqueue resource options" + ) + _check_green_ctx_support() + _check_workqueue_support() + if opts.sharing_scope is None: + return None + + IF CUDA_CORE_BUILD_MAJOR >= 13: + if opts.sharing_scope == "device_ctx": + self._wq_config_resource.wqConfig.sharingScope = ( + cydriver.CUdevWorkqueueConfigScope.CU_WORKQUEUE_SCOPE_DEVICE_CTX + ) + elif opts.sharing_scope == "green_ctx_balanced": + self._wq_config_resource.wqConfig.sharingScope = ( + cydriver.CUdevWorkqueueConfigScope.CU_WORKQUEUE_SCOPE_GREEN_CTX_BALANCED + ) + else: + raise ValueError( + f"Unknown sharing_scope: {opts.sharing_scope!r}. " + "Expected 'device_ctx' or 'green_ctx_balanced'." + ) + ELSE: + raise RuntimeError( + "WorkqueueResource requires cuda.core to be built with CUDA 13.x bindings" + ) + + +cdef class DeviceResources: + """Namespace for hardware resource queries. + + When obtained via :obj:`Device.resources`, queries return full device + resources. When obtained via :obj:`Context.resources` or + :obj:`Stream.resources`, queries return the resources provisioned for + that context. + + This class cannot be instantiated directly. + """ + + def __init__(self, *args, **kwargs): + raise RuntimeError( + "DeviceResources cannot be instantiated directly. " + "Use dev.resources or ctx.resources." + ) + + @staticmethod + cdef DeviceResources _init(int device_id): + cdef DeviceResources self = DeviceResources.__new__(DeviceResources) + self._device_id = device_id + # _h_context is default empty — queries use cuDeviceGetDevResource + return self + + @staticmethod + cdef DeviceResources _init_from_ctx(ContextHandle h_context, int device_id): + cdef DeviceResources self = DeviceResources.__new__(DeviceResources) + self._device_id = device_id + self._h_context = h_context + return self + + cdef inline int _query_sm(self, cydriver.CUdevResource* res) except?-1 nogil: + """Query SM resource from either device or context.""" + cdef GreenCtxHandle h_green + if self._h_context: + h_green = get_context_green_ctx(self._h_context) + if h_green: + HANDLE_RETURN(cydriver.cuGreenCtxGetDevResource( + as_cu(h_green), res, + cydriver.CUdevResourceType.CU_DEV_RESOURCE_TYPE_SM, + )) + else: + HANDLE_RETURN(cydriver.cuCtxGetDevResource( + as_cu(self._h_context), res, + cydriver.CUdevResourceType.CU_DEV_RESOURCE_TYPE_SM, + )) + else: + HANDLE_RETURN(cydriver.cuDeviceGetDevResource( + (self._device_id), res, + cydriver.CUdevResourceType.CU_DEV_RESOURCE_TYPE_SM, + )) + return 0 + + @property + def sm(self) -> SMResource: + """Return the :obj:`SMResource` for this device or context.""" + _check_green_ctx_support() + cdef cydriver.CUdevResource res + with nogil: + self._query_sm(&res) + return SMResource._from_dev_resource(res, self._device_id) + + @property + def workqueue(self) -> WorkqueueResource: + """Return the :obj:`WorkqueueResource` for this device or context.""" + _check_green_ctx_support() + _check_workqueue_support() + cdef cydriver.CUdevResource _wq_config + cdef cydriver.CUdevResource _wq + + IF CUDA_CORE_BUILD_MAJOR >= 13: + cdef GreenCtxHandle h_green + if self._h_context: + h_green = get_context_green_ctx(self._h_context) + if h_green: + # Green context query + with nogil: + HANDLE_RETURN(cydriver.cuGreenCtxGetDevResource( + as_cu(h_green), + &_wq_config, + cydriver.CUdevResourceType.CU_DEV_RESOURCE_TYPE_WORKQUEUE_CONFIG, + )) + HANDLE_RETURN(cydriver.cuGreenCtxGetDevResource( + as_cu(h_green), + &_wq, + cydriver.CUdevResourceType.CU_DEV_RESOURCE_TYPE_WORKQUEUE, + )) + else: + # Primary context query + with nogil: + HANDLE_RETURN(cydriver.cuCtxGetDevResource( + as_cu(self._h_context), + &_wq_config, + cydriver.CUdevResourceType.CU_DEV_RESOURCE_TYPE_WORKQUEUE_CONFIG, + )) + HANDLE_RETURN(cydriver.cuCtxGetDevResource( + as_cu(self._h_context), + &_wq, + cydriver.CUdevResourceType.CU_DEV_RESOURCE_TYPE_WORKQUEUE, + )) + else: + # Device-level query + with nogil: + HANDLE_RETURN(cydriver.cuDeviceGetDevResource( + (self._device_id), + &_wq_config, + cydriver.CUdevResourceType.CU_DEV_RESOURCE_TYPE_WORKQUEUE_CONFIG, + )) + HANDLE_RETURN(cydriver.cuDeviceGetDevResource( + (self._device_id), + &_wq, + cydriver.CUdevResourceType.CU_DEV_RESOURCE_TYPE_WORKQUEUE, + )) + return WorkqueueResource._from_dev_resources(_wq_config, _wq) + ELSE: + raise RuntimeError( + "WorkqueueResource requires cuda.core to be built with CUDA 13.x bindings" + ) diff --git a/cuda_core/cuda/core/_resource_handles.pxd b/cuda_core/cuda/core/_resource_handles.pxd index 9e7307e821b..ade94beb94b 100644 --- a/cuda_core/cuda/core/_resource_handles.pxd +++ b/cuda_core/cuda/core/_resource_handles.pxd @@ -20,6 +20,7 @@ from cuda.bindings cimport cynvjitlink cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": # Handle types ctypedef shared_ptr[const cydriver.CUcontext] ContextHandle + ctypedef shared_ptr[const cydriver.CUgreenCtx] GreenCtxHandle ctypedef shared_ptr[const cydriver.CUstream] StreamHandle ctypedef shared_ptr[const cydriver.CUevent] EventHandle ctypedef shared_ptr[const cydriver.CUmemoryPool] MemoryPoolHandle @@ -45,6 +46,7 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": # as_cu() - extract the raw CUDA handle (inline C++) cydriver.CUcontext as_cu(ContextHandle h) noexcept nogil + cydriver.CUgreenCtx as_cu(GreenCtxHandle h) noexcept nogil cydriver.CUstream as_cu(StreamHandle h) noexcept nogil cydriver.CUevent as_cu(EventHandle h) noexcept nogil cydriver.CUmemoryPool as_cu(MemoryPoolHandle h) noexcept nogil @@ -61,6 +63,7 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": # as_intptr() - extract handle as intptr_t for Python interop (inline C++) intptr_t as_intptr(ContextHandle h) noexcept nogil + intptr_t as_intptr(GreenCtxHandle h) noexcept nogil intptr_t as_intptr(StreamHandle h) noexcept nogil intptr_t as_intptr(EventHandle h) noexcept nogil intptr_t as_intptr(MemoryPoolHandle h) noexcept nogil @@ -78,6 +81,7 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": # as_py() - convert handle to Python wrapper object (inline C++; requires GIL) object as_py(ContextHandle h) + object as_py(GreenCtxHandle h) object as_py(StreamHandle h) object as_py(EventHandle h) object as_py(MemoryPoolHandle h) @@ -107,6 +111,12 @@ cdef void clear_last_error() noexcept nogil # Context handles cdef ContextHandle create_context_handle_ref(cydriver.CUcontext ctx) except+ nogil +cdef ContextHandle create_context_handle_from_green_ctx(const GreenCtxHandle& h_green_ctx) except+ nogil +cdef GreenCtxHandle get_context_green_ctx(const ContextHandle& h) noexcept nogil +cdef GreenCtxHandle create_green_ctx_handle( + cydriver.CUdevResource* resources, unsigned int nbResources, + cydriver.CUdevice dev, unsigned int flags) except+ nogil +cdef GreenCtxHandle create_green_ctx_handle_ref(cydriver.CUgreenCtx ctx) except+ nogil cdef ContextHandle get_primary_context(int device_id) except+ nogil cdef ContextHandle get_current_context() except+ nogil @@ -115,6 +125,7 @@ cdef StreamHandle create_stream_handle( const ContextHandle& h_ctx, unsigned int flags, int priority) except+ nogil cdef StreamHandle create_stream_handle_ref(cydriver.CUstream stream) except+ nogil cdef StreamHandle create_stream_handle_with_owner(cydriver.CUstream stream, object owner) except+ nogil +cdef ContextHandle get_stream_context(const StreamHandle& h) noexcept nogil cdef StreamHandle get_legacy_stream() except+ nogil cdef StreamHandle get_per_thread_stream() except+ nogil diff --git a/cuda_core/cuda/core/_resource_handles.pyx b/cuda_core/cuda/core/_resource_handles.pyx index 2090f5026d0..59e47f23462 100644 --- a/cuda_core/cuda/core/_resource_handles.pyx +++ b/cuda_core/cuda/core/_resource_handles.pyx @@ -20,6 +20,7 @@ from cuda.bindings cimport cynvjitlink from ._resource_handles cimport ( ContextHandle, + GreenCtxHandle, StreamHandle, EventHandle, MemoryPoolHandle, @@ -55,6 +56,15 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": # Context handles ContextHandle create_context_handle_ref "cuda_core::create_context_handle_ref" ( cydriver.CUcontext ctx) except+ nogil + ContextHandle create_context_handle_from_green_ctx "cuda_core::create_context_handle_from_green_ctx" ( + const GreenCtxHandle& h_green_ctx) except+ nogil + GreenCtxHandle get_context_green_ctx "cuda_core::get_context_green_ctx" ( + const ContextHandle& h) noexcept nogil + GreenCtxHandle create_green_ctx_handle "cuda_core::create_green_ctx_handle" ( + cydriver.CUdevResource* resources, unsigned int nbResources, + cydriver.CUdevice dev, unsigned int flags) except+ nogil + GreenCtxHandle create_green_ctx_handle_ref "cuda_core::create_green_ctx_handle_ref" ( + cydriver.CUgreenCtx ctx) except+ nogil ContextHandle get_primary_context "cuda_core::get_primary_context" ( int device_id) except+ nogil ContextHandle get_current_context "cuda_core::get_current_context" () except+ nogil @@ -66,6 +76,8 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": cydriver.CUstream stream) except+ nogil StreamHandle create_stream_handle_with_owner "cuda_core::create_stream_handle_with_owner" ( cydriver.CUstream stream, object owner) except+ nogil + ContextHandle get_stream_context "cuda_core::get_stream_context" ( + const StreamHandle& h) noexcept nogil StreamHandle get_legacy_stream "cuda_core::get_legacy_stream" () except+ nogil StreamHandle get_per_thread_stream "cuda_core::get_per_thread_stream" () except+ nogil @@ -223,6 +235,11 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": void* p_cuDevicePrimaryCtxRetain "reinterpret_cast(cuda_core::p_cuDevicePrimaryCtxRetain)" void* p_cuDevicePrimaryCtxRelease "reinterpret_cast(cuda_core::p_cuDevicePrimaryCtxRelease)" void* p_cuCtxGetCurrent "reinterpret_cast(cuda_core::p_cuCtxGetCurrent)" + void* p_cuGreenCtxCreate "reinterpret_cast(cuda_core::p_cuGreenCtxCreate)" + void* p_cuGreenCtxDestroy "reinterpret_cast(cuda_core::p_cuGreenCtxDestroy)" + void* p_cuCtxFromGreenCtx "reinterpret_cast(cuda_core::p_cuCtxFromGreenCtx)" + void* p_cuDevResourceGenerateDesc "reinterpret_cast(cuda_core::p_cuDevResourceGenerateDesc)" + void* p_cuGreenCtxStreamCreate "reinterpret_cast(cuda_core::p_cuGreenCtxStreamCreate)" # Stream void* p_cuStreamCreateWithPriority "reinterpret_cast(cuda_core::p_cuStreamCreateWithPriority)" @@ -288,10 +305,23 @@ cdef void* _get_driver_fn(str name): capsule = cydriver.__pyx_capi__[name] return PyCapsule_GetPointer(capsule, PyCapsule_GetName(capsule)) + +cdef void* _get_optional_driver_fn(str name): + try: + capsule = cydriver.__pyx_capi__[name] + except KeyError: + return NULL + return PyCapsule_GetPointer(capsule, PyCapsule_GetName(capsule)) + # Context p_cuDevicePrimaryCtxRetain = _get_driver_fn("cuDevicePrimaryCtxRetain") p_cuDevicePrimaryCtxRelease = _get_driver_fn("cuDevicePrimaryCtxRelease") p_cuCtxGetCurrent = _get_driver_fn("cuCtxGetCurrent") +p_cuGreenCtxCreate = _get_optional_driver_fn("cuGreenCtxCreate") +p_cuGreenCtxDestroy = _get_optional_driver_fn("cuGreenCtxDestroy") +p_cuCtxFromGreenCtx = _get_optional_driver_fn("cuCtxFromGreenCtx") +p_cuDevResourceGenerateDesc = _get_optional_driver_fn("cuDevResourceGenerateDesc") +p_cuGreenCtxStreamCreate = _get_optional_driver_fn("cuGreenCtxStreamCreate") # Stream p_cuStreamCreateWithPriority = _get_driver_fn("cuStreamCreateWithPriority") diff --git a/cuda_core/cuda/core/_stream.pyx b/cuda_core/cuda/core/_stream.pyx index ca13811cd3c..e3865bcc542 100644 --- a/cuda_core/cuda/core/_stream.pyx +++ b/cuda_core/cuda/core/_stream.pyx @@ -21,6 +21,7 @@ from dataclasses import dataclass from typing import Protocol from cuda.core._context cimport Context +from cuda.core._device_resources cimport DeviceResources from cuda.core._event import Event, EventOptions from cuda.core._resource_handles cimport ( ContextHandle, @@ -31,8 +32,10 @@ from cuda.core._resource_handles cimport ( create_stream_handle, create_stream_handle_with_owner, get_current_context, + get_last_error, get_legacy_stream, get_per_thread_stream, + get_stream_context, as_intptr, as_cu, as_py, @@ -96,7 +99,7 @@ cdef class Stream: """Create a Stream from an existing StreamHandle (cdef-only factory).""" cdef Stream s = cls.__new__(cls) s._h_stream = h_stream - # _h_context is default-initialized to empty ContextHandle by C++ + s._h_context = get_stream_context(h_stream) s._device_id = -1 # lazy init'd (invalid sentinel) s._nonblocking = -1 # lazy init'd s._priority = INT32_MIN # lazy init'd @@ -142,8 +145,15 @@ cdef class Stream: else cydriver.CUstream_flags.CU_STREAM_DEFAULT) # TODO: we might want to consider memoizing high/low per CUDA context and avoid this call cdef int high, low + cdef cydriver.CUresult res_code with nogil: - HANDLE_RETURN(cydriver.cuCtxGetStreamPriorityRange(&high, &low)) + res_code = cydriver.cuCtxGetStreamPriorityRange(&high, &low) + if res_code != cydriver.CUresult.CUDA_SUCCESS: + if res_code == cydriver.CUresult.CUDA_ERROR_INVALID_CONTEXT: + raise RuntimeError( + "No current CUDA context. Call dev.set_current() before creating streams." + ) + HANDLE_RETURN(res_code) cdef int prio if priority is not None: prio = priority @@ -152,10 +162,25 @@ cdef class Stream: else: prio = high - # C++ creates the stream and returns owning handle with context dependency + # C++ creates the stream and returns owning handle with context dependency. + # For green contexts, the C++ layer auto-dispatches to cuGreenCtxStreamCreate. h_stream = create_stream_handle(h_context, flags, prio) if not h_stream: - raise RuntimeError("Failed to create CUDA stream") + res_code = get_last_error() + if not nonblocking and res_code == cydriver.CUresult.CUDA_ERROR_INVALID_VALUE: + # cuGreenCtxStreamCreate rejects CU_STREAM_DEFAULT; + # no need to check is_green since primary streams don't fail this way + raise ValueError( + "Green context streams must be non-blocking. " + "Use StreamOptions(nonblocking=True) or omit the option (True is the default)." + ) + elif res_code == cydriver.CUresult.CUDA_ERROR_NOT_SUPPORTED: + raise RuntimeError( + "cuGreenCtxStreamCreate is not available. " + "Green context stream creation requires CUDA 12.5 or newer." + ) + else: + HANDLE_RETURN(res_code) self = Stream._from_handle(cls, h_stream) self._nonblocking = int(nonblocking) self._priority = prio @@ -322,6 +347,18 @@ cdef class Stream: Stream_ensure_ctx_device(self) return Context._from_handle(Context, self._h_context, self._device_id) + @property + def resources(self): + """Query the hardware resources provisioned for this stream's context. + + For streams created from a green context, returns the resources + that context was provisioned with. For streams on the primary + context, returns the full device resources. + """ + Stream_ensure_ctx(self) + Stream_ensure_ctx_device(self) + return DeviceResources._init_from_ctx(self._h_context, self._device_id) + @staticmethod def from_handle(handle: int) -> Stream: """Create a new :obj:`~_stream.Stream` object from a foreign stream handle. @@ -406,7 +443,11 @@ cdef inline int Stream_ensure_ctx(Stream self) except?-1 nogil: """Ensure the stream's context handle is populated.""" cdef cydriver.CUcontext ctx if not self._h_context: - HANDLE_RETURN(cydriver.cuStreamGetCtx(as_cu(self._h_stream), &ctx)) + self._h_context = get_stream_context(self._h_stream) + if self._h_context: + return 0 + HANDLE_RETURN(cydriver.cuStreamGetCtx(as_cu(self._h_stream), &ctx)) + if ctx != NULL: with gil: self._h_context = create_context_handle_ref(ctx) return 0 @@ -416,13 +457,15 @@ cdef inline int Stream_ensure_ctx_device(Stream self) except?-1: """Ensure the stream's context and device_id are populated.""" cdef cydriver.CUcontext ctx cdef cydriver.CUdevice target_dev + cdef ContextHandle current_context cdef bint switch_context if self._device_id < 0: with nogil: # Get device ID from context, switching context temporarily if needed Stream_ensure_ctx(self) - switch_context = (get_current_context() != self._h_context) + current_context = get_current_context() + switch_context = (as_cu(current_context) != as_cu(self._h_context)) if switch_context: HANDLE_RETURN(cydriver.cuCtxPushCurrent(as_cu(self._h_context))) HANDLE_RETURN(cydriver.cuCtxGetDevice(&target_dev)) diff --git a/cuda_core/cuda/core/typing.py b/cuda_core/cuda/core/typing.py index a66ab1881fb..922e6b0ae6e 100644 --- a/cuda_core/cuda/core/typing.py +++ b/cuda_core/cuda/core/typing.py @@ -4,10 +4,12 @@ """Public type aliases and protocols used in cuda.core API signatures.""" +from cuda.core._context import DeviceResourcesT from cuda.core._memory._buffer import DevicePointerT from cuda.core._stream import IsStreamT __all__ = [ "DevicePointerT", + "DeviceResourcesT", "IsStreamT", ] diff --git a/cuda_core/docs/source/api.rst b/cuda_core/docs/source/api.rst index 88780732d54..8d591316f91 100644 --- a/cuda_core/docs/source/api.rst +++ b/cuda_core/docs/source/api.rst @@ -26,12 +26,18 @@ Devices and execution Stream Event + Context + SMResource + WorkqueueResource :template: dataclass.rst StreamOptions EventOptions LaunchConfig + ContextOptions + SMResourceOptions + WorkqueueResourceOptions .. data:: LEGACY_DEFAULT_STREAM diff --git a/cuda_core/docs/source/api_private.rst b/cuda_core/docs/source/api_private.rst index 141773967e8..de100c7152e 100644 --- a/cuda_core/docs/source/api_private.rst +++ b/cuda_core/docs/source/api_private.rst @@ -17,6 +17,7 @@ CUDA runtime :toctree: generated/ typing.DevicePointerT + typing.DeviceResourcesT _memory._virtual_memory_resource.VirtualMemoryAllocationTypeT _memory._virtual_memory_resource.VirtualMemoryLocationTypeT _memory._virtual_memory_resource.VirtualMemoryGranularityT @@ -30,6 +31,7 @@ CUDA runtime :template: autosummary/cyclass.rst _device.DeviceProperties + _device_resources.DeviceResources _memory._ipc.IPCAllocationHandle _memory._ipc.IPCBufferDescriptor diff --git a/cuda_core/tests/test_green_context.py b/cuda_core/tests/test_green_context.py new file mode 100644 index 00000000000..8eb32f7c1aa --- /dev/null +++ b/cuda_core/tests/test_green_context.py @@ -0,0 +1,465 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + + +import contextlib + +import numpy as np +import pytest + +from cuda.core import ( + ContextOptions, + DeviceResources, + LaunchConfig, + LegacyPinnedMemoryResource, + Program, + ProgramOptions, + SMResource, + SMResourceOptions, + WorkqueueResource, + WorkqueueResourceOptions, + launch, +) +from cuda.core._utils.cuda_utils import CUDAError + +# --------------------------------------------------------------------------- +# Kernel source +# --------------------------------------------------------------------------- + +_FILL_KERNEL = r""" +extern "C" __global__ void fill(int* out, int value, int n) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < n) { + out[tid] = value; + } +} +""" + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def sm_resource(init_cuda): + """Query SM resources from the device, skip if unsupported.""" + try: + return init_cuda.resources.sm + except (RuntimeError, ValueError, CUDAError) as exc: + pytest.skip(str(exc)) + + +@pytest.fixture +def wq_resource(init_cuda): + """Query workqueue resources from the device, skip if unsupported.""" + try: + return init_cuda.resources.workqueue + except (RuntimeError, ValueError, CUDAError) as exc: + pytest.skip(str(exc)) + + +@pytest.fixture +def green_ctx(init_cuda, sm_resource): + """Create a single-group green context with proper teardown.""" + groups, _ = sm_resource.split(SMResourceOptions(count=None)) + try: + ctx = init_cuda.create_context(ContextOptions(resources=[groups[0]])) + except CUDAError as exc: + pytest.skip(str(exc)) + yield ctx + ctx.close() + + +@pytest.fixture +def fill_kernel(init_cuda): + """Compile the fill kernel for the current device.""" + dev = init_cuda + opts = ProgramOptions(std="c++17", arch=f"sm_{dev.arch}") + prog = Program(_FILL_KERNEL, code_type="c++", options=opts) + mod = prog.compile("cubin") + return mod.get_kernel("fill") + + +def _aligned_half(sm): + """Compute half the SM count, rounded down to min_partition_size alignment.""" + min_size = sm.min_partition_size + half = (sm.sm_count // 2 // min_size) * min_size + return half + + +@contextlib.contextmanager +def _use_green_ctx(dev, ctx): + """Context manager: set green ctx current, restore previous on exit.""" + prev = dev.set_current(ctx) + try: + yield + finally: + dev.set_current(prev) + + +# --------------------------------------------------------------------------- +# Construction / type tests +# --------------------------------------------------------------------------- + + +def test_not_user_constructible(): + with pytest.raises(RuntimeError): + DeviceResources() + with pytest.raises(RuntimeError): + SMResource() + with pytest.raises(RuntimeError): + WorkqueueResource() + + +def test_create_context_requires_resources(init_cuda): + with pytest.raises(ValueError, match="resources must be provided"): + init_cuda.create_context() + with pytest.raises(ValueError, match="resources must be provided"): + init_cuda.create_context(ContextOptions(resources=None)) + with pytest.raises(TypeError): + init_cuda.create_context(object()) + + +# --------------------------------------------------------------------------- +# SM resource query +# --------------------------------------------------------------------------- + + +class TestSMResourceQuery: + def test_properties(self, sm_resource): + assert sm_resource.handle != 0 + assert sm_resource.sm_count > 0 + assert sm_resource.min_partition_size > 0 + assert sm_resource.coscheduled_alignment > 0 + assert isinstance(sm_resource.flags, int) + + def test_no_memory_node_id_in_v1(self, sm_resource): + """memory_node_id is deferred to v1.1 (CUDA 13.4).""" + assert not hasattr(sm_resource, "memory_node_id") + + def test_arch_constraints_pre_hopper(self, init_cuda, sm_resource): + if init_cuda.compute_capability >= (9, 0): + pytest.skip("Test is for pre-Hopper architectures") + assert sm_resource.min_partition_size >= 2 + assert sm_resource.coscheduled_alignment >= 2 + + def test_arch_constraints_hopper_plus(self, init_cuda, sm_resource): + if init_cuda.compute_capability < (9, 0): + pytest.skip("Test is for Hopper+ architectures") + assert sm_resource.min_partition_size >= 8 + assert sm_resource.coscheduled_alignment >= 8 + + +# --------------------------------------------------------------------------- +# Workqueue resource +# --------------------------------------------------------------------------- + + +class TestWorkqueueResource: + def test_query(self, wq_resource): + assert wq_resource.handle != 0 + + def test_configure_none_is_noop(self, wq_resource): + assert wq_resource.configure(WorkqueueResourceOptions(sharing_scope=None)) is None + + def test_configure_valid_scope(self, wq_resource): + wq_resource.configure(WorkqueueResourceOptions(sharing_scope="green_ctx_balanced")) + + def test_configure_invalid_scope_raises(self, wq_resource): + with pytest.raises(ValueError, match="Unknown sharing_scope"): + wq_resource.configure(WorkqueueResourceOptions(sharing_scope="bogus")) + + +# --------------------------------------------------------------------------- +# SM resource split — validation +# --------------------------------------------------------------------------- + + +class TestSMResourceSplitValidation: + def test_scalar_count_with_sequence_field_raises(self, sm_resource): + count = sm_resource.min_partition_size + with pytest.raises(ValueError, match="count is scalar"): + sm_resource.split( + SMResourceOptions( + count=count, + coscheduled_sm_count=(count, count), + ) + ) + + def test_sequence_length_mismatch_raises(self, sm_resource): + count = sm_resource.min_partition_size + with pytest.raises(ValueError, match="expected 2"): + sm_resource.split( + SMResourceOptions( + count=(count, count), + coscheduled_sm_count=(count, count, count), + ) + ) + + def test_negative_count_raises(self, sm_resource): + with pytest.raises(ValueError, match="count must be non-negative"): + sm_resource.split(SMResourceOptions(count=-1)) + + def test_dry_run_cannot_create_context(self, init_cuda, sm_resource): + groups, _ = sm_resource.split(SMResourceOptions(count=None), dry_run=True) + assert len(groups) == 1 + with pytest.raises(ValueError, match="dry-run SMResource"): + init_cuda.create_context(ContextOptions(resources=[groups[0]])) + + +# --------------------------------------------------------------------------- +# SM resource split — functional +# --------------------------------------------------------------------------- + + +class TestSMResourceSplit: + def test_single_group_counts(self, sm_resource): + """Single-group split: group gets at least requested SMs.""" + requested = sm_resource.min_partition_size + groups, rem = sm_resource.split(SMResourceOptions(count=requested)) + + assert len(groups) == 1 + assert groups[0].sm_count >= requested + assert groups[0].sm_count + rem.sm_count <= sm_resource.sm_count + + def test_discovery_mode(self, sm_resource): + """count=None auto-detects a valid SM count.""" + groups, _ = sm_resource.split(SMResourceOptions(count=None)) + + assert len(groups) == 1 + assert groups[0].sm_count >= sm_resource.min_partition_size + + def test_discovery_respects_alignment(self, sm_resource): + groups, _ = sm_resource.split(SMResourceOptions(count=None)) + + if sm_resource.coscheduled_alignment > 0: + assert groups[0].sm_count % sm_resource.coscheduled_alignment == 0 + + def test_two_groups(self, sm_resource): + """Two-group split with explicit aligned counts.""" + half = _aligned_half(sm_resource) + if half < sm_resource.min_partition_size: + pytest.skip("Not enough SMs for a 2-group split") + + groups, rem = sm_resource.split(SMResourceOptions(count=(half, half))) + + assert len(groups) == 2 + assert groups[0].sm_count > 0 + assert groups[1].sm_count > 0 + total = groups[0].sm_count + groups[1].sm_count + rem.sm_count + assert total <= sm_resource.sm_count + + def test_two_groups_each_meets_request(self, sm_resource): + min_size = sm_resource.min_partition_size + half = _aligned_half(sm_resource) + if half < min_size: + pytest.skip("Not enough SMs for a 2-group split") + + groups, _ = sm_resource.split(SMResourceOptions(count=(min_size, min_size))) + + assert len(groups) == 2 + assert groups[0].sm_count >= min_size + assert groups[1].sm_count >= min_size + + def test_dry_run_matches_real(self, sm_resource): + """Dry-run reports the same SM counts as a real split.""" + opts = SMResourceOptions(count=None) + + dry_groups, _ = sm_resource.split(opts, dry_run=True) + real_groups, _ = sm_resource.split(opts, dry_run=False) + + assert len(dry_groups) == len(real_groups) + for dg, rg in zip(dry_groups, real_groups): + assert dg.sm_count == rg.sm_count + + +# --------------------------------------------------------------------------- +# Green context lifecycle +# --------------------------------------------------------------------------- + + +class TestGreenContextLifecycle: + def test_is_green(self, green_ctx): + assert green_ctx.is_green + assert green_ctx.handle is not None + + def test_create_stream_on_primary_raises(self, init_cuda): + """create_stream is only for green contexts.""" + # The init_cuda fixture sets the primary context + # Get the primary context via device internals + ctx = init_cuda._context + with pytest.raises(RuntimeError, match="only supported on green contexts"): + ctx.create_stream() + + def test_create_stream_blocking_raises(self, green_ctx): + """Green context streams must be non-blocking.""" + from cuda.core import StreamOptions + + with pytest.raises(ValueError, match="must be non-blocking"): + green_ctx.create_stream(StreamOptions(nonblocking=False)) + + def test_create_stream_explicit(self, green_ctx): + """Create a stream directly from the green context (no set_current).""" + stream = green_ctx.create_stream() + assert stream is not None + assert stream.context.is_green + assert stream.context == green_ctx + + def test_stream_and_event_track_green_context(self, green_ctx): + stream = green_ctx.create_stream() + event = stream.record() + assert stream.context.is_green + assert stream.context == green_ctx + assert event.context.is_green + assert event.context == green_ctx + stream.sync() + event.sync() + + def test_close_while_current_raises(self, init_cuda, green_ctx): + """close() on a current context raises — test via set_current.""" + dev = init_cuda + with _use_green_ctx(dev, green_ctx), pytest.raises(RuntimeError, match="while it is current"): + green_ctx.close() + + def test_set_current_swap_regression(self, init_cuda, green_ctx): + """set_current still works (backward compat) and preserves identity.""" + dev = init_cuda + with _use_green_ctx(dev, green_ctx): + pass # just verify push/pop works + # Swap again and check identity round-trip + prev = dev.set_current(green_ctx) + try: + assert prev is not None + finally: + restored = dev.set_current(prev) + assert restored is green_ctx + assert restored.is_green + + +# --------------------------------------------------------------------------- +# Context.resources +# --------------------------------------------------------------------------- + + +class TestContextResources: + def test_green_ctx_sm_resources(self, green_ctx, sm_resource): + """Green context's SM resources should be a subset of device SMs.""" + ctx_sm = green_ctx.resources.sm + assert ctx_sm.sm_count > 0 + assert ctx_sm.sm_count <= sm_resource.sm_count + + def test_green_ctx_resources_reflect_partition(self, init_cuda, sm_resource): + """Two green contexts should have disjoint SM partitions.""" + half = _aligned_half(sm_resource) + if half < sm_resource.min_partition_size: + pytest.skip("Not enough SMs for a 2-group split") + + groups, _ = sm_resource.split(SMResourceOptions(count=(half, half))) + + ctx_a = ctx_b = None + try: + ctx_a = init_cuda.create_context(ContextOptions(resources=[groups[0]])) + ctx_b = init_cuda.create_context(ContextOptions(resources=[groups[1]])) + + sm_a = ctx_a.resources.sm.sm_count + sm_b = ctx_b.resources.sm.sm_count + assert sm_a > 0 + assert sm_b > 0 + assert sm_a + sm_b <= sm_resource.sm_count + finally: + if ctx_b is not None: + ctx_b.close() + if ctx_a is not None: + ctx_a.close() + + def test_stream_resources_match_context(self, green_ctx, sm_resource): + """stream.resources should return the same as ctx.resources.""" + stream = green_ctx.create_stream() + + stream_sm = stream.resources.sm + ctx_sm = green_ctx.resources.sm + assert stream_sm.sm_count == ctx_sm.sm_count + assert stream_sm.sm_count > 0 + assert stream_sm.sm_count <= sm_resource.sm_count + + try: + stream_wq = stream.resources.workqueue + ctx_wq = green_ctx.resources.workqueue + assert stream_wq.handle != 0 + assert ctx_wq.handle != 0 + except (RuntimeError, ValueError, CUDAError): + pass # workqueue not available on this driver/build + + +# --------------------------------------------------------------------------- +# Kernel launch in green context (explicit model) +# --------------------------------------------------------------------------- + + +def _launch_fill_and_verify(dev, stream, kernel, n, value): + """Launch the fill kernel and verify results on host.""" + dev_buf = dev.allocate(n * np.dtype(np.int32).itemsize, stream=stream) + + config = LaunchConfig(grid=(n + 31) // 32, block=32) + launch(stream, config, kernel, dev_buf, np.int32(value), np.int32(n)) + + host_mr = LegacyPinnedMemoryResource() + host_buf = host_mr.allocate(n * np.dtype(np.int32).itemsize) + host_arr = np.from_dlpack(host_buf).view(np.int32) + host_arr[:] = 0 + + dev_buf.copy_to(host_buf, stream=stream) + stream.sync() + + np.testing.assert_array_equal(host_arr, np.full(n, value, dtype=np.int32)) + + +class TestGreenContextKernelLaunch: + def test_launch_and_verify(self, init_cuda, green_ctx, fill_kernel): + """Launch kernel via ctx.create_stream (explicit model, no set_current).""" + stream = green_ctx.create_stream() + _launch_fill_and_verify(init_cuda, stream, fill_kernel, n=64, value=42) + + def test_two_green_contexts_independent(self, init_cuda, sm_resource, fill_kernel): + """Two SM groups -> two green contexts -> two independent kernels.""" + dev = init_cuda + half = _aligned_half(sm_resource) + if half < sm_resource.min_partition_size: + pytest.skip("Not enough SMs for a 2-group split") + + groups, _ = sm_resource.split(SMResourceOptions(count=(half, half))) + assert len(groups) == 2 + + ctx_a = ctx_b = None + try: + ctx_a = dev.create_context(ContextOptions(resources=[groups[0]])) + ctx_b = dev.create_context(ContextOptions(resources=[groups[1]])) + + for ctx, value in [(ctx_a, 10), (ctx_b, 20)]: + stream = ctx.create_stream() + _launch_fill_and_verify(dev, stream, fill_kernel, n=64, value=value) + finally: + if ctx_b is not None: + ctx_b.close() + if ctx_a is not None: + ctx_a.close() + + def test_with_workqueue_resource(self, init_cuda, sm_resource, wq_resource, fill_kernel): + """Green context with SM + workqueue resources can launch a kernel.""" + dev = init_cuda + groups, _ = sm_resource.split(SMResourceOptions(count=None)) + + try: + ctx = dev.create_context(ContextOptions(resources=[groups[0], wq_resource])) + except CUDAError as exc: + pytest.skip(str(exc)) + + assert ctx.is_green + + try: + stream = ctx.create_stream() + _launch_fill_and_verify(dev, stream, fill_kernel, n=32, value=99) + finally: + ctx.close()