Improve batch handling
complaints:
- When generating u_traj for CEM, we need a leaf function which takes the mean and covariance. Right now, this isn't possible with generate_batch, which means we need to use tree_map and cast the result as a batch, which is a bit disappointing.
- When adding a traj to a trajbatch, we need to cast to batch first. That's also annoying.
- step, output, and cost need lots of case handling for the traj/batch stuf
Goal is to adjust generate_batch or add another function which can handle such initialization cases.
Edited by Kevin Haninger