Optimize Computation Time in Training
Hey @hanikevi ,
this issue is related to 35-out-of-memory
for optimizing the computation time in training.
Current state is:
- jitted
build_view_batch_fori
function with fori_loop() - jitted
get_traj_idx
function - testing is added in
dyn_slice_dataset.py
Looking forward to work together on this issue. Validation activities are not related to this.
- Pad arrays
- So the view_batch and view_batch_fori are working
- Write minimal benchmark including slicing and memory transfer w/ trivial computation on GPU
- Can we just vmap to get batches?
- move cumsum into default _get_idx
- Possible to do slicing on GPU? If it's JIT-ted is this possible?
- If we have pytree on CPU, generate batch, then feed that into network, when does that go to GPU?
-
Why not faster on smaller batches? Prep_batches does a small batch in 13sec.
- Check control flow - I suspect calling get_batch(idxs) in training loop might be less efficient than using an iterator / generator
- If so; split out dataset and dataloader