Skip to content
Snippets Groups Projects

Structure save and load

Merged Lisa-Marie Fenner requested to merge 44-structure-save-and-load into main
4 files
+ 27
69
Compare changes
  • Side-by-side
  • Inline
Files
4
+ 1
32
@@ -76,17 +76,6 @@ class TrainerConfig:
check_val_every_n_epoch: int = 4
class ModelParams:
model_rng: float
window_size: int
batch_size: int
len_dataloader: int
feature_size: list[int]
purpose: str
lr: float
loss_array: list[float]
class Trainer:
"""
Trainer class that includes the magic of training data-driven approaches.
@@ -118,7 +107,6 @@ class Trainer:
self.purpose = purpose
self.optimizer_hparams = optimizer_hparams
self.model = model
self.model_params = ModelParams()
self.seed = seed
self.check_val_every_n_epoch = check_val_every_n_epoch
@@ -130,8 +118,6 @@ class Trainer:
self.init_model(init_vec)
self.logger = None
self.init_logger(logger_params)
self.file_name = logger_params.file_name
self.ckpt_dir_name = logger_params.ckpt_dir_name
def init_model(self, init_vec: SysVarSet):
"""
@@ -151,10 +137,6 @@ class Trainer:
tx=None,
opt_state=None)
self.model_params.model_rng = model_rng
self.model_params.init_vec = init_vec
self.model_params.window_size = jax.tree_util.tree_leaves(init_vec)[0].shape[0]
def init_logger(self, logger_params: dataclass):
"""
Initializes the Summary Writer and define name of log_dir.
@@ -162,9 +144,7 @@ class Trainer:
Args:
logger_params: dictionary containing log name
"""
curr_dir = Path(os.path.dirname(os.path.abspath(__file__)))
root_dir = os.fspath(Path(curr_dir.parent, 'log_files/validation').resolve())
log_dir = os.path.join(root_dir, logger_params.time_stamp, logger_params.logdir_name)
log_dir = os.path.join(logger_params.path_to_model + "/log_file")
self.logger = tf.summary.create_file_writer(log_dir)
def init_optimizer(self,
@@ -196,14 +176,12 @@ class Trainer:
self.state = TrainState.create(apply_fn=self.state.apply_fn,
params=self.state.params,
tx=optimizer)
self.model_params.learning_rate = lr
def init_calc_functions(self):
"""
Initializes the step functions and loss function matching the purpose.
"""
self.model_params.purpose = self.purpose
if self.purpose.lower() == 'reconstruction':
self.train_step = jit(TrainReconstructor.train_step, static_argnums=(0,))
@@ -237,21 +215,13 @@ class Trainer:
loss_array: containing the loss_values from every iteration
"""
# idx_list = train_dataloader.get_idx_batch()
self.model_params.batch_size = train_dataloader.batch_size
self.model_params.len_dataloader = len(train_dataloader)
self.model_params.feature_size = [9, 3]
self.init_optimizer(n_epochs, len(train_dataloader))
self.init_calc_functions()
loss_array = []
loss_grad_fn = jax.value_and_grad(self.loss_fn)
#batches = train_dataloader.prep_batches()
batches = train_dataloader
# batch_idx = train_dataloader.get_idx_batch()
for i in tqdm(range(1, n_epochs + 1)):
batch_loss_array = []
@@ -267,7 +237,6 @@ class Trainer:
#self.save_model(epoch=i)
# at this point evaluation and early break could be added
self.model_params.loss_array = loss_array
#self.save_model_params()
return self.state.params, loss_array
Loading