Re-write of train_ae using TrainState and checking on vmap
Did a quick re-write of the training loop, checking about jit and vmap usage. I pulled the training step definition out of the loop, used a TrainState which helps package things nicer, and used simple lists outside the calcs b/c there we can just use python native things without too much performance penalty. Pretty sure it's reproducing the original training curve, but have only used autoencoder_testing to validate. Tried to add comments, but tag me here in a comment if something doesn't make sense.