GPU Kernels with Nvidia Warp
Why this matters
Writing GPU kernels usually requires CUDA C++ and a separate compilation step. Nvidia Warp lets you write GPU kernels in Python with a decorator (@wp.kernel) and JIT-compiles them on first use. The result: GPU-speed parallelism with Python-developer ergonomics.
The critical gotcha for Python codebases: import warp will fail on any machine without CUDA installed — including your CI runners, your teammates' MacBooks, and any CPU-only deployment host. If you put import warp at the top of a module, that entire module becomes unimportable everywhere without a GPU. The solution is deferred imports: every Warp import goes inside a function body, never at module level.
Install
uv add 'warp-lang>=1.8' numpyDeclare Warp as an optional extra in any library that ships a CPU fallback:
[project.optional-dependencies]
cuda = ["warp-lang>=1.8"]pip install 'acme-voxel[cuda]' # opt-in GPU install
pip install 'acme-voxel' # plain install stays CPU-onlyModule Skeleton — Lazy import warp
The whole point: a module that defines GPU kernels must still import on a non-CUDA host. Never import warp at module top — every Warp import goes inside a function body.
warp_backend.py:
"""GPU voxel kernels using NVIDIA Warp. All warp imports are deferred."""
from __future__ import annotations
import numpy as np
def _define_kernels() -> None:
"""Import warp and JIT-define every kernel once. Called lazily on first GPU use."""
import warp as wp # type: ignore[import]
global _kernel_boolean, _kernel_subtract_capsule, _kernel_voxelize
global _kernels_defined
if _kernels_defined:
return
# --- kernel definitions go here (see Kernel Patterns below) ---
_kernel_boolean = _boolean
_kernel_subtract_capsule = _subtract_capsule
_kernel_voxelize = _voxelize
_kernels_defined = True
# Module-level sentinel + placeholders (one per kernel).
_kernels_defined: bool = False
_kernel_boolean = None # type: ignore[assignment]
_kernel_subtract_capsule = None # type: ignore[assignment]
_kernel_voxelize = None # type: ignore[assignment]Why this shape:
- Nested kernel functions inside
_define_kernels()keepimport warpdeferred while still being addressable as module-level globals. _kernels_definedguard makes JIT compilation happen exactly once.
Kernel Patterns
Flat 1-D elementwise
@wp.kernel # type: ignore[misc]
def _boolean(
a: wp.array(dtype=wp.bool), # type: ignore[valid-type]
b: wp.array(dtype=wp.bool), # type: ignore[valid-type]
out: wp.array(dtype=wp.bool), # type: ignore[valid-type]
op: int,
) -> None:
i = wp.tid()
va = a[i]
vb = b[i]
if op == 0: out[i] = va or vb # union
elif op == 1: out[i] = va and vb # intersect
else: out[i] = va and not vb # subtractPass op as a plain int from the host: {"union": 0, "intersect": 1, "subtract": 2}[op].
Flat-index decode for a 3-D grid (Z, Y, X), with atomic counter
@wp.kernel # type: ignore[misc]
def _subtract_capsule(
data: wp.array(dtype=wp.bool), # type: ignore[valid-type]
count: wp.array(dtype=wp.int32), # type: ignore[valid-type]
nz: int, ny: int, nx: int,
ox: float, oy: float, oz: float, vs: float,
ax: float, ay: float, az: float,
bx: float, by: float, bz: float,
radius: float,
) -> None:
idx = wp.tid()
iz = idx // (ny * nx)
rem = idx - iz * ny * nx
iy = rem // nx
ix = rem - iy * nx
# World-space voxel centre
px = ox + (float(ix) + 0.5) * vs
py = oy + (float(iy) + 0.5) * vs
pz = oz + (float(iz) + 0.5) * vs
# Distance from voxel centre to segment a->b
seg_x = bx - ax; seg_y = by - ay; seg_z = bz - az
diff_x = px - ax; diff_y = py - ay; diff_z = pz - az
dot_seg = seg_x * seg_x + seg_y * seg_y + seg_z * seg_z
t = 0.0
if dot_seg > 1e-12:
t = (diff_x * seg_x + diff_y * seg_y + diff_z * seg_z) / dot_seg
t = wp.clamp(t, 0.0, 1.0)
dx = px - (ax + t * seg_x)
dy = py - (ay + t * seg_y)
dz = pz - (az + t * seg_z)
if dx * dx + dy * dy + dz * dz <= radius * radius and data[idx]:
data[idx] = wp.bool(False)
wp.atomic_add(count, 0, 1)3-D indexed kernel for a Z-slab (wp.array3d + 3-tuple wp.tid())
@wp.kernel # type: ignore[misc]
def _subtract_capsule_offset(
grid: wp.array3d(dtype=wp.bool), # type: ignore[valid-type]
ax: float, ay: float, az: float,
bx: float, by: float, bz: float,
radius: float,
ox: float, oy: float, oz: float, vs: float,
z_offset: int,
) -> None:
iz, iy, ix = wp.tid()
px = ox + (float(ix) + 0.5) * vs
py = oy + (float(iy) + 0.5) * vs
pz = oz + (float(iz + z_offset) + 0.5) * vs # GLOBAL Z
# ... same SDF math as aboveMesh voxelization with a BVH (wp.Mesh)
@wp.kernel # type: ignore[misc]
def _voxelize(
mesh_id: wp.uint64, # type: ignore[valid-type]
out: wp.array(dtype=wp.bool), # type: ignore[valid-type]
min_x: float, min_y: float, min_z: float,
pitch: float,
nz: int, ny: int, nx: int,
) -> None:
idx = wp.tid()
iz = idx // (ny * nx)
rem = idx - iz * ny * nx
iy = rem // nx
ix = rem - iy * nx
cx = min_x + (float(ix) + 0.5) * pitch
cy = min_y + (float(iy) + 0.5) * pitch
cz = min_z + (float(iz) + 0.5) * pitch
pt = wp.vec3(cx, cy, cz)
query = wp.mesh_query_point_sign_normal(mesh_id, pt, 1e6)
if query.sign < 0.0:
out[idx] = wp.bool(True)Sign-query voxelization can leave hollow shells for non-watertight meshes — post-process the downloaded numpy result with scipy.ndimage.binary_fill_holes(out).
Backend Class — Launch + Memory
import warp only inside the constructor and launching methods. Resolve device once in __init__, warm the JIT, then every method follows the same shape: numpy in → wp.array on device → wp.launch → .numpy() back.
class WarpBackend:
name: str = "cuda"
def __init__(self) -> None:
import warp as wp # type: ignore[import]
devices = wp.get_cuda_devices()
if not devices:
raise RuntimeError("WarpBackend: no CUDA devices found.")
self._device = devices[0]
wp.init()
_define_kernels() # JIT now so first method call is warm
def boolean(self, op, a, b=None):
import warp as wp # type: ignore[import]
arr_a = a.astype(bool)
arr_b = b.astype(bool)
n = arr_a.size
d_a = wp.array(arr_a.ravel(), dtype=wp.bool, device=self._device)
d_b = wp.array(arr_b.ravel(), dtype=wp.bool, device=self._device)
d_out = wp.zeros(n, dtype=wp.bool, device=self._device)
op_code = {"union": 0, "intersect": 1, "subtract": 2}[op]
wp.launch(_kernel_boolean,
dim=n,
inputs=[d_a, d_b, d_out, op_code],
device=self._device)
return d_out.numpy().reshape(arr_a.shape)
def subtract_swept_capsule(self, data, p0, p1, radius_mm, origin, vs):
import warp as wp # type: ignore[import]
nz, ny, nx = data.shape
d_data = wp.array(data.ravel(), dtype=wp.bool, device=self._device)
d_count = wp.zeros(1, dtype=wp.int32, device=self._device)
ax, ay, az = p0; bx, by, bz = p1; ox, oy, oz = origin
wp.launch(_kernel_subtract_capsule,
dim=nz * ny * nx,
inputs=[d_data, d_count, nz, ny, nx, ox, oy, oz, vs,
ax, ay, az, bx, by, bz, radius_mm],
device=self._device)
new_data = d_data.numpy().reshape(data.shape)
removed = int(d_count.numpy()[0])
return new_data, removedLaunch shape:
dimis total thread count for flat kernels (arr.sizeornz*ny*nx), or(nz, ny, nx)tuple for 3-Dwp.array3dkernels.inputs=[...]is positional, must match the kernel signature exactly. Cast scalars:float(x0),int(z_lo)..numpy()does the device→host copy (implicit sync). Treat it as the one round-trip point per op.
Slabbed Work — When Grid > VRAM
For a 14 GB grid on a 16 GB card, process in Z-slabs sized to a fixed VRAM budget. One bool voxel is 1 byte on the GPU.
from dataclasses import dataclass, field
_SLAB_BYTES: int = 2 * 1024**3 # 2 GB slab budget
@dataclass
class SimSession:
"""GPU-resident batch session — queue sweeps, apply on flush() / __exit__."""
arr: "np.ndarray"
origin: tuple[float, float, float]
voxel_size: float
device: str = "cuda:0"
_pending: list = field(default_factory=list, init=False, repr=False)
def __enter__(self): return self
def __exit__(self, *_): self.flush()
def subtract_swept_capsule(self, path, radius_mm) -> None:
self._pending.append((path, radius_mm)) # queue, do not run yet
def flush(self) -> None:
if not self._pending:
return
_define_kernels()
nz, ny, nx = self.arr.shape
nz_slab = max(1, _SLAB_BYTES // (ny * nx))
for z_lo in range(0, nz, nz_slab):
z_hi = min(nz, z_lo + nz_slab)
slab = self.arr[z_lo:z_hi].astype(bool, copy=False)
_subtract_capsules_in_slab(slab, self._pending, z_lo,
self.voxel_size, self.origin, self.device)
self.arr[z_lo:z_hi] = slab
self._pending.clear() # second flush() is a no-opPer-slab launch loop:
snz, sny, snx = slab_np.shape
wp_slab = wp.array(slab_np, dtype=wp.bool, device=device)
wp.launch(kernel=_kernel_subtract_capsule_offset,
dim=(snz, sny, snx),
inputs=[wp_slab,
float(x0), float(y0), float(z0),
float(x1), float(y1), float(z1),
float(radius_mm),
float(ox), float(oy), float(oz), float(voxel_size),
int(z_lo)],
device=device)
slab_np[:] = wp_slab.numpy() # download in place
del wp_slab # free VRAM before next slabBackend Selector — auto / cuda / cpu
Pair the Warp backend with a scipy CPU fallback behind a common Backend Protocol. "auto" tries Warp, falls back silently; "cuda" raises on no GPU; "cpu" always returns scipy.
# backend.py
from __future__ import annotations
from typing import Protocol, runtime_checkable
import numpy as np
@runtime_checkable
class Backend(Protocol):
name: str
def boolean(self, op: str, a, b=None): ...
def dilate(self, grid, kernel: np.ndarray): ...
def voxelize_mesh(self, mesh, bbox, pitch_mm: float): ...
def subtract_swept_capsule(self, grid, p0, p1, radius_mm: float) -> int: ...# backend_select.py
from __future__ import annotations
import os
from typing import Literal
from backend import Backend
BackendName = Literal["auto", "cuda", "cpu"]
def get_backend(name: BackendName | None = None) -> Backend:
if name is None:
name = os.environ.get("ACME_BACKEND", "auto") # type: ignore[assignment]
if name == "cpu":
from scipy_backend import ScipyBackend
return ScipyBackend()
if name == "cuda":
return _load_warp_backend(require_cuda=True)
if name == "auto":
try:
return _load_warp_backend(require_cuda=False)
except (RuntimeError, ImportError):
from scipy_backend import ScipyBackend
return ScipyBackend()
raise ValueError(f"Unknown backend: {name!r}. Choose 'auto', 'cuda', or 'cpu'.")
def _load_warp_backend(*, require_cuda: bool) -> Backend:
try:
import warp as wp # type: ignore[import]
except ImportError as e:
raise RuntimeError(
"warp-lang is not installed. Install the CUDA extra:\n"
" pip install 'acme-voxel[cuda]'"
) from e
devices = wp.get_cuda_devices()
if not devices:
raise RuntimeError(
"No CUDA devices found. Cannot use WarpBackend.\n"
"Set ACME_BACKEND=cpu to use the scipy fallback."
)
from warp_backend import WarpBackend
return WarpBackend()Per-op CPU fallback inside the GPU backend is fine while you're writing kernels:
def dilate(self, grid, kernel):
import scipy.ndimage as ndi # type: ignore[import]
return ndi.binary_dilation(grid.astype(bool), structure=kernel.astype(bool))Testing
Register the cuda marker
[tool.pytest.ini_options]
markers = ["cuda: tests that require a CUDA GPU (skipped on non-CUDA hosts)"]Module-level skip for non-CUDA hosts
tests/test_warp_backend.py — collects everywhere, skips on machines without a GPU:
from __future__ import annotations
import numpy as np
import pytest
def _has_cuda() -> bool:
try:
import warp as wp # type: ignore[import]
return len(wp.get_cuda_devices()) > 0
except Exception:
return False
pytestmark = pytest.mark.cuda
@pytest.fixture(scope="session", autouse=True)
def require_cuda() -> None:
if not _has_cuda():
pytest.skip("No CUDA devices found", allow_module_level=True)
@pytest.fixture(scope="module")
def warp_backend():
from warp_backend import WarpBackend
return WarpBackend()
@pytest.fixture(scope="module")
def scipy_backend():
from scipy_backend import ScipyBackend
return ScipyBackend()Cross-validate against the CPU reference
Cross-validate the GPU kernel against the scipy CPU implementation voxel-for-voxel. This catches a kernel that uses diameter where radius was intended (removes ~2x too much):
@pytest.mark.cuda
@pytest.mark.parametrize("p0,p1,radius_mm,desc", [
((16.0, 16.0, 0.0), (16.0, 16.0, 32.0), 3.0, "z-axis radius=3"),
((0.0, 16.0, 16.0), (32.0, 16.0, 16.0), 4.0, "x-axis radius=4"),
((16.0, 16.0, 4.0), (16.0, 16.0, 28.0), 6.35, "cutter-sized radius=6.35"),
])
def test_subtract_swept_capsule_matches_scipy(warp_backend, scipy_backend, p0, p1, radius_mm, desc):
arr_w = np.ones((32, 32, 32), dtype=bool)
arr_s = arr_w.copy()
warp_backend.subtract_swept_capsule(arr_w, p0, p1, radius_mm)
scipy_backend.subtract_swept_capsule(arr_s, p0, p1, radius_mm)
diff = int((arr_w != arr_s).sum())
if diff != 0:
ratio = (~arr_w).sum() / max(1, (~arr_s).sum())
raise AssertionError(
f"{desc}: {diff} voxels differ. ratio={ratio:.3f} "
f"(~2 means kernel uses diameter where radius is intended)"
)AST isolation test — lock down the lazy import
A grep can't tell a module-level import from a lazy one inside a function. Parse the AST and inspect only top-level body. The whole point of deferring import warp is defeated if anyone imports the GPU module at module scope.
# tests/test_gpu_isolation.py
from __future__ import annotations
import ast, importlib, re
from pathlib import Path
def test_only_dispatch_imports_warp_backend() -> None:
module = importlib.import_module("acme.stages.dispatch")
src_path = Path(module.__file__)
src = src_path.read_text(encoding="utf-8")
occurrences = re.findall(r"warp_backend", src)
tree = ast.parse(src)
violations: list[str] = []
for node in ast.iter_child_nodes(tree): # TOP-LEVEL only
if isinstance(node, ast.Import):
for alias in node.names:
if "warp_backend" in (alias.name or ""):
violations.append(f"line {node.lineno}: import {alias.name}")
elif isinstance(node, ast.ImportFrom):
mod = node.module or ""
if "warp_backend" in mod:
violations.append(f"line {node.lineno}: from {mod} import ...")
assert not violations, (
"Module-level warp_backend import found.\n"
+ "\n".join(f" {v}" for v in violations)
)
assert len(occurrences) >= 1, "Expected the lazy import inside the dispatch fn."Rules to Keep
@wp.kernel→# type: ignore[misc]; eachwp.array(...)/wp.array3d(...)arg →# type: ignore[valid-type]- Kernel return type is always
-> None— kernels mutate array args in place - Inside a kernel use only Warp builtins:
wp.bool(True),wp.clamp,wp.vec3,wp.atomic_add,wp.round. No Python imports, no PythonTrue/False - Grids are
(Z, Y, X). Flat decode is always:iz = idx // (ny*nx); rem = idx - iz*ny*nx; iy = rem // nx; ix = rem - iy*nx - Cast scalars at launch sites —
float(x0),int(z_lo), never bare Python ints/floats
Common Pitfalls
import warpat module top breaks every non-GPU host. Lock it down with the AST isolation test- Forgetting
del wp_slabbetween slab launches keeps VRAM allocated and OOMs on the next slab - Half-voxel GPU/CPU mismatch — without snapping to nearest voxel centre, the GPU boundary can disagree with an index-rounded CPU disc by up to half a voxel. Use
wp.round((closest_x - ox) / vs - 0.5)then re-multiply - Hollow shells for non-watertight meshes — sign-query voxelization can leave holes; post-process with
scipy.ndimage.binary_fill_holes device=accepts either a resolved device object or a string ("cuda:0"/"cpu") — passing the wrong type silently picks a different device