diff --git a/dataclass_creator/utils_dataclass/real_data_gen.py b/dataclass_creator/utils_dataclass/real_data_gen.py
index 29b6600694685a71365b506120cfdc9d0a97b173..b362c843cc77742f714fd4ad3f1638dd3527fa1d 100644
--- a/dataclass_creator/utils_dataclass/real_data_gen.py
+++ b/dataclass_creator/utils_dataclass/real_data_gen.py
@@ -144,13 +144,7 @@ if __name__ == "__main__":
     logging.basicConfig(level=logging.INFO)
     #curr_dir = Path(os.path.dirname(os.path.abspath(__file__)))
     #root = os.fspath(Path(curr_dir.parent.parent, 'data/test').resolve())
-    root = 'experiments/01_faston_converging/data/with_search_strategy/test'
-    data_anom = load_dataset(root=root, name='FAILURE')
-    data_nom = load_dataset(root=root, name='110_SUCCESS')
-    visualize_anomaly_bags(data_anom[0], data_nom[0],  "040_FAILURE")
-    visualize_anomaly_bags(data_anom[1], data_nom[0], "050_FAILURE")
-    visualize_anomaly_bags(data_anom[2], data_nom[0],"214_FAILURE")
-    visualise_bags(data_anom[0], "040_FAILURE")
-    visualise_bags(data_anom[1], "050_FAILURE")
-    visualise_bags(data_anom[2], "214_FAILURE")
+    root = 'experiments/01_faston_converging/data/with_search_strategy/new_train_data'
+    data = load_dataset(root=root, name='SUCCESS_1')
+    visualise_bags(data[0], "SUCCESS_1")
     plt.show()
diff --git a/experiments/01_faston_converging/cable_insert.yaml b/experiments/01_faston_converging/cable_insert.yaml
index c367685c416d17d62a8aac2ca3c56a76ddc547fc..d320d39ddde2fc132f133e5ae3e40b5b200465fb 100644
--- a/experiments/01_faston_converging/cable_insert.yaml
+++ b/experiments/01_faston_converging/cable_insert.yaml
@@ -52,22 +52,39 @@ approach_cable_grasp_pose:
   - -0.47043495040290134
 grasp_cable:
   pose:
-  - 0.48940379215784663
-  - 0.2933004241716297
-  - 0.006126541950159006
-  - 0.7138474046364929
-  - 0.6996129957181882
-  - -0.02948294706752113
-  - 0.009710558592539402
+  - 0.4826879115012523
+  - 0.29947575674521476
+  - 0.005842341083710317
+  - 0.7149426825384871
+  - 0.6987647735477951
+  - -0.011155675090129779
+  - 0.021454669576905745
   joints:
-  - 1.027455213042836
-  - 0.7278512652033298
-  - -0.4511681262404776
-  - -2.030587917329981
-  - 0.5263882677981989
-  - 2.584431212981766
-  - -0.5594717313261109
+  - 0.6776813118403903
+  - 0.6438322385959123
+  - -0.10709574380422854
+  - -2.093987303803397
+  - 0.0439472730841957
+  - 2.7205313886315206
+  - -0.21129641588076004
   gripper: 0.0
+up_cable:
+  pose:
+  - 0.4818379247707044
+  - 0.3001726937630979
+  - 0.036077533601895584
+  - 0.7180804817530207
+  - 0.695682210565033
+  - -0.0006025150732353679
+  - 0.019655040175116118
+  joints:
+  - 0.6776552141724952
+  - 0.5647839719885451
+  - -0.10776315144603148
+  - -2.133596273117624
+  - 0.04394766067097792
+  - 2.7013461448420375
+  - -0.20672177773000044
 retract:
   pose:
   - 0.48902761437621317
@@ -87,87 +104,71 @@ retract:
   - -0.5966199124958602
 approach_pin:
   pose:
-  - 0.503204932767814
-  - 0.18344931962619784
-  - 0.07611620811279404
-  - 0.02536483565897197
-  - 0.999280347869091
-  - 0.010536075855952514
-  - 0.026161089048040268
+  - 0.4949115240627787
+  - 0.1543959742899967
+  - 0.05083832725493995
+  - 0.038325764577640586
+  - 0.9989499401493926
+  - -0.023576509821620607
+  - 0.008619804504469737
   joints:
-  - 0.793868678156873
-  - 0.43361186193178913
-  - -0.4021309203493717
-  - -2.2919933968744184
-  - 0.4040916941025861
-  - 2.7213993197062702
-  - -2.242867455877782
+  - 0.5182812716865103
+  - 0.4435976145560723
+  - -0.18954777847173723
+  - -2.315438278059576
+  - 0.08054654846850062
+  - 2.7491899913358795
+  - -2.0073420958939794
 insert:
   pose:
-  - 0.5021642363693928
-  - 0.18493170055063776
-  - 0.025945862350009202
-  - 0.03703467298417344
-  - 0.9990214405637642
-  - -0.023448636511955834
-  - 0.005895399612154127
+  - 0.4977208359877092
+  - 0.15338349769552737
+  - 0.027148958653066824
+  - 0.0036238208975668033
+  - -0.9993648007097787
+  - 0.035100159936021724
+  - 0.0049841545794879595
   joints:
-  - 0.7981065009488251
-  - 0.5799657608079216
-  - -0.38623049173387863
-  - -2.220131238555052
-  - 0.405322595285676
-  - 2.722419941876007
-  - -2.187715059698592
+  - 0.5204929499336249
+  - 0.516401936254395
+  - -0.19073320814802766
+  - -2.2701458141558173
+  - 0.08062917972846706
+  - 2.744095154752464
+  - -2.085770595120282
 push:
   pose:
-  - 0.5013547709761427
-  - 0.1852945687153377
-  - 0.019663997174158274
-  - 0.03386252087059072
-  - 0.9990237703114663
-  - -0.02812015403999394
-  - 0.003754060430846745
+  - 0.4980202107190786
+  - 0.1533976634253666
+  - 0.01803492100331841
+  - 0.018056243425659542
+  - -0.9991220119330607
+  - 0.036399973786148336
+  - 0.010208783108377755
   joints:
-  - 0.8019034296719748
-  - 0.5978566783946364
-  - -0.38618090190516735
-  - -2.2122215767819005
-  - 0.40530622425727153
-  - 2.7232870630171937
-  - -2.188345344529137
+  - 0.519169196367411
+  - 0.5427766405437066
+  - -0.18821371965541292
+  - -2.2551019047621064
+  - 0.0903627840442835
+  - 2.7450764061240123
+  - -2.12108451506054
   gripper: 0.007
-retract_from_push:
-  pose:
-  - 0.4989785893822214
-  - 0.18580438264442287
-  - 0.1962621341922801
-  - 0.02739044930762703
-  - 0.9992503532602677
-  - -0.022328776472840685
-  - 0.015808875266201387
-  joints:
-  - 0.8195480770575153
-  - 0.14708330806624798
-  - -0.45307220286235117
-  - -2.27542286520298
-  - 0.016285074296123053
-  - 2.419610465832382
-  - -1.9424311684979763
-unplug_cable:
+up:
   pose:
-  - 0.4989795825312483
-  - 0.18580622320881063
-  - 0.1962621525028913
-  - 0.027387833851976018
-  - 0.9992503749963118
-  - -0.022329889332139356
-  - 0.015810460747280566
+  - 0.49785197434972334
+  - 0.14577086834013042
+  - 0.08665672666075616
+  - 0.0031382251923061937
+  - 0.998566478886717
+  - -0.044666015614166193
+  - 0.029327220045228912
   joints:
-  - 0.8195498694019174
-  - 0.1470867623851681
-  - -0.4530719673858672
-  - -2.275421783860349
-  - 0.016281149566510472
-  - 2.419610268455857
-  - -1.9424262078010301
+  - 0.672089091365226
+  - 0.3621044548052303
+  - -0.35064736127663026
+  - -2.3504683522148
+  - 0.04364192547046914
+  - 2.716781466988086
+  - -2.045407101554594
+
diff --git a/experiments/01_faston_converging/exp_cable_insert.py b/experiments/01_faston_converging/exp_cable_insert.py
index 07eb0107fd9f350a4c103264513814f7c358a2e1..dfb0070a46dd2b871e4417a97e5a3c6dfcd8ccc0 100644
--- a/experiments/01_faston_converging/exp_cable_insert.py
+++ b/experiments/01_faston_converging/exp_cable_insert.py
@@ -1,3 +1,4 @@
+import subprocess, shlex, psutil
 import json
 import yaml
 import sys
@@ -263,13 +264,8 @@ def cable_push(rob, waypoints_cable, STEPWISE= False):
     if STEPWISE:
       uinput = input("Press [Enter] to start moving: ")
       if uinput == "q": return
-    print("retract_from_push")
-    move_fb = rob.move_cart(waypoints_cable["retract_from_push"]["pose"])
-    if STEPWISE:
-      uinput = input("Press [Enter] to start moving: ")
-      if uinput == "q": return
-    print("unplug_cable")
-    move_fb = rob.move_cart(waypoints_cable["unplug_cable"]["pose"])
+    print("up")
+    move_fb = rob.move_cart(waypoints_cable["up"]["pose"])
     if STEPWISE:
       uinput = input("Press [Enter] to start moving: ")
       if uinput == "q": return
@@ -296,6 +292,17 @@ def cable_reset(rob, waypoints_cable, STEPWISE= False):
     move_fb = rob.move_joint(waypoints_cable["start_joint"]["joints"])
     print("cable reset [ DONE ]")
     return 
+  
+def start_record(self):
+  print("Spacemouse moved, Start recording!")
+  
+  
+  
+
+def stop_record(self):
+  
+  print("Space mouse stopped, Stop recording!")
+    
 
     
 if __name__ == "__main__":
@@ -304,10 +311,24 @@ if __name__ == "__main__":
     
     waypoints_cable_insert = import_waypoints_yaml('cable_insert.yaml')
     waypoints_cable_push = import_waypoints_yaml('cable_insert.yaml')
-   
     uinput = input("Press [Enter] to insert: ")
-    cable_insert(rob, waypoints_cable_insert, STEPWISE= True)
-    uinput = input("Press [Enter] to push: ")
-    cable_push(rob, waypoints_cable_push, STEPWISE= True)
+    
+    command = "rosbag record /franka_state_controller/joint_states /joint_states /franka_state_controller/F_ext /tf -O" + "/media/ipk410/Samsung_980_1TB/faston_converging_bags/SUCCESS_11.bag"
+    command = shlex.split(command)
+    rosbag_proc = subprocess.Popen(command)
+    
+    cable_insert(rob, waypoints_cable_insert, STEPWISE= False)
+    cable_push(rob, waypoints_cable_push, STEPWISE= False)
+    
+    for proc in psutil.process_iter():
+      if "record" in proc.name() and set(command[2:]).issubset(proc.cmdline()):
+          proc.send_signal(subprocess.signal.SIGINT)
+    rosbag_proc.send_signal(subprocess.signal.SIGINT)
+    print("Recording stopped")
+    try:
+      rosbag_proc.wait(timeout=1)
+
+    except subprocess.TimeoutExpired:
+      rosbag_proc.kill()
     #cable_reset(rob, waypoints_cable_reset)
     
diff --git a/experiments/01_faston_converging/trained_models_val/with_search_strategy/test_w100_b50_e40_l20/model.dill b/experiments/01_faston_converging/trained_models_val/with_search_strategy/test_w100_b50_e40_l20/model.dill
new file mode 100644
index 0000000000000000000000000000000000000000..784b804d56e2185e92f8c6ad2ea57735ab55088c
Binary files /dev/null and b/experiments/01_faston_converging/trained_models_val/with_search_strategy/test_w100_b50_e40_l20/model.dill differ
diff --git a/experiments/01_faston_converging/trained_models_val/with_search_strategy/test_w100_b50_e40_l20/threshold.dill b/experiments/01_faston_converging/trained_models_val/with_search_strategy/test_w100_b50_e40_l20/threshold.dill
new file mode 100644
index 0000000000000000000000000000000000000000..ea031ee2cfd9df7ccd198dce7eb635a99895c081
Binary files /dev/null and b/experiments/01_faston_converging/trained_models_val/with_search_strategy/test_w100_b50_e40_l20/threshold.dill differ
diff --git a/experiments/01_faston_converging/trained_models_val/with_search_strategy/test_w100_b50_e40_l30/model.dill b/experiments/01_faston_converging/trained_models_val/with_search_strategy/test_w100_b50_e40_l30/model.dill
new file mode 100644
index 0000000000000000000000000000000000000000..f6cedfbafb8c129e54b6ab66c9c395fc2f740c61
Binary files /dev/null and b/experiments/01_faston_converging/trained_models_val/with_search_strategy/test_w100_b50_e40_l30/model.dill differ
diff --git a/experiments/01_faston_converging/trained_models_val/with_search_strategy/test_w100_b50_e40_l30/threshold.dill b/experiments/01_faston_converging/trained_models_val/with_search_strategy/test_w100_b50_e40_l30/threshold.dill
new file mode 100644
index 0000000000000000000000000000000000000000..d0c391708bffd898ab4817971379cfbfc37ebad4
Binary files /dev/null and b/experiments/01_faston_converging/trained_models_val/with_search_strategy/test_w100_b50_e40_l30/threshold.dill differ
diff --git a/ros_adetector.py b/ros_adetector.py
index e51d79f7d7c3f32657ec13fb7797385a054783d3..fa8aa0563251b417aea9fbf6c7a3b92bfaffaccc 100644
--- a/ros_adetector.py
+++ b/ros_adetector.py
@@ -61,13 +61,14 @@ class AnomalyDetector(RosDatahandler):
          - makes a classification based on anomaly_bound
     
     """
-    def __init__(self, data_streams):
-        self.ae_jit, logger_params = load_trained_model_jit(path_to_models=RepoPaths.trained_models_val, test_name="test_w100_b50_e40")
+    def __init__(self, data_streams, model_name: str):
+        self.ae_jit, logger_params = load_trained_model_jit(path_to_models=RepoPaths.trained_models_val, test_name=model_name)
         super().__init__(data_streams,
-                         sync_stream="jt", # when a message arrives on this topic, all the last msgs are added to the data buffer
+                         sync_stream="force", # when a message arrives on this topic, all the last msgs are added to the data buffer
                          window_size=logger_params.window_size)
-        self.threshold = from_dill(path=os.path.join(RepoPaths.trained_models_val, Path("test_w100_b50_e40")),
+        self.threshold = from_dill(path=os.path.join(RepoPaths.trained_models_val, Path(model_name)),
                                    file_name="threshold.dill")
+        print("Model: ", model_name)
         self.start = True # True means get_data() and False nothing, currently for testing
         self.start_datastreams(ros_version = "ros1") # This sets up subscribers for the data_streams we have in the super().__init__
 
@@ -86,6 +87,7 @@ class AnomalyDetector(RosDatahandler):
                                                  Bool,
                                                  self.start_callback)
         self.est_force_pub = rospy.Publisher('/F_ext_recon', WrenchStamped, queue_size=1)
+        self.force_pub = rospy.Publisher('/F_comp', WrenchStamped, queue_size=1)
         self.anomaly_pub = rospy.Publisher('/anomaly_score', WrenchStamped, queue_size=1)
         
         while not self.start:
@@ -102,6 +104,7 @@ class AnomalyDetector(RosDatahandler):
         if not self.start: return
         if not self.all_streams_recieved(): return
         data = self.get_data()
+        self.force_pub.publish(build_wrench_msg(data.force[-1]))
         recon_data = self.ae_jit(data)
         recon_loss = jax.tree_util.tree_map(lambda w, rw: jnp.mean(jnp.abs(w - rw), axis=0), data, recon_data)
         self.est_force_pub.publish(build_wrench_msg(recon_data.force[-1]))
@@ -116,48 +119,46 @@ class AnomalyDetector(RosDatahandler):
         # -> So no jax.tree_map at this point
         # anomaly_counter = jax.tree_map(lambda val, cut_off: val > cut_off, loss, self.threshold)
 
-        loss = jax.tree_util.tree_leaves(loss)
+        leaves_with_paths = jax.tree_util.tree_leaves_with_path(loss)
+        leaves_loss = [l[1] for l in leaves_with_paths]
+        paths = [l[0] for l in leaves_with_paths]
         cut_off = jax.tree_util.tree_leaves(self.threshold)
 
-        for loss_leaf, cut_leaf in zip(loss, cut_off):
-            if jnp.any(loss_leaf > cut_leaf):
-                print("Anomaly Detected!")
+        for name, leaf_loss, cut_leaf in zip(paths, leaves_loss, cut_off):
+            anomalies = jnp.where(leaf_loss > cut_leaf)
+            if len(anomalies[0]) != 0:
+                print("Anomaly Detected in feature: ", anomalies[0], jax.tree_util.keystr(name))
 
 def start_node():
     # we init with a initialzied FrankaData b/c the default values of the attributes in FrankaData
     # have the right topic names, etc.
-    anomaly_detector = AnomalyDetector(FrankaData(jt = JointPosData('/joint_states', JointState),
-                                                      force = ForceData('/franka_state_controller/F_ext', WrenchStamped)))
+    anomaly_detector = AnomalyDetector(data_streams=FrankaData(jt = JointPosData('/joint_states', JointState),
+                                                      force = ForceData('/franka_state_controller/F_ext', WrenchStamped)),
+                                       model_name="test_w100_b20_e50_l15")
     rospy.spin()
 
 def test(len_test: int=15, window_size: int=10):
-    #roscore = subprocess.Popen(["roscore"])
     try:
-        anomaly_detector = AnomalyDetector(FrankaData(jt = JointPosData('/joint_states', JointState),
-                                                      force = ForceData('/franka_state_controller/F_ext', WrenchStamped)))
+        anomaly_detector = AnomalyDetector(data_streams=FrankaData(jt = JointPosData('/joint_states', JointState),
+                                                      force = ForceData('/franka_state_controller/F_ext', WrenchStamped)),
+                                           model_name="test_w100_b20_e50_l15")
+
         rospub_process = subprocess.Popen( ['rostopic', 'pub', '/start_anomaly', 'std_msgs/Bool', 'True', '--once'] )
-        rosbag_process = subprocess.Popen( ['rosbag', 'play', f'-u {len_test}', 'experiments/01_faston_converging/data/with_search_strategy/test/214_FAILURE_+0,0033_not_inserted.bag'] )
-        #echo_process = subprocess.Popen( ['rostopic', 'echo', 'joint_states/position[0]'] )
-        
+        #rosbag_process = subprocess.Popen( ['rosbag', 'play', f'-u {len_test}', 'experiments/01_faston_converging/data/with_search_strategy/test/214_FAILURE_+0,0033_not_inserted.bag'] )
+
         assert type(anomaly_detector.get_data()) == FrankaData
 
         sleep(len_test)
     finally:
         print("Terminating ROS bag process...")
         rospub_process.send_signal(signal.SIGINT)
-        rosbag_process.send_signal(signal.SIGINT)
-        #echo_process.send_signal(signal.SIGINT)
-        #roscore.send_signal(signal.SIGINT)
+        #rosbag_process.send_signal(signal.SIGINT)
         try:
-            rosbag_process.wait(timeout=1)
+            #rosbag_process.wait(timeout=1)
             rospub_process.wait(timeout=1)
-            #echo_process.wait(timeout=1)
-            #roscore.wait(timeout=1)
         except subprocess.TimeoutExpired:
             rospub_process.kill()
-            rosbag_process.kill()
-            #echo_process.kill()
-            #roscore.kill()
+            #rosbag_process.kill()
 
 
 if __name__ == "__main__":
diff --git a/run/construct_thresholds.py b/run/construct_thresholds.py
index 0976bcf4bf06081db5b0b7fa7b1cc0b6a8208734..9c5951beb8ac09c0ec5b04880b258247fc927d7b 100644
--- a/run/construct_thresholds.py
+++ b/run/construct_thresholds.py
@@ -15,9 +15,9 @@ def construct_threshold(name):
     from dataclass_creator.dyn_slice_dataloader import DynSliceDataloader
     from run.parameter import RepoPaths
 
-    print(f"Validation of model {name}")
+    print(f"Threshold calculation of model {name}")
     # load test bags based on bag_name
-    dataset = load_dataset(root=RepoPaths.threshold, name="010_SUCCESS")
+    dataset = load_dataset(root=RepoPaths.threshold, name="SUCCESS")
 
     ae_params, model_params, logger_params = load_trained_model(path_to_models=RepoPaths.trained_models_val,
                                                                 test_name=name)
@@ -37,16 +37,20 @@ def construct_threshold(name):
     recon_windows = jax.vmap(ae_params.apply, in_axes=(None, 0))(model_params, windows)
     loss = tree_map(lambda w, rw: jnp.mean(jnp.abs(w - rw), axis=1),windows, recon_windows)
     threshold = tree_map(lambda l: jnp.max(l, axis=0), loss)
+    threshold = tree_map(lambda l: l + 0.6*l, threshold)
     thresholds = [tree_map(lambda l: l + i*l, threshold) for i in np.arange(-0.9, 0.95, 0.05).tolist()]
-    test_thresholds = [tree_map(lambda l: l + i*l, threshold) for i in np.arange(-0.9, 0.95, 0.3).tolist()]
-    threshold_test = tree_map(lambda l: l + (-0.5)*l, threshold)
-    to_dill(tuple(thresholds), logger_params.path_to_model, "thresholds_roc.dill")
-    to_dill(tuple(test_thresholds), logger_params.path_to_model, "thresholds_testing.dill")
     to_dill(threshold, logger_params.path_to_model, "threshold.dill")
 
 
 if __name__ == "__main__":
-    models = ["test_w100_b50_e40"]
+    models = ["test_w100_b20_e30_l15",
+              "test_w100_b20_e30_l20",
+              "test_w100_b20_e50_l15",
+              "test_w100_b20_e50_l20",
+              "test_w100_b50_e30_l15",
+              "test_w100_b50_e30_l20",
+              "test_w100_b50_e50_l15",
+              "test_w100_b50_e50_l20"]
 
     for name in models:
         construct_threshold(name)
diff --git a/run/parameter.py b/run/parameter.py
index 3caf9448a6288defdaace90216afb432afddc765..51964c49b29b94a03ec5f7414b23490564deefc8 100644
--- a/run/parameter.py
+++ b/run/parameter.py
@@ -49,12 +49,13 @@ class TerminalColour:
 
 class LoggerParams:
 
-    def __init__(self, window_size, batch_size, epochs, path_to_models):
+    def __init__(self, window_size, batch_size, epochs, latent_dim, path_to_models):
         self.window_size = window_size
         self.batch_size = batch_size
         self.epochs = epochs
+        self.latent_dim = latent_dim
         self.path_to_models = path_to_models
-        self.path_to_model = os.path.join(self.path_to_models, Path(f"test_w{window_size}_b{batch_size}_e{epochs}"))
+        self.path_to_model = os.path.join(self.path_to_models, Path(f"test_w{window_size}_b{batch_size}_e{epochs}_l{latent_dim}"))
         self.time_stamp = datetime.now().strftime("%d_%m_%Y-%H_%M_%S")
 
         logging.basicConfig(level=logging.INFO)
@@ -68,7 +69,7 @@ class LoggerParams:
 @dataclass
 class AEParams:
     c_hid: int = 50  # parameters in hidden layer
-    bottleneck_size: int = 15  # latent_dim
+    bottleneck_size: int = 20  # latent_dim
 
     #@kevin 10.8.24: a major advantage of dataclass is the automatic __init__,
     #                by defining __init__ by hand, this is overwritten so we cant
@@ -85,7 +86,7 @@ class RepoPaths:
     """Path names for 01_faston_converging anomaly detection."""
     example_data_train = Path('experiments/example_data/train')
     example_data_test = Path('experiments/example_data/test')
-    data_train = Path('experiments/01_faston_converging/data/with_search_strategy/train')
-    unseen_data = Path('experiments/01_faston_converging/data/with_search_strategy/test')
-    trained_models_val = Path('experiments/01_faston_converging/trained_models_val/with_search_strategy')
-    threshold = Path('experiments/01_faston_converging/data/with_search_strategy/threshold')
+    data_train = Path('experiments/01_faston_converging/data/cable_insert/train')
+    unseen_data = Path('experiments/01_faston_converging/data/cable_insert/test')
+    trained_models_val = Path('experiments/01_faston_converging/trained_models_val/cable_insert_model')
+    threshold = Path('experiments/01_faston_converging/data/cable_insert/test')
diff --git a/run/train_model.py b/run/train_model.py
index 94a87854d524ca7c5e10b26c784dcd401379ba4f..18baac1737cfe06936a9f3f743c5f8a6f1209fbe 100644
--- a/run/train_model.py
+++ b/run/train_model.py
@@ -23,7 +23,7 @@ def train_model(train_loader: DynSliceDataloader,
 
     # --------- initialize autoencoder ---------------------------------------------------------------------------------
 
-    ae = AutoEncoder(AEParams.c_hid, AEParams.bottleneck_size, train_loader.out_size)
+    ae = AutoEncoder(AEParams.c_hid, logger_params.latent_dim, train_loader.out_size)
 
     # --------- initialize trainer -------------------------------------------------------------------------------------
     config = TrainerConfig(
@@ -43,14 +43,16 @@ def train_model(train_loader: DynSliceDataloader,
     return ae, optimized_params, random_threshold
 
 
-def train(path_to_data: Path, window_size: int, batch_size: int, epochs: int, purpose: str):
-    train_data = load_dataset(root=path_to_data, name="SUCCESS")
+def train(path_to_data: Path, window_size: int, batch_size: int, epochs: int, latent_dim: int, purpose: str, train_data=None):
+    if not train_data:
+        train_data = load_dataset(root=path_to_data, name="SUCCESS")
     train_loader = DynSliceDataloader(train_data, window_size=window_size, batch_size=batch_size)
     val_loader = DynSliceDataloader(train_data, window_size=window_size, batch_size=batch_size)
 
     logger_params = LoggerParams(window_size=window_size,
                                  batch_size=batch_size,
                                  epochs=epochs,
+                                 latent_dim=latent_dim,
                                  path_to_models=RepoPaths.trained_models_val)
 
     ae, optimized_params, threshold = train_model(train_loader,
@@ -64,7 +66,7 @@ def train(path_to_data: Path, window_size: int, batch_size: int, epochs: int, pu
     assert isinstance(optimized_params, dict)
 
 
-def train_loop(path_to_data: Path, window_size: list, batch_size: list, epochs: list, purpose: str):
+def train_loop(path_to_data: Path, window_size: list, batch_size: list, epochs: list, latent_dim:list, purpose: str):
     """
     Train multiple models by setting different window, batch sizes and number of epochs.
 
@@ -74,20 +76,21 @@ def train_loop(path_to_data: Path, window_size: list, batch_size: list, epochs:
           batch_size: List of different batch sizes.
           epochs: List of numbers of epochs the model will be trained for.
     """
-    model_params = product(window_size, batch_size, epochs)
-
+    model_params = product(window_size, batch_size, epochs, latent_dim)
+    train_data = load_dataset(root=path_to_data, name="SUCCESS")
     for hparams in model_params:
-        window_size, batch_size, epochs = hparams
-        train(path_to_data=path_to_data, window_size=window_size, batch_size=batch_size, epochs=epochs, purpose=purpose)
+        window_size, batch_size, epochs, latent_dim = hparams
+        train(path_to_data=path_to_data, window_size=window_size, batch_size=batch_size, epochs=epochs, latent_dim=latent_dim, purpose=purpose, train_data=train_data)
 
 if __name__ == "__main__":
     train_config = {
-        "path_to_data": RepoPaths.example_data_train,
+        "path_to_data": RepoPaths.data_train,
         "window_size": [100],
-        "batch_size": [50],
-        "epochs": [20],
+        "batch_size": [20,50],
+        "epochs": [30,50],
+        "latent_dim": [15,20],
         "purpose": "Reconstruction"
     }
-
+    
     train_loop(**train_config)
     os._exit(1)
diff --git a/run/validate_model.py b/run/validate_model.py
index ab0373bb065ee12ba9b96d57680ad962abc1fbe9..97ecba62e2717370517de8559003adf1d734dae2 100644
--- a/run/validate_model.py
+++ b/run/validate_model.py
@@ -6,10 +6,11 @@ import matplotlib.pyplot as plt
 from dataclass_creator.utils_dataclass.real_data_gen import load_dataset
 from run.load_save_model import from_dill
 from run.valmodel import ValModelRecon, save_multi_image
+from run.parameter import RepoPaths
 
 
 def val_model(model_name: str, bag_key: str, bag_names: list, anomalies: list[dict]):
-    from run.parameter import RepoPaths
+
     logging.basicConfig(level=logging.INFO)
     dataset = load_dataset(root=RepoPaths.unseen_data, name=bag_key)
 
@@ -34,19 +35,28 @@ def val_model(model_name: str, bag_key: str, bag_names: list, anomalies: list[di
             val.visualize_scores_over_batches(threshold_name=str(i), threshold= threshold, anomalies=anomaly)
 
         val.get_roc_over_batches(thresholds=thresholds_roc, anomalies=anomaly)
-        save_multi_image(f"val_{bag_name}_{name}.pdf")
+        save_multi_image(f"val_{bag_name}_{model_name}.pdf")
         plt.close('all')
 
 
+def show_some_plots(model_name: str, bag_name: str):
+    dataset = load_dataset(root=RepoPaths.threshold, name=bag_name)
+    val = ValModelRecon(path_to_models=RepoPaths.trained_models_val,
+                        name_model=model_name,
+                        test_traj=dataset[0],
+                        test_name=bag_name)
+
+    val.visualize_windows()
+    val.visualize_loss()
+    plt.show()
+
+
 if __name__ == "__main__":
-    models = ["test_w10_b50_e20",
-              "test_w50_b50_e20",
-              "test_w100_b50_e20",
-              "test_w500_b50_e20",
-              "test_w1000_b50_e20"]
+    model = ["test_w80_b50_e30_l15"]
+    show_some_plots(model[0], bag_name="SUCCESS")
 
-    for name in models:
+    """for name in model:
         val_model(name,
                   "FAILURE",
                   ["040_FAILURE", "050_FAILURE", "214_FAILURE"],
-                  anomalies=[list(range(236, 280)), list(range(249,288)), list(range(258, 296))])
+                  anomalies=[list(range(236, 280)), list(range(249,288)), list(range(258, 296))])"""