from dataclasses import dataclass
from typing import List
import numpy as np
from torch.utils.data import TensorDataset
@dataclass
[docs]
class ZtfData:
"""Holds raw ZTF object data."""
[docs]
def __iter__(self):
return iter((self.names, self.labels, self.redshifts))
@dataclass
[docs]
class TrainData:
"""Holds train and validation datasets."""
[docs]
train_dataset: TensorDataset
[docs]
val_dataset: TensorDataset
[docs]
def __iter__(self):
return iter((self.train_dataset, self.val_dataset))
@dataclass
[docs]
class TestData:
"""Holds information about testing data."""
[docs]
test_features: np.ndarray
[docs]
test_classes: np.ndarray
[docs]
test_group_idxs: List[int]
[docs]
def __iter__(self):
return iter(
(
self.test_features,
self.test_classes,
self.test_names,
self.test_group_idxs,
)
)