Skip to content
Snippets Groups Projects

Structure save and load

Merged Lisa-Marie Fenner requested to merge 44-structure-save-and-load into main
12 files
+ 311
328
Compare changes
  • Side-by-side
  • Inline
Files
12
+ 72
0
"""A dataloader object with sampler for producing batches."""
from typing import Iterable, Tuple
import jax.numpy as jnp
from jax import Array
from opt_tools import Traj, Batch
from dataclass_creator.dyn_slice_dataset import DynSliceDataset, PredDynSliceDataset
class DynSliceDataloader:
def __init__(self, data: Iterable[Traj],
window_size: int = 80,
batch_size: int = 15,
pred_offset: int = 0):
"""Initialize a dataloader.
Args:
data (Iterable[Traj]): A list of trajectories which can be ragged
window_size (int): The length of window which samples should have
batch_size (int): The number of samples in a batch
pred_offset (int): If !=0, get_batches returns tuple with 2nd element
shifted by pred_offset
"""
if pred_offset == 0:
self.dset = DynSliceDataset(data, window_size)
else:
self.dset = PredDynSliceDataset(data, window_size, pred_offset=pred_offset)
self.batch_size = batch_size
def __len__(self) -> int:
"""Number of complete batches."""
return len(self.dset) // self.batch_size
def __iter__(self) -> Iterable[Batch]:
self._sampler = iter(self._get_idx_batch())
return self
def __next__(self) -> Batch:
return self.dset.get_batch(next(self._sampler))
def _get_idx_batch(self, shuffle: bool = False, drop_last = True) -> Array:
"""Returns a list of indices to sample over."""
num_batches = len(self)
num_eff_windows = len(self.dset)
rest_elements = num_eff_windows % self.batch_size
last_idx_list = list(range(num_eff_windows - rest_elements, num_eff_windows)) + \
list(range(self.batch_size - rest_elements))
idx_list = jnp.asarray([range(num_eff_windows - rest_elements)])
if shuffle:
raise NotImplementedError # better to indicate not currently supported
idx_list = jnp.reshape(idx_list, (num_batches, self.batch_size))
if not drop_last:
idx_list = jnp.vstack((idx_list, last_idx_list))
return idx_list
def prep_batches(self) -> Iterable[Batch]:
"""Return iterable of training batches. Memory-intensive for large windows!"""
idx_list = self._get_idx_batch()
batches = [self.dset.get_batch(idx) for idx in idx_list]
return batches
def get_init_vec(self) -> Batch:
"""Return an initial batch to be used to initialize models."""
return self.dset[0]
Loading