diff --git a/dataclass_creator/franka_dataclass.py b/dataclass_creator/franka_dataclass.py
index 4d4e1df5d58c562a1d450f6c64172ee8a066f010..c3a277a671bc909a4687647c25f652283a60cde6 100644
--- a/dataclass_creator/franka_dataclass.py
+++ b/dataclass_creator/franka_dataclass.py
@@ -8,3 +8,8 @@ from opt_tools import SysVarSet
 class FrankaData(SysVarSet):
     jt: jnp.ndarray = jnp.zeros(7)
     force: jnp.ndarray = jnp.zeros(3)
+
+@dataclass
+class URData(SysVarSet):
+    jt: jnp.ndarray = jnp.zeros(7)
+    force: jnp.ndarray = jnp.zeros(3)
diff --git a/dataclass_creator/utils_dataclass/real_data_gen.py b/dataclass_creator/utils_dataclass/real_data_gen.py
index 29b6600694685a71365b506120cfdc9d0a97b173..6029d305495a32281dda0fb3dbe84001042aa2d8 100644
--- a/dataclass_creator/utils_dataclass/real_data_gen.py
+++ b/dataclass_creator/utils_dataclass/real_data_gen.py
@@ -16,7 +16,7 @@ from typing import Any, Iterable
 from opt_tools  import SysVarSet, Traj, Batch, tree_map_relaxed
 from ros_datahandler import get_bags_iter
 from ros_datahandler import JointPosData, ForceData, JointState, WrenchStamped
-from dataclass_creator.franka_dataclass import FrankaData
+from dataclass_creator.franka_dataclass import FrankaData, URData
 
 
 def min_max_normalize_data(data: jnp.ndarray):
@@ -115,15 +115,20 @@ def load_dataset(root='experiments/example_data/train', name: str = 'SUCCESS') -
     Returns:
         FormattedDataclass: data containing the modalities corresponding to the attributes initialized for RosDatahandler
     """
-
-    bags = glob.glob(join(root, f'*{name}*.bag'))
+    if 'SUCCESS' in name or 'FAILURE' in name:
+        bags = glob.glob(join(root, f'*{name}*')) # should give the folders which contain metadata.yaml and the .db3 files    
+    else:
+        bags = glob.glob(join(root, f'*{name}*.bag'))
     assert len(bags) > 0, f"I expected bags! No bags on {root}"
     bags = [Path(bag) for bag in bags]
 
     print(f'Containing # {len(bags)} rosbags, named {name}')
 
+    # init_urdata = URData(jt = JointPosData('/joint_states', JointState),
+    #                             force = ForceData('/force_torque_sensor_broadcaster/wrench', WrenchStamped))
+    
     init_frankadata = FrankaData(jt = JointPosData('/joint_states', JointState),
-                                 force = ForceData('/franka_state_controller/F_ext', WrenchStamped))
+                                force = ForceData('/franka_state_controller/F_ext', WrenchStamped))
 
     # @kevin 2.7.24: data is now a list of Trajs; leaves are np arrays for each bags
     data = get_bags_iter(bags, init_frankadata, sync_stream='force')
@@ -144,13 +149,18 @@ if __name__ == "__main__":
     logging.basicConfig(level=logging.INFO)
     #curr_dir = Path(os.path.dirname(os.path.abspath(__file__)))
     #root = os.fspath(Path(curr_dir.parent.parent, 'data/test').resolve())
-    root = 'experiments/01_faston_converging/data/with_search_strategy/test'
-    data_anom = load_dataset(root=root, name='FAILURE')
-    data_nom = load_dataset(root=root, name='110_SUCCESS')
-    visualize_anomaly_bags(data_anom[0], data_nom[0],  "040_FAILURE")
-    visualize_anomaly_bags(data_anom[1], data_nom[0], "050_FAILURE")
-    visualize_anomaly_bags(data_anom[2], data_nom[0],"214_FAILURE")
-    visualise_bags(data_anom[0], "040_FAILURE")
-    visualise_bags(data_anom[1], "050_FAILURE")
-    visualise_bags(data_anom[2], "214_FAILURE")
+    # root = 'experiments/01_faston_converging/data/with_search_strategy/test'
+    # root = 'experiments/03_kraftverlaeufe_klipsen'
+    # data = load_dataset(root=root, name='rosbag')
+
+
+    # data_anom = load_dataset(root=root, name='FAILURE')
+    # data_nom = load_dataset(root=root, name='110_SUCCESS')
+
+    # visualize_anomaly_bags(data_anom[0], data_nom[0],  "040_FAILURE")
+    # visualize_anomaly_bags(data_anom[1], data_nom[0], "050_FAILURE")
+    # visualize_anomaly_bags(data_anom[2], data_nom[0],"214_FAILURE")
+    # visualise_bags(data_anom[0], "040_FAILURE")
+    # visualise_bags(data_anom[1], "050_FAILURE")
+    # visualise_bags(data_anom[2], "214_FAILURE")
     plt.show()
diff --git a/model/__init__.py b/model/__init__.py
index 4910b8ddc274e18380f3c794d0c390ebc1a9ec0f..10e08a6ea350f5916ba73a03bfd70b8d358fedd8 100644
--- a/model/__init__.py
+++ b/model/__init__.py
@@ -1,2 +1,3 @@
+from .conv_ae import ConvAE
 from .autoencoder import AutoEncoder
 from .trainer import Trainer, TrainerConfig
diff --git a/model/conv_ae.py b/model/conv_ae.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f6952fa917de8b3d1a221880c9d706e2e1177e5
--- /dev/null
+++ b/model/conv_ae.py
@@ -0,0 +1,96 @@
+import jax.numpy as jnp
+
+from jax.flatten_util import ravel_pytree
+from flax import linen as nn
+
+from opt_tools import SysVarSet
+
+# Check some aspects -->
+# Do we really train with all the data? batch_size is currently not initialized (compare the dims).
+# Is the num of features correct? 9 / 12 ? 
+
+class Encoder(nn.Module):
+    latent_dim: int  # z latent dimension
+    T: int           # Time window size
+    d_x: int         # Input feature dimension
+
+    @nn.compact
+    def __call__(self, x: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
+        # Reshape input to include a channel dimension for Conv2D (no batch dimension for now, not sure if required)
+        x = x.reshape((self.T, self.d_x, 1))  # (T, d_x, channels)
+        print("Shape after Reshape:", x.shape)
+
+        # Conv2D layers
+        x = nn.Conv(features=16, kernel_size=(3, 3), strides=(1, 1), padding="SAME")(x)
+        x = nn.selu(x)
+        print("Shape after Conv1:", x.shape)
+        x = nn.Conv(features=32, kernel_size=(3, 3), strides=(1, 1), padding="SAME")(x)
+        x = nn.selu(x)
+        print("Shape after Conv2:", x.shape)
+
+        # Flatten the output
+        x = x.reshape((-1,))
+        print("Shape after Flatten:", x.shape)
+
+        # Dense layers for latent variables
+        mu_z = nn.Dense(features=self.latent_dim)(x)
+        sigma_z = nn.Dense(features=self.latent_dim)(x)
+
+        return mu_z, sigma_z
+
+
+class Decoder(nn.Module):
+    latent_dim: int
+    T: int          
+    d_x: int
+
+    @nn.compact
+    def __call__(self, z: jnp.ndarray) -> jnp.ndarray:
+        # Dense layer to expand latent vector to match flattened Conv2D feature map
+        x = nn.Dense(features=self.T * self.d_x * 32)(z)
+        x = nn.selu(x)
+        print("Shape after Dense:", x.shape)
+
+        # Reshape to Conv2D input format
+        x = x.reshape((self.T, self.d_x, 32))
+        print("Shape after Reshape:", x.shape)
+
+        # Transposed Conv2D layers to reconstruct the input
+        x = nn.ConvTranspose(features=16, kernel_size=(3, 3), strides=(1, 1), padding="SAME")(x)
+        x = nn.selu(x)
+        print("Shape after ConvTranspose1:", x.shape)
+        x = nn.ConvTranspose(features=1, kernel_size=(3, 3), strides=(1, 1), padding="SAME")(x)
+        print("Shape after ConvTranspose2:", x.shape)
+
+        # Reshape back to time series format (T, d_x)
+        x = x.reshape((self.T, self.d_x))
+        print("Shape after Reshape to Output:", x.shape)
+
+        return x
+
+
+class ConvAE(nn.Module):
+    latent_dim: int
+    T: int
+    d_x: int
+
+    def setup(self):
+        self.encoder = Encoder(latent_dim=self.latent_dim, T=self.T, d_x=self.d_x)
+        self.decoder = Decoder(latent_dim=self.latent_dim, T=self.T, d_x=self.d_x)
+
+    def __call__(self, x: SysVarSet) -> SysVarSet:
+        array, x_tree_def = ravel_pytree(x)
+        array = array.reshape((self.T, self.d_x))  # Reshape input to (T, d_x)
+        print("Input shape:", array.shape)
+
+        mu_z, sigma_z = self.encoder(array)
+        z = mu_z  # Use the mean latent vector for simplicity
+
+        x_hat = self.decoder(z)
+        print("Output shape before reconstruction:", x_hat.shape)
+
+        # Flatten back for SysVarSet
+        x_hat_flat = x_hat.reshape((-1,))
+        x_hat = x_tree_def(x_hat_flat)
+
+        return x_hat
\ No newline at end of file
diff --git a/model/trainer.py b/model/trainer.py
index 621c18f079310c054434013989515899c848fcd4..527de0ad161fc1cba6ddb9a1fb8d9125a4b90e16 100644
--- a/model/trainer.py
+++ b/model/trainer.py
@@ -4,6 +4,7 @@ from pathlib import Path
 from dataclasses import dataclass
 from functools import partial
 import logging
+import matplotlib.pyplot as plt
 
 import optax
 #import tensorflow as tf
@@ -15,6 +16,7 @@ from tqdm.auto import tqdm
 
 from .autoencoder import AutoEncoder
 from run.parameter import LoggerParams, OptimizerHParams
+from utils.visualize_fn import plot_encoder_decoder_layers
 from dataclass_creator.dyn_slice_dataloader import DynSliceDataloader
 from utils.loss_calculation import LossCalculator, ReconstructionLoss, PredictionLoss
 from opt_tools import SysVarSet, Batch
@@ -38,7 +40,6 @@ class TrainerConfig:
     n_epochs: int = 50
     check_val_every_n_epoch: int = 1
 
-
 class Trainer:
     """
     Train the model to the dataset, using optimizer and logger settings from config.
@@ -87,6 +88,9 @@ class Trainer:
                                                     n_steps_per_epoch)
 
         params = config.model.init(jax.random.PRNGKey(config.seed), init_vec)
+
+        # plot_encoder_decoder_layers(params, headline="Encoder and Decoder Weights (Initial)")
+
         return TrainState.create(apply_fn=config.model.apply,
                                  params=params,
                                  tx=optimizer)
@@ -157,6 +161,8 @@ class Trainer:
             if i % self.check_val_every_n_epoch == 0:
                 print(f"Training step {i}: {loss_array[i-1]}")
 
+        # plot_encoder_decoder_layers(self.state.params, headline="Encoder and Decoder Weights (After Training)")
+
         return self.state.params, loss_array
         
         
diff --git a/run/train_model.py b/run/train_model.py
index 94a87854d524ca7c5e10b26c784dcd401379ba4f..5e7ea527dc6c5447d5dbd9b75c8b1b083dfdfffd 100644
--- a/run/train_model.py
+++ b/run/train_model.py
@@ -1,12 +1,13 @@
 import os
 import jax
+import matplotlib.pyplot as plt
 
 from itertools import product
 from pathlib import Path
 
 from dataclass_creator.dyn_slice_dataloader import DynSliceDataloader
 from dataclass_creator.utils_dataclass.real_data_gen import load_dataset
-from model import AutoEncoder, Trainer, TrainerConfig
+from model import AutoEncoder, Trainer, TrainerConfig, ConvAE
 from run.parameter import LoggerParams, AEParams, OptimizerHParams, RepoPaths
 from run.load_save_model import to_dill
 
@@ -22,9 +23,11 @@ def train_model(train_loader: DynSliceDataloader,
                 ):
 
     # --------- initialize autoencoder ---------------------------------------------------------------------------------
-
     ae = AutoEncoder(AEParams.c_hid, AEParams.bottleneck_size, train_loader.out_size)
 
+    # ae = ConvAE(AEParams.bottleneck_size, 100, 12)
+
+
     # --------- initialize trainer -------------------------------------------------------------------------------------
     config = TrainerConfig(
         purpose=purpose,
diff --git a/utils/visualize_fn.py b/utils/visualize_fn.py
index 7fcb429ca444d6cbd77a6343330784114e573973..ca8b23d217fe9ef5ecface9641e5c150e4f4a165 100644
--- a/utils/visualize_fn.py
+++ b/utils/visualize_fn.py
@@ -3,11 +3,13 @@ import jax.numpy as jnp
 import jax
 
 # Types
+from flax.core import unfreeze
 from typing import Iterable
 from jax import Array
 from opt_tools import Batch
 
 
+
 def plot_loss(loss_array: jnp.ndarray,
               title: str,
               x_label: str,
@@ -156,3 +158,41 @@ def plot_ROC_curve(tpr: list, fpr: list, auc: float):
     plt.ylim([0, 1])
     plt.ylabel('True Positive Rate')
     plt.xlabel('False Positive Rate')
+
+def plot_encoder_decoder_layers(params, headline="Encoder and Decoder Layer Weights"):
+
+    encoder_weights = unfreeze(params)['params']['encoder']
+    decoder_weights = unfreeze(params)['params']['decoder']
+    
+    num_encoder_layers = len([key for key in encoder_weights.keys() if 'kernel' in encoder_weights[key]])
+    num_decoder_layers = len([key for key in decoder_weights.keys() if 'kernel' in decoder_weights[key]])
+
+    num_columns = max(num_encoder_layers, num_decoder_layers)
+
+    fig, axes = plt.subplots(2, num_columns, figsize=(5 * num_columns, 10))
+    fig.suptitle(headline, fontsize=16)
+
+    # Plot encoder layers
+    for i, (layer_name, params) in enumerate(encoder_weights.items()):
+        if 'kernel' in params: 
+            weights = params['kernel']
+            ax = axes[0, i]
+            im = ax.imshow(weights, cmap='viridis', aspect='auto')
+            ax.set_title(f'Encoder: {layer_name}')
+            ax.set_xlabel('Neurons (Output Features)')
+            ax.set_ylabel('Input Features')
+            fig.colorbar(im, ax=ax)
+            
+    # Plot decoder layers
+    for j, (layer_name, params) in enumerate(decoder_weights.items()):
+        if 'kernel' in params:
+            weights = params['kernel']
+            ax = axes[1, j]
+            im = ax.imshow(weights, cmap='viridis', aspect='auto')
+            ax.set_title(f'Decoder: {layer_name}')
+            ax.set_xlabel('Neurons (Output Features)')
+            ax.set_ylabel('Input Features')
+            fig.colorbar(im, ax=ax)
+
+    plt.tight_layout(rect=[0, 0.03, 1, 0.95])  # Adjust layout to fit the title
+    plt.show()
\ No newline at end of file