Skip to content
Snippets Groups Projects

Restructure data/model parameters

Merged Kevin Haninger requested to merge 65-single-point-of-truth-for-data-and-model into main
1 unresolved thread
Files
28
@@ -11,12 +11,10 @@ from os.path import join
from pathlib import Path
from typing import Any, Iterable
#from optree import tree_map
from opt_tools import SysVarSet, Traj, Batch, tree_map_relaxed
from opt_tools.jax_tools import SysVarSet, Traj, Batch, tree_map_relaxed
from ros_datahandler import get_bags_iter
from ros_datahandler import JointPosData, ForceData, JointState, WrenchStamped
from dataclass_creator.franka_dataclass import FrankaData
from dataclass_creator.custom_dataclasses import FrankaData
def min_max_normalize_data(data: jnp.ndarray):
@@ -105,7 +103,12 @@ def bag_info(root='data/without_search_strategy/', element: int = 0):
print('Topics of one Rosbag:', topics)
def load_dataset(root='experiments/example_data/train', name: str = 'SUCCESS') -> Iterable[Traj]:
def load_dataset(
datatype: SysVarSet,
sync_stream: str,
root: Path,
name: str = "SUCCESS",
) -> Iterable[Traj]:
"""collects rosbags in specified dir and return a formatted dataclass containg ragged list*3
Args:
@@ -116,26 +119,15 @@ def load_dataset(root='experiments/example_data/train', name: str = 'SUCCESS') -
FormattedDataclass: data containing the modalities corresponding to the attributes initialized for RosDatahandler
"""
bags = glob.glob(join(root, f'*{name}*.bag'))
assert len(bags) > 0, f"I expected bags! No bags on {root}"
bags = [Path(bag) for bag in bags]
print(f'Containing # {len(bags)} rosbags, named {name}')
init_frankadata = FrankaData(jt = JointPosData('/joint_states', JointState),
force = ForceData('/franka_state_controller/F_ext', WrenchStamped))
bags = glob.glob(join(root, f"*{name}*"))
assert len(bags) > 0, f"I expected bags! No bags on {root} with filter {name}"
# @kevin 2.7.24: data is now a list of Trajs; leaves are np arrays for each bags
data = get_bags_iter(bags, init_frankadata, sync_stream='force')
bags = [Path(bag) for bag in bags]
logging.info(f'Containing # {len(bags)} rosbags for filter {name}')
data = get_bags_iter(bags, datatype, sync_stream=sync_stream)
data = [d.cast_traj() for d in data]
assert len(data) == len(bags)
assert isinstance(data[0], FrankaData), 'data is somehow not a FrankaData? huh?'
assert isinstance(data[0], SysVarSet), 'data is somehow not a sysvarset? huh?'
assert isinstance(data[0], Traj), 'data should be a traj'
assert not isinstance(data[0], Batch), 'data should not be a batch yet'
return data
Loading