Skip to content

Training speed regression

Complete training seems to have regressed. See discussion here

  • Write small benchmark for a training iteration
  • Git bisect to isolate the regression commit
  • Git bisect on GPU
  • Check if a pure array index works better?
  • Steal other things from torch? https://pytorch.org/docs/stable/data.html#single-and-multi-process-data-loading
    • prefetch
    • memory pinning
    • multiple workers
  • seems like jax just doesn't do memory views. https://jax.readthedocs.io/en/latest/jax.numpy.html. Mem copies prob get optimized out in JIT. Way to JIT the training loop? Could be interesting.
    • if an inner fori should be used, clean up dataset._samples
    • clean up dataloader
    • fix prediction dataloader. right now cumlengths makes the +pred_offset wrap to the next array.
  • check test_synth_data? seems to have pulled an old version
Edited by Kevin Haninger