from dataclasses import dataclass
from typing import List
import numpy as np
from torch.utils.data import TensorDataset
@dataclass
[docs]class ZtfData:
"""Holds raw ZTF object data."""
[docs] def __iter__(self):
return iter((self.names, self.labels, self.redshifts))
@dataclass
[docs]class TrainData:
"""Holds train and validation datasets."""
[docs] train_dataset: TensorDataset
[docs] val_dataset: TensorDataset
[docs] def __iter__(self):
return iter((self.train_dataset, self.val_dataset))
@dataclass
[docs]class TestData:
"""Holds information about testing data."""
[docs] test_features: np.ndarray
[docs] test_classes: np.ndarray
[docs] test_group_idxs: List[int]
[docs] def __iter__(self):
return iter(
(
self.test_features,
self.test_classes,
self.test_names,
self.test_group_idxs,
)
)