Skip to content

Optimize Computation Time in Training

Hey @hanikevi ,

this issue is related to 35-out-of-memory for optimizing the computation time in training.

Current state is:

  • jitted build_view_batch_fori function with fori_loop()
  • jitted get_traj_idx function
  • testing is added in dyn_slice_dataset.py

Looking forward to work together on this issue. Validation activities are not related to this.


  • Pad arrays
  • So the view_batch and view_batch_fori are working
  • Write minimal benchmark including slicing and memory transfer w/ trivial computation on GPU
  • Can we just vmap to get batches?
  • move cumsum into default _get_idx
  • Possible to do slicing on GPU? If it's JIT-ted is this possible?
  • If we have pytree on CPU, generate batch, then feed that into network, when does that go to GPU?
  • Why not faster on smaller batches? Prep_batches does a small batch in 13sec.
    • Check control flow - I suspect calling get_batch(idxs) in training loop might be less efficient than using an iterator / generator
    • If so; split out dataset and dataloader
Edited by Kevin Haninger