diff --git a/dataclass_creator/utils_dataclass/real_data_gen.py b/dataclass_creator/utils_dataclass/real_data_gen.py index 29b6600694685a71365b506120cfdc9d0a97b173..b362c843cc77742f714fd4ad3f1638dd3527fa1d 100644 --- a/dataclass_creator/utils_dataclass/real_data_gen.py +++ b/dataclass_creator/utils_dataclass/real_data_gen.py @@ -144,13 +144,7 @@ if __name__ == "__main__": logging.basicConfig(level=logging.INFO) #curr_dir = Path(os.path.dirname(os.path.abspath(__file__))) #root = os.fspath(Path(curr_dir.parent.parent, 'data/test').resolve()) - root = 'experiments/01_faston_converging/data/with_search_strategy/test' - data_anom = load_dataset(root=root, name='FAILURE') - data_nom = load_dataset(root=root, name='110_SUCCESS') - visualize_anomaly_bags(data_anom[0], data_nom[0], "040_FAILURE") - visualize_anomaly_bags(data_anom[1], data_nom[0], "050_FAILURE") - visualize_anomaly_bags(data_anom[2], data_nom[0],"214_FAILURE") - visualise_bags(data_anom[0], "040_FAILURE") - visualise_bags(data_anom[1], "050_FAILURE") - visualise_bags(data_anom[2], "214_FAILURE") + root = 'experiments/01_faston_converging/data/with_search_strategy/new_train_data' + data = load_dataset(root=root, name='SUCCESS_1') + visualise_bags(data[0], "SUCCESS_1") plt.show() diff --git a/experiments/01_faston_converging/cable_insert.yaml b/experiments/01_faston_converging/cable_insert.yaml index c367685c416d17d62a8aac2ca3c56a76ddc547fc..d320d39ddde2fc132f133e5ae3e40b5b200465fb 100644 --- a/experiments/01_faston_converging/cable_insert.yaml +++ b/experiments/01_faston_converging/cable_insert.yaml @@ -52,22 +52,39 @@ approach_cable_grasp_pose: - -0.47043495040290134 grasp_cable: pose: - - 0.48940379215784663 - - 0.2933004241716297 - - 0.006126541950159006 - - 0.7138474046364929 - - 0.6996129957181882 - - -0.02948294706752113 - - 0.009710558592539402 + - 0.4826879115012523 + - 0.29947575674521476 + - 0.005842341083710317 + - 0.7149426825384871 + - 0.6987647735477951 + - -0.011155675090129779 + - 0.021454669576905745 joints: - - 1.027455213042836 - - 0.7278512652033298 - - -0.4511681262404776 - - -2.030587917329981 - - 0.5263882677981989 - - 2.584431212981766 - - -0.5594717313261109 + - 0.6776813118403903 + - 0.6438322385959123 + - -0.10709574380422854 + - -2.093987303803397 + - 0.0439472730841957 + - 2.7205313886315206 + - -0.21129641588076004 gripper: 0.0 +up_cable: + pose: + - 0.4818379247707044 + - 0.3001726937630979 + - 0.036077533601895584 + - 0.7180804817530207 + - 0.695682210565033 + - -0.0006025150732353679 + - 0.019655040175116118 + joints: + - 0.6776552141724952 + - 0.5647839719885451 + - -0.10776315144603148 + - -2.133596273117624 + - 0.04394766067097792 + - 2.7013461448420375 + - -0.20672177773000044 retract: pose: - 0.48902761437621317 @@ -87,87 +104,71 @@ retract: - -0.5966199124958602 approach_pin: pose: - - 0.503204932767814 - - 0.18344931962619784 - - 0.07611620811279404 - - 0.02536483565897197 - - 0.999280347869091 - - 0.010536075855952514 - - 0.026161089048040268 + - 0.4949115240627787 + - 0.1543959742899967 + - 0.05083832725493995 + - 0.038325764577640586 + - 0.9989499401493926 + - -0.023576509821620607 + - 0.008619804504469737 joints: - - 0.793868678156873 - - 0.43361186193178913 - - -0.4021309203493717 - - -2.2919933968744184 - - 0.4040916941025861 - - 2.7213993197062702 - - -2.242867455877782 + - 0.5182812716865103 + - 0.4435976145560723 + - -0.18954777847173723 + - -2.315438278059576 + - 0.08054654846850062 + - 2.7491899913358795 + - -2.0073420958939794 insert: pose: - - 0.5021642363693928 - - 0.18493170055063776 - - 0.025945862350009202 - - 0.03703467298417344 - - 0.9990214405637642 - - -0.023448636511955834 - - 0.005895399612154127 + - 0.4977208359877092 + - 0.15338349769552737 + - 0.027148958653066824 + - 0.0036238208975668033 + - -0.9993648007097787 + - 0.035100159936021724 + - 0.0049841545794879595 joints: - - 0.7981065009488251 - - 0.5799657608079216 - - -0.38623049173387863 - - -2.220131238555052 - - 0.405322595285676 - - 2.722419941876007 - - -2.187715059698592 + - 0.5204929499336249 + - 0.516401936254395 + - -0.19073320814802766 + - -2.2701458141558173 + - 0.08062917972846706 + - 2.744095154752464 + - -2.085770595120282 push: pose: - - 0.5013547709761427 - - 0.1852945687153377 - - 0.019663997174158274 - - 0.03386252087059072 - - 0.9990237703114663 - - -0.02812015403999394 - - 0.003754060430846745 + - 0.4980202107190786 + - 0.1533976634253666 + - 0.01803492100331841 + - 0.018056243425659542 + - -0.9991220119330607 + - 0.036399973786148336 + - 0.010208783108377755 joints: - - 0.8019034296719748 - - 0.5978566783946364 - - -0.38618090190516735 - - -2.2122215767819005 - - 0.40530622425727153 - - 2.7232870630171937 - - -2.188345344529137 + - 0.519169196367411 + - 0.5427766405437066 + - -0.18821371965541292 + - -2.2551019047621064 + - 0.0903627840442835 + - 2.7450764061240123 + - -2.12108451506054 gripper: 0.007 -retract_from_push: - pose: - - 0.4989785893822214 - - 0.18580438264442287 - - 0.1962621341922801 - - 0.02739044930762703 - - 0.9992503532602677 - - -0.022328776472840685 - - 0.015808875266201387 - joints: - - 0.8195480770575153 - - 0.14708330806624798 - - -0.45307220286235117 - - -2.27542286520298 - - 0.016285074296123053 - - 2.419610465832382 - - -1.9424311684979763 -unplug_cable: +up: pose: - - 0.4989795825312483 - - 0.18580622320881063 - - 0.1962621525028913 - - 0.027387833851976018 - - 0.9992503749963118 - - -0.022329889332139356 - - 0.015810460747280566 + - 0.49785197434972334 + - 0.14577086834013042 + - 0.08665672666075616 + - 0.0031382251923061937 + - 0.998566478886717 + - -0.044666015614166193 + - 0.029327220045228912 joints: - - 0.8195498694019174 - - 0.1470867623851681 - - -0.4530719673858672 - - -2.275421783860349 - - 0.016281149566510472 - - 2.419610268455857 - - -1.9424262078010301 + - 0.672089091365226 + - 0.3621044548052303 + - -0.35064736127663026 + - -2.3504683522148 + - 0.04364192547046914 + - 2.716781466988086 + - -2.045407101554594 + diff --git a/experiments/01_faston_converging/exp_cable_insert.py b/experiments/01_faston_converging/exp_cable_insert.py index 07eb0107fd9f350a4c103264513814f7c358a2e1..dfb0070a46dd2b871e4417a97e5a3c6dfcd8ccc0 100644 --- a/experiments/01_faston_converging/exp_cable_insert.py +++ b/experiments/01_faston_converging/exp_cable_insert.py @@ -1,3 +1,4 @@ +import subprocess, shlex, psutil import json import yaml import sys @@ -263,13 +264,8 @@ def cable_push(rob, waypoints_cable, STEPWISE= False): if STEPWISE: uinput = input("Press [Enter] to start moving: ") if uinput == "q": return - print("retract_from_push") - move_fb = rob.move_cart(waypoints_cable["retract_from_push"]["pose"]) - if STEPWISE: - uinput = input("Press [Enter] to start moving: ") - if uinput == "q": return - print("unplug_cable") - move_fb = rob.move_cart(waypoints_cable["unplug_cable"]["pose"]) + print("up") + move_fb = rob.move_cart(waypoints_cable["up"]["pose"]) if STEPWISE: uinput = input("Press [Enter] to start moving: ") if uinput == "q": return @@ -296,6 +292,17 @@ def cable_reset(rob, waypoints_cable, STEPWISE= False): move_fb = rob.move_joint(waypoints_cable["start_joint"]["joints"]) print("cable reset [ DONE ]") return + +def start_record(self): + print("Spacemouse moved, Start recording!") + + + + +def stop_record(self): + + print("Space mouse stopped, Stop recording!") + if __name__ == "__main__": @@ -304,10 +311,24 @@ if __name__ == "__main__": waypoints_cable_insert = import_waypoints_yaml('cable_insert.yaml') waypoints_cable_push = import_waypoints_yaml('cable_insert.yaml') - uinput = input("Press [Enter] to insert: ") - cable_insert(rob, waypoints_cable_insert, STEPWISE= True) - uinput = input("Press [Enter] to push: ") - cable_push(rob, waypoints_cable_push, STEPWISE= True) + + command = "rosbag record /franka_state_controller/joint_states /joint_states /franka_state_controller/F_ext /tf -O" + "/media/ipk410/Samsung_980_1TB/faston_converging_bags/SUCCESS_11.bag" + command = shlex.split(command) + rosbag_proc = subprocess.Popen(command) + + cable_insert(rob, waypoints_cable_insert, STEPWISE= False) + cable_push(rob, waypoints_cable_push, STEPWISE= False) + + for proc in psutil.process_iter(): + if "record" in proc.name() and set(command[2:]).issubset(proc.cmdline()): + proc.send_signal(subprocess.signal.SIGINT) + rosbag_proc.send_signal(subprocess.signal.SIGINT) + print("Recording stopped") + try: + rosbag_proc.wait(timeout=1) + + except subprocess.TimeoutExpired: + rosbag_proc.kill() #cable_reset(rob, waypoints_cable_reset) diff --git a/experiments/01_faston_converging/trained_models_val/with_search_strategy/test_w100_b50_e40_l20/model.dill b/experiments/01_faston_converging/trained_models_val/with_search_strategy/test_w100_b50_e40_l20/model.dill new file mode 100644 index 0000000000000000000000000000000000000000..784b804d56e2185e92f8c6ad2ea57735ab55088c Binary files /dev/null and b/experiments/01_faston_converging/trained_models_val/with_search_strategy/test_w100_b50_e40_l20/model.dill differ diff --git a/experiments/01_faston_converging/trained_models_val/with_search_strategy/test_w100_b50_e40_l20/threshold.dill b/experiments/01_faston_converging/trained_models_val/with_search_strategy/test_w100_b50_e40_l20/threshold.dill new file mode 100644 index 0000000000000000000000000000000000000000..ea031ee2cfd9df7ccd198dce7eb635a99895c081 Binary files /dev/null and b/experiments/01_faston_converging/trained_models_val/with_search_strategy/test_w100_b50_e40_l20/threshold.dill differ diff --git a/experiments/01_faston_converging/trained_models_val/with_search_strategy/test_w100_b50_e40_l30/model.dill b/experiments/01_faston_converging/trained_models_val/with_search_strategy/test_w100_b50_e40_l30/model.dill new file mode 100644 index 0000000000000000000000000000000000000000..f6cedfbafb8c129e54b6ab66c9c395fc2f740c61 Binary files /dev/null and b/experiments/01_faston_converging/trained_models_val/with_search_strategy/test_w100_b50_e40_l30/model.dill differ diff --git a/experiments/01_faston_converging/trained_models_val/with_search_strategy/test_w100_b50_e40_l30/threshold.dill b/experiments/01_faston_converging/trained_models_val/with_search_strategy/test_w100_b50_e40_l30/threshold.dill new file mode 100644 index 0000000000000000000000000000000000000000..d0c391708bffd898ab4817971379cfbfc37ebad4 Binary files /dev/null and b/experiments/01_faston_converging/trained_models_val/with_search_strategy/test_w100_b50_e40_l30/threshold.dill differ diff --git a/ros_adetector.py b/ros_adetector.py index e51d79f7d7c3f32657ec13fb7797385a054783d3..fa8aa0563251b417aea9fbf6c7a3b92bfaffaccc 100644 --- a/ros_adetector.py +++ b/ros_adetector.py @@ -61,13 +61,14 @@ class AnomalyDetector(RosDatahandler): - makes a classification based on anomaly_bound """ - def __init__(self, data_streams): - self.ae_jit, logger_params = load_trained_model_jit(path_to_models=RepoPaths.trained_models_val, test_name="test_w100_b50_e40") + def __init__(self, data_streams, model_name: str): + self.ae_jit, logger_params = load_trained_model_jit(path_to_models=RepoPaths.trained_models_val, test_name=model_name) super().__init__(data_streams, - sync_stream="jt", # when a message arrives on this topic, all the last msgs are added to the data buffer + sync_stream="force", # when a message arrives on this topic, all the last msgs are added to the data buffer window_size=logger_params.window_size) - self.threshold = from_dill(path=os.path.join(RepoPaths.trained_models_val, Path("test_w100_b50_e40")), + self.threshold = from_dill(path=os.path.join(RepoPaths.trained_models_val, Path(model_name)), file_name="threshold.dill") + print("Model: ", model_name) self.start = True # True means get_data() and False nothing, currently for testing self.start_datastreams(ros_version = "ros1") # This sets up subscribers for the data_streams we have in the super().__init__ @@ -86,6 +87,7 @@ class AnomalyDetector(RosDatahandler): Bool, self.start_callback) self.est_force_pub = rospy.Publisher('/F_ext_recon', WrenchStamped, queue_size=1) + self.force_pub = rospy.Publisher('/F_comp', WrenchStamped, queue_size=1) self.anomaly_pub = rospy.Publisher('/anomaly_score', WrenchStamped, queue_size=1) while not self.start: @@ -102,6 +104,7 @@ class AnomalyDetector(RosDatahandler): if not self.start: return if not self.all_streams_recieved(): return data = self.get_data() + self.force_pub.publish(build_wrench_msg(data.force[-1])) recon_data = self.ae_jit(data) recon_loss = jax.tree_util.tree_map(lambda w, rw: jnp.mean(jnp.abs(w - rw), axis=0), data, recon_data) self.est_force_pub.publish(build_wrench_msg(recon_data.force[-1])) @@ -116,48 +119,46 @@ class AnomalyDetector(RosDatahandler): # -> So no jax.tree_map at this point # anomaly_counter = jax.tree_map(lambda val, cut_off: val > cut_off, loss, self.threshold) - loss = jax.tree_util.tree_leaves(loss) + leaves_with_paths = jax.tree_util.tree_leaves_with_path(loss) + leaves_loss = [l[1] for l in leaves_with_paths] + paths = [l[0] for l in leaves_with_paths] cut_off = jax.tree_util.tree_leaves(self.threshold) - for loss_leaf, cut_leaf in zip(loss, cut_off): - if jnp.any(loss_leaf > cut_leaf): - print("Anomaly Detected!") + for name, leaf_loss, cut_leaf in zip(paths, leaves_loss, cut_off): + anomalies = jnp.where(leaf_loss > cut_leaf) + if len(anomalies[0]) != 0: + print("Anomaly Detected in feature: ", anomalies[0], jax.tree_util.keystr(name)) def start_node(): # we init with a initialzied FrankaData b/c the default values of the attributes in FrankaData # have the right topic names, etc. - anomaly_detector = AnomalyDetector(FrankaData(jt = JointPosData('/joint_states', JointState), - force = ForceData('/franka_state_controller/F_ext', WrenchStamped))) + anomaly_detector = AnomalyDetector(data_streams=FrankaData(jt = JointPosData('/joint_states', JointState), + force = ForceData('/franka_state_controller/F_ext', WrenchStamped)), + model_name="test_w100_b20_e50_l15") rospy.spin() def test(len_test: int=15, window_size: int=10): - #roscore = subprocess.Popen(["roscore"]) try: - anomaly_detector = AnomalyDetector(FrankaData(jt = JointPosData('/joint_states', JointState), - force = ForceData('/franka_state_controller/F_ext', WrenchStamped))) + anomaly_detector = AnomalyDetector(data_streams=FrankaData(jt = JointPosData('/joint_states', JointState), + force = ForceData('/franka_state_controller/F_ext', WrenchStamped)), + model_name="test_w100_b20_e50_l15") + rospub_process = subprocess.Popen( ['rostopic', 'pub', '/start_anomaly', 'std_msgs/Bool', 'True', '--once'] ) - rosbag_process = subprocess.Popen( ['rosbag', 'play', f'-u {len_test}', 'experiments/01_faston_converging/data/with_search_strategy/test/214_FAILURE_+0,0033_not_inserted.bag'] ) - #echo_process = subprocess.Popen( ['rostopic', 'echo', 'joint_states/position[0]'] ) - + #rosbag_process = subprocess.Popen( ['rosbag', 'play', f'-u {len_test}', 'experiments/01_faston_converging/data/with_search_strategy/test/214_FAILURE_+0,0033_not_inserted.bag'] ) + assert type(anomaly_detector.get_data()) == FrankaData sleep(len_test) finally: print("Terminating ROS bag process...") rospub_process.send_signal(signal.SIGINT) - rosbag_process.send_signal(signal.SIGINT) - #echo_process.send_signal(signal.SIGINT) - #roscore.send_signal(signal.SIGINT) + #rosbag_process.send_signal(signal.SIGINT) try: - rosbag_process.wait(timeout=1) + #rosbag_process.wait(timeout=1) rospub_process.wait(timeout=1) - #echo_process.wait(timeout=1) - #roscore.wait(timeout=1) except subprocess.TimeoutExpired: rospub_process.kill() - rosbag_process.kill() - #echo_process.kill() - #roscore.kill() + #rosbag_process.kill() if __name__ == "__main__": diff --git a/run/construct_thresholds.py b/run/construct_thresholds.py index 0976bcf4bf06081db5b0b7fa7b1cc0b6a8208734..9c5951beb8ac09c0ec5b04880b258247fc927d7b 100644 --- a/run/construct_thresholds.py +++ b/run/construct_thresholds.py @@ -15,9 +15,9 @@ def construct_threshold(name): from dataclass_creator.dyn_slice_dataloader import DynSliceDataloader from run.parameter import RepoPaths - print(f"Validation of model {name}") + print(f"Threshold calculation of model {name}") # load test bags based on bag_name - dataset = load_dataset(root=RepoPaths.threshold, name="010_SUCCESS") + dataset = load_dataset(root=RepoPaths.threshold, name="SUCCESS") ae_params, model_params, logger_params = load_trained_model(path_to_models=RepoPaths.trained_models_val, test_name=name) @@ -37,16 +37,20 @@ def construct_threshold(name): recon_windows = jax.vmap(ae_params.apply, in_axes=(None, 0))(model_params, windows) loss = tree_map(lambda w, rw: jnp.mean(jnp.abs(w - rw), axis=1),windows, recon_windows) threshold = tree_map(lambda l: jnp.max(l, axis=0), loss) + threshold = tree_map(lambda l: l + 0.6*l, threshold) thresholds = [tree_map(lambda l: l + i*l, threshold) for i in np.arange(-0.9, 0.95, 0.05).tolist()] - test_thresholds = [tree_map(lambda l: l + i*l, threshold) for i in np.arange(-0.9, 0.95, 0.3).tolist()] - threshold_test = tree_map(lambda l: l + (-0.5)*l, threshold) - to_dill(tuple(thresholds), logger_params.path_to_model, "thresholds_roc.dill") - to_dill(tuple(test_thresholds), logger_params.path_to_model, "thresholds_testing.dill") to_dill(threshold, logger_params.path_to_model, "threshold.dill") if __name__ == "__main__": - models = ["test_w100_b50_e40"] + models = ["test_w100_b20_e30_l15", + "test_w100_b20_e30_l20", + "test_w100_b20_e50_l15", + "test_w100_b20_e50_l20", + "test_w100_b50_e30_l15", + "test_w100_b50_e30_l20", + "test_w100_b50_e50_l15", + "test_w100_b50_e50_l20"] for name in models: construct_threshold(name) diff --git a/run/parameter.py b/run/parameter.py index 3caf9448a6288defdaace90216afb432afddc765..51964c49b29b94a03ec5f7414b23490564deefc8 100644 --- a/run/parameter.py +++ b/run/parameter.py @@ -49,12 +49,13 @@ class TerminalColour: class LoggerParams: - def __init__(self, window_size, batch_size, epochs, path_to_models): + def __init__(self, window_size, batch_size, epochs, latent_dim, path_to_models): self.window_size = window_size self.batch_size = batch_size self.epochs = epochs + self.latent_dim = latent_dim self.path_to_models = path_to_models - self.path_to_model = os.path.join(self.path_to_models, Path(f"test_w{window_size}_b{batch_size}_e{epochs}")) + self.path_to_model = os.path.join(self.path_to_models, Path(f"test_w{window_size}_b{batch_size}_e{epochs}_l{latent_dim}")) self.time_stamp = datetime.now().strftime("%d_%m_%Y-%H_%M_%S") logging.basicConfig(level=logging.INFO) @@ -68,7 +69,7 @@ class LoggerParams: @dataclass class AEParams: c_hid: int = 50 # parameters in hidden layer - bottleneck_size: int = 15 # latent_dim + bottleneck_size: int = 20 # latent_dim #@kevin 10.8.24: a major advantage of dataclass is the automatic __init__, # by defining __init__ by hand, this is overwritten so we cant @@ -85,7 +86,7 @@ class RepoPaths: """Path names for 01_faston_converging anomaly detection.""" example_data_train = Path('experiments/example_data/train') example_data_test = Path('experiments/example_data/test') - data_train = Path('experiments/01_faston_converging/data/with_search_strategy/train') - unseen_data = Path('experiments/01_faston_converging/data/with_search_strategy/test') - trained_models_val = Path('experiments/01_faston_converging/trained_models_val/with_search_strategy') - threshold = Path('experiments/01_faston_converging/data/with_search_strategy/threshold') + data_train = Path('experiments/01_faston_converging/data/cable_insert/train') + unseen_data = Path('experiments/01_faston_converging/data/cable_insert/test') + trained_models_val = Path('experiments/01_faston_converging/trained_models_val/cable_insert_model') + threshold = Path('experiments/01_faston_converging/data/cable_insert/test') diff --git a/run/train_model.py b/run/train_model.py index 94a87854d524ca7c5e10b26c784dcd401379ba4f..18baac1737cfe06936a9f3f743c5f8a6f1209fbe 100644 --- a/run/train_model.py +++ b/run/train_model.py @@ -23,7 +23,7 @@ def train_model(train_loader: DynSliceDataloader, # --------- initialize autoencoder --------------------------------------------------------------------------------- - ae = AutoEncoder(AEParams.c_hid, AEParams.bottleneck_size, train_loader.out_size) + ae = AutoEncoder(AEParams.c_hid, logger_params.latent_dim, train_loader.out_size) # --------- initialize trainer ------------------------------------------------------------------------------------- config = TrainerConfig( @@ -43,14 +43,16 @@ def train_model(train_loader: DynSliceDataloader, return ae, optimized_params, random_threshold -def train(path_to_data: Path, window_size: int, batch_size: int, epochs: int, purpose: str): - train_data = load_dataset(root=path_to_data, name="SUCCESS") +def train(path_to_data: Path, window_size: int, batch_size: int, epochs: int, latent_dim: int, purpose: str, train_data=None): + if not train_data: + train_data = load_dataset(root=path_to_data, name="SUCCESS") train_loader = DynSliceDataloader(train_data, window_size=window_size, batch_size=batch_size) val_loader = DynSliceDataloader(train_data, window_size=window_size, batch_size=batch_size) logger_params = LoggerParams(window_size=window_size, batch_size=batch_size, epochs=epochs, + latent_dim=latent_dim, path_to_models=RepoPaths.trained_models_val) ae, optimized_params, threshold = train_model(train_loader, @@ -64,7 +66,7 @@ def train(path_to_data: Path, window_size: int, batch_size: int, epochs: int, pu assert isinstance(optimized_params, dict) -def train_loop(path_to_data: Path, window_size: list, batch_size: list, epochs: list, purpose: str): +def train_loop(path_to_data: Path, window_size: list, batch_size: list, epochs: list, latent_dim:list, purpose: str): """ Train multiple models by setting different window, batch sizes and number of epochs. @@ -74,20 +76,21 @@ def train_loop(path_to_data: Path, window_size: list, batch_size: list, epochs: batch_size: List of different batch sizes. epochs: List of numbers of epochs the model will be trained for. """ - model_params = product(window_size, batch_size, epochs) - + model_params = product(window_size, batch_size, epochs, latent_dim) + train_data = load_dataset(root=path_to_data, name="SUCCESS") for hparams in model_params: - window_size, batch_size, epochs = hparams - train(path_to_data=path_to_data, window_size=window_size, batch_size=batch_size, epochs=epochs, purpose=purpose) + window_size, batch_size, epochs, latent_dim = hparams + train(path_to_data=path_to_data, window_size=window_size, batch_size=batch_size, epochs=epochs, latent_dim=latent_dim, purpose=purpose, train_data=train_data) if __name__ == "__main__": train_config = { - "path_to_data": RepoPaths.example_data_train, + "path_to_data": RepoPaths.data_train, "window_size": [100], - "batch_size": [50], - "epochs": [20], + "batch_size": [20,50], + "epochs": [30,50], + "latent_dim": [15,20], "purpose": "Reconstruction" } - + train_loop(**train_config) os._exit(1) diff --git a/run/validate_model.py b/run/validate_model.py index ab0373bb065ee12ba9b96d57680ad962abc1fbe9..97ecba62e2717370517de8559003adf1d734dae2 100644 --- a/run/validate_model.py +++ b/run/validate_model.py @@ -6,10 +6,11 @@ import matplotlib.pyplot as plt from dataclass_creator.utils_dataclass.real_data_gen import load_dataset from run.load_save_model import from_dill from run.valmodel import ValModelRecon, save_multi_image +from run.parameter import RepoPaths def val_model(model_name: str, bag_key: str, bag_names: list, anomalies: list[dict]): - from run.parameter import RepoPaths + logging.basicConfig(level=logging.INFO) dataset = load_dataset(root=RepoPaths.unseen_data, name=bag_key) @@ -34,19 +35,28 @@ def val_model(model_name: str, bag_key: str, bag_names: list, anomalies: list[di val.visualize_scores_over_batches(threshold_name=str(i), threshold= threshold, anomalies=anomaly) val.get_roc_over_batches(thresholds=thresholds_roc, anomalies=anomaly) - save_multi_image(f"val_{bag_name}_{name}.pdf") + save_multi_image(f"val_{bag_name}_{model_name}.pdf") plt.close('all') +def show_some_plots(model_name: str, bag_name: str): + dataset = load_dataset(root=RepoPaths.threshold, name=bag_name) + val = ValModelRecon(path_to_models=RepoPaths.trained_models_val, + name_model=model_name, + test_traj=dataset[0], + test_name=bag_name) + + val.visualize_windows() + val.visualize_loss() + plt.show() + + if __name__ == "__main__": - models = ["test_w10_b50_e20", - "test_w50_b50_e20", - "test_w100_b50_e20", - "test_w500_b50_e20", - "test_w1000_b50_e20"] + model = ["test_w80_b50_e30_l15"] + show_some_plots(model[0], bag_name="SUCCESS") - for name in models: + """for name in model: val_model(name, "FAILURE", ["040_FAILURE", "050_FAILURE", "214_FAILURE"], - anomalies=[list(range(236, 280)), list(range(249,288)), list(range(258, 296))]) + anomalies=[list(range(236, 280)), list(range(249,288)), list(range(258, 296))])"""