bumped ros_datahandler, modified dynslicedataloader, and made a test with...
Loading
Fix training speed. High-level changelog:
dyn_slice_dataset
from a list of trajs to a concatenated traj. With a list, we needed to use a jax.lax.switch
to get the right trajectory index which worked but went really slow as # of bags increased (see here)jax.lax.fori_loop
. This puts the batch fetching and training into the same jit block, so the XLA compiler can optimize out any memory copies.
train_jit()
because it makes it harder to do certain things within an epoch. If we won't need to do anything within an epoch (e.g. early_stopping or checkpointing is always after an epoch) then we can merge and just have one train fn. A quick note from either of you would help here :)Closes #49 (closed)