Error while working with to_batch
Hi Kevin,
you suggested to use to_batch instead of the add function. I exchanged it here but if I run test in real_data_test.py the to_batch function causes an error in tree_map_relaxed
Maybe I used the funtion not in the correct way or there is something wrong. Nevertheless I still get an "out of memory" error if I use jnp.stack.