Source code for src.superphot_plus.model.data

from dataclasses import dataclass
from typing import List

import numpy as np
from torch.utils.data import TensorDataset


@dataclass
[docs]class ZtfData: """Holds raw ZTF object data."""
[docs] names: List[int]
[docs] labels: List[int]
[docs] redshifts: List[int]
[docs] def __iter__(self): return iter((self.names, self.labels, self.redshifts))
@dataclass
[docs]class TrainData: """Holds train and validation datasets."""
[docs] train_dataset: TensorDataset
[docs] val_dataset: TensorDataset
[docs] def __iter__(self): return iter((self.train_dataset, self.val_dataset))
@dataclass
[docs]class TestData: """Holds information about testing data."""
[docs] test_features: np.ndarray
[docs] test_classes: np.ndarray
[docs] test_names: np.ndarray
[docs] test_group_idxs: List[int]
[docs] def __iter__(self): return iter( ( self.test_features, self.test_classes, self.test_names, self.test_group_idxs, ) )