Training speed regression
Complete training seems to have regressed. See discussion here
- Write small benchmark for a training iteration
- Git bisect to isolate the regression commit
- Git bisect on GPU
- Check if a pure array index works better?
-
Steal other things from torch? https://pytorch.org/docs/stable/data.html#single-and-multi-process-data-loading
- prefetch
- memory pinning
- multiple workers
-
seems like jax just doesn't do memory views. https://jax.readthedocs.io/en/latest/jax.numpy.html. Mem copies prob get optimized out in JIT. Way to JIT the training loop? Could be interesting.
- if an inner fori should be used, clean up dataset._samples
- clean up dataloader
- fix prediction dataloader. right now cumlengths makes the +pred_offset wrap to the next array.
- check test_synth_data? seems to have pulled an old version
Activity
-
Newest first Oldest first
-
Show all activity Show comments only Show history only
- Kevin Haninger assigned to @hanikevi
assigned to @hanikevi
- Kevin Haninger mentioned in commit ce23ef3c
mentioned in commit ce23ef3c
- Kevin Haninger created branch
49-training-speed-regression
to address this issuecreated branch
49-training-speed-regression
to address this issue - Author Owner
Benchmark laptop example_data_train ddb78c3b, switched to example_data, 5 bags loaded, w10_b50_e20
Edited by Kevin Haninger Collapse replies - Author Owner
056678b2, switch window/batches to match. 5 bags loaded, w10_b50_e20
Edited by Kevin Haninger - Author Owner
3c171020 has nans after first iter, but first iteration is comparable
- Kevin Haninger marked the checklist item Write small benchmark for a training iteration as completed
marked the checklist item Write small benchmark for a training iteration as completed
- Kevin Haninger marked the checklist item Git bisect to isolate the regression commit as completed
marked the checklist item Git bisect to isolate the regression commit as completed
- Author Owner
Collapse replies - Author Owner
056678b2, adjust run.train_data to use right folder, window/batches/epochs
- Author Owner
Testing on 605a7e9a as suggested by Lisa, first get the 20 sec/iter:
But now I realize this is from the iterator being consumed in the first iteration, and I guess tqdm doesnt catch this error and makes the iter estimate from dividing by number of planned epochs. When I fix this bug we are back at 425 sec/iter :/
Edited by Kevin Haninger
- Kevin Haninger marked the checklist item Git bisect on GPU as completed
marked the checklist item Git bisect on GPU as completed
- Author Owner
Collapse replies - Maintainer
Unfortunately I never tested the optimized computation time right before/after it was merged in 605a7e9a. Maybe we achieved those mentioned 20sec per iteration in #36 (closed) by setting prep_batches to "True"?
Edited by Lisa-Marie Fenner
- Author Owner
Collapse replies - Author Owner
removed casting of idxs to array in get_batch on dyn_slice_dataset. they use a yield operator here: https://github.com/BirkhoffG/jax-dataloader/blob/1842b49428e37de721e821bc31d78dc21c6d7390/jax_dataloader/loaders/jax.py#L18
- Author Owner
Total training time (in .train()) at w10, b50, 2 epochs.
no switch in dyn_slice concatenates the trajectories then just adjusts index
# bags prep_batches train_jit jit no switch in dyn slice 5 1.9 3.6 3.5 25 5.6 5.2 5.3 50 10.5 7.8 6.8 75 16.7 10.4 9.0 100 20.1 83.6 10.6 125 25.5 229.4 12.6 Edited by Kevin Haninger - Author Owner
- Kevin Haninger changed the description
changed the description
- Kevin Haninger marked the checklist item seems like jax just doesn't do memory views. https://jax.readthedocs.io/en/latest/jax.numpy.html. Mem copies prob get optimized out in JIT. Way to JIT the training loop? Could be interesting. as completed
marked the checklist item seems like jax just doesn't do memory views. https://jax.readthedocs.io/en/latest/jax.numpy.html. Mem copies prob get optimized out in JIT. Way to JIT the training loop? Could be interesting. as completed
- Kevin Haninger changed the description
changed the description
- Kevin Haninger marked the checklist item Check if a pure array index works better? as completed
marked the checklist item Check if a pure array index works better? as completed
- Kevin Haninger changed the description
changed the description
- Kevin Haninger marked the checklist item if an inner fori should be used, clean up dataset._samples as completed
marked the checklist item if an inner fori should be used, clean up dataset._samples as completed
- Kevin Haninger changed the description
changed the description
- Kevin Haninger marked the checklist item fix prediction dataloader. right now cumlengths makes the +pred_offset wrap to the next array. as completed
marked the checklist item fix prediction dataloader. right now cumlengths makes the +pred_offset wrap to the next array. as completed
- Kevin Haninger changed the description
changed the description
- Kevin Haninger marked the checklist item clean up dataloader as completed
marked the checklist item clean up dataloader as completed
- Kevin Haninger marked the checklist item check test_synth_data? seems to have pulled an old version as completed
marked the checklist item check test_synth_data? seems to have pulled an old version as completed
- Kevin Haninger mentioned in merge request !36 (merged)
mentioned in merge request !36 (merged)
- Niklas Grambow closed with commit bcdba693
closed with commit bcdba693
- Niklas Grambow mentioned in commit bcdba693
mentioned in commit bcdba693