Cuda: it just works
Add CUDA optionally with JAX, so we can use one pyproject.toml with and without CUDA, and roll out the system on ~arbitrary CUDA-capable machines.
Known limitations:
- NVIDIA System driver must be installed and loaded. Version >555.58.
-
Pyproject.toml
needs both the Jax cuda deps to be included, at the same version, and marked as optional. Additionally, an extras group for cuda should expose these two libs. Seejax/pyproject.toml
here oropt_tools/pyproject.toml
. - Only tested with jax 4.31. Maybe other versions work, not tested.
Closes #2