Draft: JIT examples for diffusion step and data handling.
Hi @dingyuan.wan ,
This MR is just to wrap up issue #22, I would actually not recommend merging it b/c I made some changes so I could run without pytorch. However, I'd suggest to look at scripts/inference_pushT_benchmark_data_opt
. To JIT end-to-end, you just need two structural changes:
- I think its easier to make a dataclass which has your data, I called it a Databuffer. I think updating Databuffer on CPU (e.g. with numpy arrays or other python types) in the callback for your rosnode, then casting to jax (databuffer.to_jax()) when ready to do the next diffusion step is probably most efficient. When casting to jax arrays, they default to the device.defualt, which should be GPU.
- I added an
inference
function, this goes from the databuffer to inferred actions. This is what we JIT, don't need to worry about jitting any fns that go into it ahead of time, we just need jit on the highest level.
This also has a pyproject and poetry.lock which work on CPU. If you want to take those, i'd suggest do a git checkout to pull them into your branch.
Closes #22