Skip to content

Cuda: it just works

Kevin Haninger requested to merge 2-cuda into master

Add CUDA optionally with JAX, so we can use one pyproject.toml with and without CUDA, and roll out the system on ~arbitrary CUDA-capable machines.

Known limitations:

  • NVIDIA System driver must be installed and loaded. Version >555.58.
  • Pyproject.toml needs both the Jax cuda deps to be included, at the same version, and marked as optional. Additionally, an extras group for cuda should expose these two libs. See jax/pyproject.toml here or opt_tools/pyproject.toml.
  • Only tested with jax 4.31. Maybe other versions work, not tested.

Closes #2

Merge request reports