Skip to content

Re-write of train_ae using TrainState and checking on vmap

Kevin Haninger requested to merge niki_dev_kev into niki_dev

Did a quick re-write of the training loop, checking about jit and vmap usage. I pulled the training step definition out of the loop, used a TrainState which helps package things nicer, and used simple lists outside the calcs b/c there we can just use python native things without too much performance penalty. Pretty sure it's reproducing the original training curve, but have only used autoencoder_testing to validate. Tried to add comments, but tag me here in a comment if something doesn't make sense.

Merge request reports

Loading