Skip to content

Improve batch handling

complaints:

  1. When generating u_traj for CEM, we need a leaf function which takes the mean and covariance. Right now, this isn't possible with generate_batch, which means we need to use tree_map and cast the result as a batch, which is a bit disappointing.
  2. When adding a traj to a trajbatch, we need to cast to batch first. That's also annoying.
  3. step, output, and cost need lots of case handling for the traj/batch stuf

Goal is to adjust generate_batch or add another function which can handle such initialization cases.

Edited by Kevin Haninger