Skip to content

Module wtracker.neural.config

View Source
from __future__ import annotations

import torch

from torch import nn

from torch.optim import Optimizer

from torch.utils.data import Dataset, DataLoader, random_split

from dataclasses import dataclass, field

from wtracker.utils.config_base import ConfigBase

@dataclass

class DatasetConfig(ConfigBase):

    input_frames: list[int] # The frames to use as input for the network. The frames are in the format of the number of frames before (negative) or after (positive) the prediction frame(0).

    pred_frames: list[int] # The frames to predict. The frames are in the format of the number of frames before (negative) or after (positive) the prediction frame(0).

    log_path: list[str] # The path to the log file containing the worm head predictions (by YOLO).

    def __post_init__(self) -> None:

        if self.input_frames[0] != 0:

            print(

                "WARNING::DatasetConfig::frames_for_pred should contain 0 as first element. Please check verify you parameters."

            )

    @staticmethod

    def from_io_config(io: IOConfig, log_path: str) -> DatasetConfig:

        return DatasetConfig(io.input_frames, io.pred_frames, log_path)

OPTIMIZERS = {

    "adam": torch.optim.Adam,

    "sgd": torch.optim.SGD,

    "rmsprop": torch.optim.RMSprop,

    "adamw": torch.optim.AdamW,

}

LOSSES = {

    "mse": nn.MSELoss,

    "l1": nn.L1Loss,

}

@dataclass

class TrainConfig(ConfigBase):

    # general parameters

    seed: int = field(default=42, kw_only=True)  # Random seed for reproducibility

    dataset: DatasetConfig  # The dataset to use for training, can also be a config object (if Dataset, it will be used as is)

    # trainer parameters

    model: nn.Module | str  # The model to train, can also be a pretrained model (if str, it will be loaded from disk)

    loss_fn: str  # The loss function to use, can be any of the keys in the LOSSES dict

    optimizer: str  # The optimizer to use, can be any of the keys in the OPTIMIZERS dict

    device: str = "cuda"  # 'cuda' for training on GPU or 'cpu' otherwise

    log: bool = False  # Whether to log and save the training process with tensorboard

    # training parameters

    num_epochs: int = 100  # Number of times to iterate over the dataset

    checkpoints: str = None  # Path to save model checkpoints, influding the checkpoint name.

    early_stopping: int = None  # Number of epochs to wait before stopping training if no improvement was made

    print_every: int = 5  # How often (#epochs) to print training progress

    # optimizer parameters

    learning_rate: float = 0.001  # Learning rate for the optimizer

    weight_decay: float = (

        1e-5  # Weight decay for the optimizer (regularization, values typically in range [0.0, 1e-4] but can be bigger)

    )

    # dataloader parameters

    batch_size: int = 256  # Number of samples in each batch

    shuffle: bool = True  # Whether to shuffle the dataset at the beginning of each epoch

    num_workers: int = 0  # Number of subprocesses to use for data loading

    train_test_split: float = 0.8  # Fraction of the dataset to use for training, the rest will be used for testing

    dl_train: DataLoader = field(init=False)

    dl_test: DataLoader = field(init=False)

@dataclass

class IOConfig(ConfigBase):

    """

    Configuration for the basic input/output of the network

    The input_frames and pred_frames are lists of integers that represent the frames

    that will be used as input and output of the network. The frames are in the format

    of the number of frames before (negative) or after (positive) the prediction frame(0).

    To calculate in_dim,out_dim we assume that each input frame has 4 features (x,y,w,h), representing the worm bbox in that frame

    and each prediction frame has 2 features (x,y), representing the worm center in that frame.

    """

    input_frames: list[int]

    pred_frames: list[int]

    in_dim: int = field(init=False)

    out_dim: int = field(init=False)

    def __post_init__(self):

        if 0 not in self.input_frames:

            print(

                "WARNING::IOConfig::__post_init__::input_frames doesn't contain 0 (the prediction frame). Please verify your parameters."

            )

        self.in_dim = len(self.input_frames) * 4

        self.out_dim = len(self.pred_frames) * 2

    @staticmethod

    def from_datasetConfig(config: DatasetConfig) -> IOConfig:

        return IOConfig(config.input_frames, config.pred_frames)

Variables

LOSSES
OPTIMIZERS

Classes

DatasetConfig

class DatasetConfig(
    input_frames: 'list[int]',
    pred_frames: 'list[int]',
    log_path: 'list[str]'
)

DatasetConfig(input_frames: 'list[int]', pred_frames: 'list[int]', log_path: 'list[str]')

View Source
@dataclass

class DatasetConfig(ConfigBase):

    input_frames: list[int] # The frames to use as input for the network. The frames are in the format of the number of frames before (negative) or after (positive) the prediction frame(0).

    pred_frames: list[int] # The frames to predict. The frames are in the format of the number of frames before (negative) or after (positive) the prediction frame(0).

    log_path: list[str] # The path to the log file containing the worm head predictions (by YOLO).

    def __post_init__(self) -> None:

        if self.input_frames[0] != 0:

            print(

                "WARNING::DatasetConfig::frames_for_pred should contain 0 as first element. Please check verify you parameters."

            )

    @staticmethod

    def from_io_config(io: IOConfig, log_path: str) -> DatasetConfig:

        return DatasetConfig(io.input_frames, io.pred_frames, log_path)

Ancestors (in MRO)

  • wtracker.utils.config_base.ConfigBase

Static methods

from_io_config

def from_io_config(
    io: 'IOConfig',
    log_path: 'str'
) -> 'DatasetConfig'
View Source
    @staticmethod

    def from_io_config(io: IOConfig, log_path: str) -> DatasetConfig:

        return DatasetConfig(io.input_frames, io.pred_frames, log_path)

load_json

def load_json(
    path: 'str' = None
) -> 'T'

Load the class from a JSON file.

Parameters:

Name Type Description Default
path str The path to the JSON file. None

Returns:

Type Description
ConfigBase The class loaded from the JSON file.
View Source
    @classmethod

    def load_json(cls: type[T], path: str = None) -> T:

        """

        Load the class from a JSON file.

        Args:

            path (str): The path to the JSON file.

        Returns:

            ConfigBase: The class loaded from the JSON file.

        """

        if path is None:

            path = UserPrompt.open_file(

                title=f"Open {cls.__name__} File",

                file_types=[("json", ".json")],

            )

        with open(path, "r") as f:

            data = json.load(f)

        obj = cls.__new__(cls)

        obj.__dict__.update(data)

        return obj

load_pickle

def load_pickle(
    path: 'str' = None
) -> 'T'

Load the class from a pickle file.

Parameters:

Name Type Description Default
path str The path to the pickle file. None

Returns:

Type Description
None The class loaded from the pickle file.
View Source
    @classmethod

    def load_pickle(cls: type[T], path: str = None) -> T:

        """

        Load the class from a pickle file.

        Args:

            path (str): The path to the pickle file.

        Returns:

            The class loaded from the pickle file.

        """

        if path is None:

            path = UserPrompt.open_file(

                title=f"Open {cls.__name__} File",

                file_types=[("pickle", ".pkl")],

            )

        return pickle_load_object(path)

Methods

save_json

def save_json(
    self,
    path: 'str' = None
)

Saves the class as JSON file.

Parameters:

Name Type Description Default
path str The path to the output JSON file. None
View Source
    def save_json(self, path: str = None):

        """

        Saves the class as JSON file.

        Args:

            path (str): The path to the output JSON file.

        """

        if path is None:

            path = UserPrompt.save_file(

                title=f"Save {type(self).__name__} As",

                file_types=[("json", ".json")],

                defaultextension=".json",

            )

        with open(path, "w") as f:

            json.dump(self.__dict__, f, indent=4)

save_pickle

def save_pickle(
    self,
    path: 'str' = None
) -> 'None'

Saves the class as a pickle file.

Parameters:

Name Type Description Default
path str The path to the output pickle file. None
View Source
    def save_pickle(self, path: str = None) -> None:

        """

        Saves the class as a pickle file.

        Args:

            path (str): The path to the output pickle file.

        """

        if path is None:

            path = UserPrompt.save_file(

                title=f"Save {type(self).__name__} As",

                file_types=[("pickle", ".pkl")],

                defaultextension=".pkl",

            )

        pickle_save_object(self, path)

IOConfig

class IOConfig(
    input_frames: 'list[int]',
    pred_frames: 'list[int]'
)

Configuration for the basic input/output of the network

The input_frames and pred_frames are lists of integers that represent the frames that will be used as input and output of the network. The frames are in the format of the number of frames before (negative) or after (positive) the prediction frame(0). To calculate in_dim,out_dim we assume that each input frame has 4 features (x,y,w,h), representing the worm bbox in that frame and each prediction frame has 2 features (x,y), representing the worm center in that frame.

View Source
@dataclass

class IOConfig(ConfigBase):

    """

    Configuration for the basic input/output of the network

    The input_frames and pred_frames are lists of integers that represent the frames

    that will be used as input and output of the network. The frames are in the format

    of the number of frames before (negative) or after (positive) the prediction frame(0).

    To calculate in_dim,out_dim we assume that each input frame has 4 features (x,y,w,h), representing the worm bbox in that frame

    and each prediction frame has 2 features (x,y), representing the worm center in that frame.

    """

    input_frames: list[int]

    pred_frames: list[int]

    in_dim: int = field(init=False)

    out_dim: int = field(init=False)

    def __post_init__(self):

        if 0 not in self.input_frames:

            print(

                "WARNING::IOConfig::__post_init__::input_frames doesn't contain 0 (the prediction frame). Please verify your parameters."

            )

        self.in_dim = len(self.input_frames) * 4

        self.out_dim = len(self.pred_frames) * 2

    @staticmethod

    def from_datasetConfig(config: DatasetConfig) -> IOConfig:

        return IOConfig(config.input_frames, config.pred_frames)

Ancestors (in MRO)

  • wtracker.utils.config_base.ConfigBase

Static methods

from_datasetConfig

def from_datasetConfig(
    config: 'DatasetConfig'
) -> 'IOConfig'
View Source
    @staticmethod

    def from_datasetConfig(config: DatasetConfig) -> IOConfig:

        return IOConfig(config.input_frames, config.pred_frames)

load_json

def load_json(
    path: 'str' = None
) -> 'T'

Load the class from a JSON file.

Parameters:

Name Type Description Default
path str The path to the JSON file. None

Returns:

Type Description
ConfigBase The class loaded from the JSON file.
View Source
    @classmethod

    def load_json(cls: type[T], path: str = None) -> T:

        """

        Load the class from a JSON file.

        Args:

            path (str): The path to the JSON file.

        Returns:

            ConfigBase: The class loaded from the JSON file.

        """

        if path is None:

            path = UserPrompt.open_file(

                title=f"Open {cls.__name__} File",

                file_types=[("json", ".json")],

            )

        with open(path, "r") as f:

            data = json.load(f)

        obj = cls.__new__(cls)

        obj.__dict__.update(data)

        return obj

load_pickle

def load_pickle(
    path: 'str' = None
) -> 'T'

Load the class from a pickle file.

Parameters:

Name Type Description Default
path str The path to the pickle file. None

Returns:

Type Description
None The class loaded from the pickle file.
View Source
    @classmethod

    def load_pickle(cls: type[T], path: str = None) -> T:

        """

        Load the class from a pickle file.

        Args:

            path (str): The path to the pickle file.

        Returns:

            The class loaded from the pickle file.

        """

        if path is None:

            path = UserPrompt.open_file(

                title=f"Open {cls.__name__} File",

                file_types=[("pickle", ".pkl")],

            )

        return pickle_load_object(path)

Methods

save_json

def save_json(
    self,
    path: 'str' = None
)

Saves the class as JSON file.

Parameters:

Name Type Description Default
path str The path to the output JSON file. None
View Source
    def save_json(self, path: str = None):

        """

        Saves the class as JSON file.

        Args:

            path (str): The path to the output JSON file.

        """

        if path is None:

            path = UserPrompt.save_file(

                title=f"Save {type(self).__name__} As",

                file_types=[("json", ".json")],

                defaultextension=".json",

            )

        with open(path, "w") as f:

            json.dump(self.__dict__, f, indent=4)

save_pickle

def save_pickle(
    self,
    path: 'str' = None
) -> 'None'

Saves the class as a pickle file.

Parameters:

Name Type Description Default
path str The path to the output pickle file. None
View Source
    def save_pickle(self, path: str = None) -> None:

        """

        Saves the class as a pickle file.

        Args:

            path (str): The path to the output pickle file.

        """

        if path is None:

            path = UserPrompt.save_file(

                title=f"Save {type(self).__name__} As",

                file_types=[("pickle", ".pkl")],

                defaultextension=".pkl",

            )

        pickle_save_object(self, path)

TrainConfig

class TrainConfig(
    dataset: 'DatasetConfig',
    model: 'nn.Module | str',
    loss_fn: 'str',
    optimizer: 'str',
    device: 'str' = 'cuda',
    log: 'bool' = False,
    num_epochs: 'int' = 100,
    checkpoints: 'str' = None,
    early_stopping: 'int' = None,
    print_every: 'int' = 5,
    learning_rate: 'float' = 0.001,
    weight_decay: 'float' = 1e-05,
    batch_size: 'int' = 256,
    shuffle: 'bool' = True,
    num_workers: 'int' = 0,
    train_test_split: 'float' = 0.8,
    *,
    seed: 'int' = 42
)

TrainConfig(dataset: 'DatasetConfig', model: 'nn.Module | str', loss_fn: 'str', optimizer: 'str', device: 'str' = 'cuda', log: 'bool' = False, num_epochs: 'int' = 100, checkpoints: 'str' = None, early_stopping: 'int' = None, print_every: 'int' = 5, learning_rate: 'float' = 0.001, weight_decay: 'float' = 1e-05, batch_size: 'int' = 256, shuffle: 'bool' = True, num_workers: 'int' = 0, train_test_split: 'float' = 0.8, *, seed: 'int' = 42)

View Source
@dataclass

class TrainConfig(ConfigBase):

    # general parameters

    seed: int = field(default=42, kw_only=True)  # Random seed for reproducibility

    dataset: DatasetConfig  # The dataset to use for training, can also be a config object (if Dataset, it will be used as is)

    # trainer parameters

    model: nn.Module | str  # The model to train, can also be a pretrained model (if str, it will be loaded from disk)

    loss_fn: str  # The loss function to use, can be any of the keys in the LOSSES dict

    optimizer: str  # The optimizer to use, can be any of the keys in the OPTIMIZERS dict

    device: str = "cuda"  # 'cuda' for training on GPU or 'cpu' otherwise

    log: bool = False  # Whether to log and save the training process with tensorboard

    # training parameters

    num_epochs: int = 100  # Number of times to iterate over the dataset

    checkpoints: str = None  # Path to save model checkpoints, influding the checkpoint name.

    early_stopping: int = None  # Number of epochs to wait before stopping training if no improvement was made

    print_every: int = 5  # How often (#epochs) to print training progress

    # optimizer parameters

    learning_rate: float = 0.001  # Learning rate for the optimizer

    weight_decay: float = (

        1e-5  # Weight decay for the optimizer (regularization, values typically in range [0.0, 1e-4] but can be bigger)

    )

    # dataloader parameters

    batch_size: int = 256  # Number of samples in each batch

    shuffle: bool = True  # Whether to shuffle the dataset at the beginning of each epoch

    num_workers: int = 0  # Number of subprocesses to use for data loading

    train_test_split: float = 0.8  # Fraction of the dataset to use for training, the rest will be used for testing

    dl_train: DataLoader = field(init=False)

    dl_test: DataLoader = field(init=False)

Ancestors (in MRO)

  • wtracker.utils.config_base.ConfigBase

Class variables

batch_size
checkpoints
device
early_stopping
learning_rate
log
num_epochs
num_workers
print_every
seed
shuffle
train_test_split
weight_decay

Static methods

load_json

def load_json(
    path: 'str' = None
) -> 'T'

Load the class from a JSON file.

Parameters:

Name Type Description Default
path str The path to the JSON file. None

Returns:

Type Description
ConfigBase The class loaded from the JSON file.
View Source
    @classmethod

    def load_json(cls: type[T], path: str = None) -> T:

        """

        Load the class from a JSON file.

        Args:

            path (str): The path to the JSON file.

        Returns:

            ConfigBase: The class loaded from the JSON file.

        """

        if path is None:

            path = UserPrompt.open_file(

                title=f"Open {cls.__name__} File",

                file_types=[("json", ".json")],

            )

        with open(path, "r") as f:

            data = json.load(f)

        obj = cls.__new__(cls)

        obj.__dict__.update(data)

        return obj

load_pickle

def load_pickle(
    path: 'str' = None
) -> 'T'

Load the class from a pickle file.

Parameters:

Name Type Description Default
path str The path to the pickle file. None

Returns:

Type Description
None The class loaded from the pickle file.
View Source
    @classmethod

    def load_pickle(cls: type[T], path: str = None) -> T:

        """

        Load the class from a pickle file.

        Args:

            path (str): The path to the pickle file.

        Returns:

            The class loaded from the pickle file.

        """

        if path is None:

            path = UserPrompt.open_file(

                title=f"Open {cls.__name__} File",

                file_types=[("pickle", ".pkl")],

            )

        return pickle_load_object(path)

Methods

save_json

def save_json(
    self,
    path: 'str' = None
)

Saves the class as JSON file.

Parameters:

Name Type Description Default
path str The path to the output JSON file. None
View Source
    def save_json(self, path: str = None):

        """

        Saves the class as JSON file.

        Args:

            path (str): The path to the output JSON file.

        """

        if path is None:

            path = UserPrompt.save_file(

                title=f"Save {type(self).__name__} As",

                file_types=[("json", ".json")],

                defaultextension=".json",

            )

        with open(path, "w") as f:

            json.dump(self.__dict__, f, indent=4)

save_pickle

def save_pickle(
    self,
    path: 'str' = None
) -> 'None'

Saves the class as a pickle file.

Parameters:

Name Type Description Default
path str The path to the output pickle file. None
View Source
    def save_pickle(self, path: str = None) -> None:

        """

        Saves the class as a pickle file.

        Args:

            path (str): The path to the output pickle file.

        """

        if path is None:

            path = UserPrompt.save_file(

                title=f"Save {type(self).__name__} As",

                file_types=[("pickle", ".pkl")],

                defaultextension=".pkl",

            )

        pickle_save_object(self, path)