Optimize Computation Time in Training
Hey @hanikevi ,
this issue is related to 35-out-of-memory
for optimizing the computation time in training.
Current state is:
- jitted
build_view_batch_fori
function with fori_loop() - jitted
get_traj_idx
function - testing is added in
dyn_slice_dataset.py
Looking forward to work together on this issue. Validation activities are not related to this.
-
Pad arrays -
So the view_batch and view_batch_fori are working -
Write minimal benchmark including slicing and memory transfer w/ trivial computation on GPU -
Can we just vmap to get batches? -
move cumsum into default _get_idx -
Possible to do slicing on GPU? If it's JIT-ted is this possible? -
If we have pytree on CPU, generate batch, then feed that into network, when does that go to GPU? -
Why not faster on smaller batches? Prep_batches does a small batch in 13sec. -
Check control flow - I suspect calling get_batch(idxs) in training loop might be less efficient than using an iterator / generator -
If so; split out dataset and dataloader
-
Edited by Kevin Haninger