Skip to content

bumped ros_datahandler, modified dynslicedataloader, and made a test with...

Kevin Haninger requested to merge 49-training-speed-regression into main

Fix training speed. High-level changelog:

  • change internal data structure on dyn_slice_dataset from a list of trajs to a concatenated traj. With a list, we needed to use a jax.lax.switch to get the right trajectory index which worked but went really slow as # of bags increased (see here)
  • make the train loop into a jax.lax.fori_loop. This puts the batch fetching and training into the same jit block, so the XLA compiler can optimize out any memory copies.
    • This drops epoch time from 13 sec to 6 with window_size=10, batch=50.
    • I left this as a separate train_jit() because it makes it harder to do certain things within an epoch. If we won't need to do anything within an epoch (e.g. early_stopping or checkpointing is always after an epoch) then we can merge and just have one train fn. A quick note from either of you would help here :)

Closes #49 (closed)

Merge request reports