bumped ros_datahandler, modified dynslicedataloader, and made a test with...
Fix training speed. High-level changelog:
- change internal data structure on
dyn_slice_dataset
from a list of trajs to a concatenated traj. With a list, we needed to use ajax.lax.switch
to get the right trajectory index which worked but went really slow as # of bags increased (see here) - make the train loop into a
jax.lax.fori_loop
. This puts the batch fetching and training into the same jit block, so the XLA compiler can optimize out any memory copies.- This drops epoch time from 13 sec to 6 with window_size=10, batch=50.
- I left this as a separate
train_jit()
because it makes it harder to do certain things within an epoch. If we won't need to do anything within an epoch (e.g. early_stopping or checkpointing is always after an epoch) then we can merge and just have one train fn. A quick note from either of you would help here :)
Closes #49 (closed)