Training speed regression
Complete training seems to have regressed. See discussion here
-
Write small benchmark for a training iteration -
Git bisect to isolate the regression commit -
Git bisect on GPU -
Check if a pure array index works better? -
Steal other things from torch? https://pytorch.org/docs/stable/data.html#single-and-multi-process-data-loading -
prefetch -
memory pinning -
multiple workers
-
-
seems like jax just doesn't do memory views. https://jax.readthedocs.io/en/latest/jax.numpy.html. Mem copies prob get optimized out in JIT. Way to JIT the training loop? Could be interesting. -
if an inner fori should be used, clean up dataset._samples -
clean up dataloader -
fix prediction dataloader. right now cumlengths makes the +pred_offset wrap to the next array.
-
-
check test_synth_data? seems to have pulled an old version
Edited by Kevin Haninger