Adapt shapes and use vmap
Old comment (from !14 (merged)) : In this comment, you suggested ravel pytree. But if I use ravel_pytree, the output is not the same, probably because of dealing with an 1D-array. So my approach was tree_flatten, concatenate and after the encode and decode step, we get the pytree back by using tree_unflatten. But here is my question: I used a very hacky solution to revert the concatenate function here. To easily use the split function we need split_vals to get the leaves back. I get those values here, which is also not great, so maybe we can discuss that.
What happened: Kevin pushed an example for testing c16f2e7e
What to be done: Adapt the structure, that vmap runs outside the train_step and processes batches. Following vmap gives one window to the ae-architecture.