diff --git a/dataclass_creator/franka_dataclass.py b/dataclass_creator/franka_dataclass.py index 4d4e1df5d58c562a1d450f6c64172ee8a066f010..c3a277a671bc909a4687647c25f652283a60cde6 100644 --- a/dataclass_creator/franka_dataclass.py +++ b/dataclass_creator/franka_dataclass.py @@ -8,3 +8,8 @@ from opt_tools import SysVarSet class FrankaData(SysVarSet): jt: jnp.ndarray = jnp.zeros(7) force: jnp.ndarray = jnp.zeros(3) + +@dataclass +class URData(SysVarSet): + jt: jnp.ndarray = jnp.zeros(7) + force: jnp.ndarray = jnp.zeros(3) diff --git a/dataclass_creator/utils_dataclass/real_data_gen.py b/dataclass_creator/utils_dataclass/real_data_gen.py index 29b6600694685a71365b506120cfdc9d0a97b173..6029d305495a32281dda0fb3dbe84001042aa2d8 100644 --- a/dataclass_creator/utils_dataclass/real_data_gen.py +++ b/dataclass_creator/utils_dataclass/real_data_gen.py @@ -16,7 +16,7 @@ from typing import Any, Iterable from opt_tools import SysVarSet, Traj, Batch, tree_map_relaxed from ros_datahandler import get_bags_iter from ros_datahandler import JointPosData, ForceData, JointState, WrenchStamped -from dataclass_creator.franka_dataclass import FrankaData +from dataclass_creator.franka_dataclass import FrankaData, URData def min_max_normalize_data(data: jnp.ndarray): @@ -115,15 +115,20 @@ def load_dataset(root='experiments/example_data/train', name: str = 'SUCCESS') - Returns: FormattedDataclass: data containing the modalities corresponding to the attributes initialized for RosDatahandler """ - - bags = glob.glob(join(root, f'*{name}*.bag')) + if 'SUCCESS' in name or 'FAILURE' in name: + bags = glob.glob(join(root, f'*{name}*')) # should give the folders which contain metadata.yaml and the .db3 files + else: + bags = glob.glob(join(root, f'*{name}*.bag')) assert len(bags) > 0, f"I expected bags! No bags on {root}" bags = [Path(bag) for bag in bags] print(f'Containing # {len(bags)} rosbags, named {name}') + # init_urdata = URData(jt = JointPosData('/joint_states', JointState), + # force = ForceData('/force_torque_sensor_broadcaster/wrench', WrenchStamped)) + init_frankadata = FrankaData(jt = JointPosData('/joint_states', JointState), - force = ForceData('/franka_state_controller/F_ext', WrenchStamped)) + force = ForceData('/franka_state_controller/F_ext', WrenchStamped)) # @kevin 2.7.24: data is now a list of Trajs; leaves are np arrays for each bags data = get_bags_iter(bags, init_frankadata, sync_stream='force') @@ -144,13 +149,18 @@ if __name__ == "__main__": logging.basicConfig(level=logging.INFO) #curr_dir = Path(os.path.dirname(os.path.abspath(__file__))) #root = os.fspath(Path(curr_dir.parent.parent, 'data/test').resolve()) - root = 'experiments/01_faston_converging/data/with_search_strategy/test' - data_anom = load_dataset(root=root, name='FAILURE') - data_nom = load_dataset(root=root, name='110_SUCCESS') - visualize_anomaly_bags(data_anom[0], data_nom[0], "040_FAILURE") - visualize_anomaly_bags(data_anom[1], data_nom[0], "050_FAILURE") - visualize_anomaly_bags(data_anom[2], data_nom[0],"214_FAILURE") - visualise_bags(data_anom[0], "040_FAILURE") - visualise_bags(data_anom[1], "050_FAILURE") - visualise_bags(data_anom[2], "214_FAILURE") + # root = 'experiments/01_faston_converging/data/with_search_strategy/test' + # root = 'experiments/03_kraftverlaeufe_klipsen' + # data = load_dataset(root=root, name='rosbag') + + + # data_anom = load_dataset(root=root, name='FAILURE') + # data_nom = load_dataset(root=root, name='110_SUCCESS') + + # visualize_anomaly_bags(data_anom[0], data_nom[0], "040_FAILURE") + # visualize_anomaly_bags(data_anom[1], data_nom[0], "050_FAILURE") + # visualize_anomaly_bags(data_anom[2], data_nom[0],"214_FAILURE") + # visualise_bags(data_anom[0], "040_FAILURE") + # visualise_bags(data_anom[1], "050_FAILURE") + # visualise_bags(data_anom[2], "214_FAILURE") plt.show() diff --git a/model/__init__.py b/model/__init__.py index 4910b8ddc274e18380f3c794d0c390ebc1a9ec0f..10e08a6ea350f5916ba73a03bfd70b8d358fedd8 100644 --- a/model/__init__.py +++ b/model/__init__.py @@ -1,2 +1,3 @@ +from .conv_ae import ConvAE from .autoencoder import AutoEncoder from .trainer import Trainer, TrainerConfig diff --git a/model/conv_ae.py b/model/conv_ae.py new file mode 100644 index 0000000000000000000000000000000000000000..0f6952fa917de8b3d1a221880c9d706e2e1177e5 --- /dev/null +++ b/model/conv_ae.py @@ -0,0 +1,96 @@ +import jax.numpy as jnp + +from jax.flatten_util import ravel_pytree +from flax import linen as nn + +from opt_tools import SysVarSet + +# Check some aspects --> +# Do we really train with all the data? batch_size is currently not initialized (compare the dims). +# Is the num of features correct? 9 / 12 ? + +class Encoder(nn.Module): + latent_dim: int # z latent dimension + T: int # Time window size + d_x: int # Input feature dimension + + @nn.compact + def __call__(self, x: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]: + # Reshape input to include a channel dimension for Conv2D (no batch dimension for now, not sure if required) + x = x.reshape((self.T, self.d_x, 1)) # (T, d_x, channels) + print("Shape after Reshape:", x.shape) + + # Conv2D layers + x = nn.Conv(features=16, kernel_size=(3, 3), strides=(1, 1), padding="SAME")(x) + x = nn.selu(x) + print("Shape after Conv1:", x.shape) + x = nn.Conv(features=32, kernel_size=(3, 3), strides=(1, 1), padding="SAME")(x) + x = nn.selu(x) + print("Shape after Conv2:", x.shape) + + # Flatten the output + x = x.reshape((-1,)) + print("Shape after Flatten:", x.shape) + + # Dense layers for latent variables + mu_z = nn.Dense(features=self.latent_dim)(x) + sigma_z = nn.Dense(features=self.latent_dim)(x) + + return mu_z, sigma_z + + +class Decoder(nn.Module): + latent_dim: int + T: int + d_x: int + + @nn.compact + def __call__(self, z: jnp.ndarray) -> jnp.ndarray: + # Dense layer to expand latent vector to match flattened Conv2D feature map + x = nn.Dense(features=self.T * self.d_x * 32)(z) + x = nn.selu(x) + print("Shape after Dense:", x.shape) + + # Reshape to Conv2D input format + x = x.reshape((self.T, self.d_x, 32)) + print("Shape after Reshape:", x.shape) + + # Transposed Conv2D layers to reconstruct the input + x = nn.ConvTranspose(features=16, kernel_size=(3, 3), strides=(1, 1), padding="SAME")(x) + x = nn.selu(x) + print("Shape after ConvTranspose1:", x.shape) + x = nn.ConvTranspose(features=1, kernel_size=(3, 3), strides=(1, 1), padding="SAME")(x) + print("Shape after ConvTranspose2:", x.shape) + + # Reshape back to time series format (T, d_x) + x = x.reshape((self.T, self.d_x)) + print("Shape after Reshape to Output:", x.shape) + + return x + + +class ConvAE(nn.Module): + latent_dim: int + T: int + d_x: int + + def setup(self): + self.encoder = Encoder(latent_dim=self.latent_dim, T=self.T, d_x=self.d_x) + self.decoder = Decoder(latent_dim=self.latent_dim, T=self.T, d_x=self.d_x) + + def __call__(self, x: SysVarSet) -> SysVarSet: + array, x_tree_def = ravel_pytree(x) + array = array.reshape((self.T, self.d_x)) # Reshape input to (T, d_x) + print("Input shape:", array.shape) + + mu_z, sigma_z = self.encoder(array) + z = mu_z # Use the mean latent vector for simplicity + + x_hat = self.decoder(z) + print("Output shape before reconstruction:", x_hat.shape) + + # Flatten back for SysVarSet + x_hat_flat = x_hat.reshape((-1,)) + x_hat = x_tree_def(x_hat_flat) + + return x_hat \ No newline at end of file diff --git a/model/trainer.py b/model/trainer.py index 621c18f079310c054434013989515899c848fcd4..527de0ad161fc1cba6ddb9a1fb8d9125a4b90e16 100644 --- a/model/trainer.py +++ b/model/trainer.py @@ -4,6 +4,7 @@ from pathlib import Path from dataclasses import dataclass from functools import partial import logging +import matplotlib.pyplot as plt import optax #import tensorflow as tf @@ -15,6 +16,7 @@ from tqdm.auto import tqdm from .autoencoder import AutoEncoder from run.parameter import LoggerParams, OptimizerHParams +from utils.visualize_fn import plot_encoder_decoder_layers from dataclass_creator.dyn_slice_dataloader import DynSliceDataloader from utils.loss_calculation import LossCalculator, ReconstructionLoss, PredictionLoss from opt_tools import SysVarSet, Batch @@ -38,7 +40,6 @@ class TrainerConfig: n_epochs: int = 50 check_val_every_n_epoch: int = 1 - class Trainer: """ Train the model to the dataset, using optimizer and logger settings from config. @@ -87,6 +88,9 @@ class Trainer: n_steps_per_epoch) params = config.model.init(jax.random.PRNGKey(config.seed), init_vec) + + # plot_encoder_decoder_layers(params, headline="Encoder and Decoder Weights (Initial)") + return TrainState.create(apply_fn=config.model.apply, params=params, tx=optimizer) @@ -157,6 +161,8 @@ class Trainer: if i % self.check_val_every_n_epoch == 0: print(f"Training step {i}: {loss_array[i-1]}") + # plot_encoder_decoder_layers(self.state.params, headline="Encoder and Decoder Weights (After Training)") + return self.state.params, loss_array diff --git a/run/train_model.py b/run/train_model.py index 94a87854d524ca7c5e10b26c784dcd401379ba4f..5e7ea527dc6c5447d5dbd9b75c8b1b083dfdfffd 100644 --- a/run/train_model.py +++ b/run/train_model.py @@ -1,12 +1,13 @@ import os import jax +import matplotlib.pyplot as plt from itertools import product from pathlib import Path from dataclass_creator.dyn_slice_dataloader import DynSliceDataloader from dataclass_creator.utils_dataclass.real_data_gen import load_dataset -from model import AutoEncoder, Trainer, TrainerConfig +from model import AutoEncoder, Trainer, TrainerConfig, ConvAE from run.parameter import LoggerParams, AEParams, OptimizerHParams, RepoPaths from run.load_save_model import to_dill @@ -22,9 +23,11 @@ def train_model(train_loader: DynSliceDataloader, ): # --------- initialize autoencoder --------------------------------------------------------------------------------- - ae = AutoEncoder(AEParams.c_hid, AEParams.bottleneck_size, train_loader.out_size) + # ae = ConvAE(AEParams.bottleneck_size, 100, 12) + + # --------- initialize trainer ------------------------------------------------------------------------------------- config = TrainerConfig( purpose=purpose, diff --git a/utils/visualize_fn.py b/utils/visualize_fn.py index 7fcb429ca444d6cbd77a6343330784114e573973..ca8b23d217fe9ef5ecface9641e5c150e4f4a165 100644 --- a/utils/visualize_fn.py +++ b/utils/visualize_fn.py @@ -3,11 +3,13 @@ import jax.numpy as jnp import jax # Types +from flax.core import unfreeze from typing import Iterable from jax import Array from opt_tools import Batch + def plot_loss(loss_array: jnp.ndarray, title: str, x_label: str, @@ -156,3 +158,41 @@ def plot_ROC_curve(tpr: list, fpr: list, auc: float): plt.ylim([0, 1]) plt.ylabel('True Positive Rate') plt.xlabel('False Positive Rate') + +def plot_encoder_decoder_layers(params, headline="Encoder and Decoder Layer Weights"): + + encoder_weights = unfreeze(params)['params']['encoder'] + decoder_weights = unfreeze(params)['params']['decoder'] + + num_encoder_layers = len([key for key in encoder_weights.keys() if 'kernel' in encoder_weights[key]]) + num_decoder_layers = len([key for key in decoder_weights.keys() if 'kernel' in decoder_weights[key]]) + + num_columns = max(num_encoder_layers, num_decoder_layers) + + fig, axes = plt.subplots(2, num_columns, figsize=(5 * num_columns, 10)) + fig.suptitle(headline, fontsize=16) + + # Plot encoder layers + for i, (layer_name, params) in enumerate(encoder_weights.items()): + if 'kernel' in params: + weights = params['kernel'] + ax = axes[0, i] + im = ax.imshow(weights, cmap='viridis', aspect='auto') + ax.set_title(f'Encoder: {layer_name}') + ax.set_xlabel('Neurons (Output Features)') + ax.set_ylabel('Input Features') + fig.colorbar(im, ax=ax) + + # Plot decoder layers + for j, (layer_name, params) in enumerate(decoder_weights.items()): + if 'kernel' in params: + weights = params['kernel'] + ax = axes[1, j] + im = ax.imshow(weights, cmap='viridis', aspect='auto') + ax.set_title(f'Decoder: {layer_name}') + ax.set_xlabel('Neurons (Output Features)') + ax.set_ylabel('Input Features') + fig.colorbar(im, ax=ax) + + plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to fit the title + plt.show() \ No newline at end of file