Training speed regression
Complete training seems to have regressed. See discussion here
Write small benchmark for a training iteration -
Git bisect to isolate the regression commit -
Git bisect on GPU -
Check if a pure array index works better? -
Steal other things from torch? -
prefetch -
memory pinning -
multiple workers
seems like jax just doesn't do memory views. Mem copies prob get optimized out in JIT. Way to JIT the training loop? Could be interesting. -
if an inner fori should be used, clean up dataset._samples -
clean up dataloader -
fix prediction dataloader. right now cumlengths makes the +pred_offset wrap to the next array.
check test_synth_data? seems to have pulled an old version
Edited by Kevin Haninger