Module wtracker.neural.config
View Source
from __future__ import annotations
import torch
from torch import nn
from torch.optim import Optimizer
from torch.utils.data import Dataset, DataLoader, random_split
from dataclasses import dataclass, field
from wtracker.utils.config_base import ConfigBase
@dataclass
class DatasetConfig(ConfigBase):
input_frames: list[int] # The frames to use as input for the network. The frames are in the format of the number of frames before (negative) or after (positive) the prediction frame(0).
pred_frames: list[int] # The frames to predict. The frames are in the format of the number of frames before (negative) or after (positive) the prediction frame(0).
log_path: list[str] # The path to the log file containing the worm head predictions (by YOLO).
def __post_init__(self) -> None:
if self.input_frames[0] != 0:
print(
"WARNING::DatasetConfig::frames_for_pred should contain 0 as first element. Please check verify you parameters."
)
@staticmethod
def from_io_config(io: IOConfig, log_path: str) -> DatasetConfig:
return DatasetConfig(io.input_frames, io.pred_frames, log_path)
OPTIMIZERS = {
"adam": torch.optim.Adam,
"sgd": torch.optim.SGD,
"rmsprop": torch.optim.RMSprop,
"adamw": torch.optim.AdamW,
}
LOSSES = {
"mse": nn.MSELoss,
"l1": nn.L1Loss,
}
@dataclass
class TrainConfig(ConfigBase):
# general parameters
seed: int = field(default=42, kw_only=True) # Random seed for reproducibility
dataset: DatasetConfig # The dataset to use for training, can also be a config object (if Dataset, it will be used as is)
# trainer parameters
model: nn.Module | str # The model to train, can also be a pretrained model (if str, it will be loaded from disk)
loss_fn: str # The loss function to use, can be any of the keys in the LOSSES dict
optimizer: str # The optimizer to use, can be any of the keys in the OPTIMIZERS dict
device: str = "cuda" # 'cuda' for training on GPU or 'cpu' otherwise
log: bool = False # Whether to log and save the training process with tensorboard
# training parameters
num_epochs: int = 100 # Number of times to iterate over the dataset
checkpoints: str = None # Path to save model checkpoints, influding the checkpoint name.
early_stopping: int = None # Number of epochs to wait before stopping training if no improvement was made
print_every: int = 5 # How often (#epochs) to print training progress
# optimizer parameters
learning_rate: float = 0.001 # Learning rate for the optimizer
weight_decay: float = (
1e-5 # Weight decay for the optimizer (regularization, values typically in range [0.0, 1e-4] but can be bigger)
)
# dataloader parameters
batch_size: int = 256 # Number of samples in each batch
shuffle: bool = True # Whether to shuffle the dataset at the beginning of each epoch
num_workers: int = 0 # Number of subprocesses to use for data loading
train_test_split: float = 0.8 # Fraction of the dataset to use for training, the rest will be used for testing
dl_train: DataLoader = field(init=False)
dl_test: DataLoader = field(init=False)
@dataclass
class IOConfig(ConfigBase):
"""
Configuration for the basic input/output of the network
The input_frames and pred_frames are lists of integers that represent the frames
that will be used as input and output of the network. The frames are in the format
of the number of frames before (negative) or after (positive) the prediction frame(0).
To calculate in_dim,out_dim we assume that each input frame has 4 features (x,y,w,h), representing the worm bbox in that frame
and each prediction frame has 2 features (x,y), representing the worm center in that frame.
"""
input_frames: list[int]
pred_frames: list[int]
in_dim: int = field(init=False)
out_dim: int = field(init=False)
def __post_init__(self):
if 0 not in self.input_frames:
print(
"WARNING::IOConfig::__post_init__::input_frames doesn't contain 0 (the prediction frame). Please verify your parameters."
)
self.in_dim = len(self.input_frames) * 4
self.out_dim = len(self.pred_frames) * 2
@staticmethod
def from_datasetConfig(config: DatasetConfig) -> IOConfig:
return IOConfig(config.input_frames, config.pred_frames)
Variables
LOSSES
OPTIMIZERS
Classes
DatasetConfig
class DatasetConfig(
input_frames: 'list[int]',
pred_frames: 'list[int]',
log_path: 'list[str]'
)
DatasetConfig(input_frames: 'list[int]', pred_frames: 'list[int]', log_path: 'list[str]')
View Source
@dataclass
class DatasetConfig(ConfigBase):
input_frames: list[int] # The frames to use as input for the network. The frames are in the format of the number of frames before (negative) or after (positive) the prediction frame(0).
pred_frames: list[int] # The frames to predict. The frames are in the format of the number of frames before (negative) or after (positive) the prediction frame(0).
log_path: list[str] # The path to the log file containing the worm head predictions (by YOLO).
def __post_init__(self) -> None:
if self.input_frames[0] != 0:
print(
"WARNING::DatasetConfig::frames_for_pred should contain 0 as first element. Please check verify you parameters."
)
@staticmethod
def from_io_config(io: IOConfig, log_path: str) -> DatasetConfig:
return DatasetConfig(io.input_frames, io.pred_frames, log_path)
Ancestors (in MRO)
- wtracker.utils.config_base.ConfigBase
Static methods
from_io_config
def from_io_config(
io: 'IOConfig',
log_path: 'str'
) -> 'DatasetConfig'
View Source
@staticmethod
def from_io_config(io: IOConfig, log_path: str) -> DatasetConfig:
return DatasetConfig(io.input_frames, io.pred_frames, log_path)
load_json
def load_json(
path: 'str' = None
) -> 'T'
Load the class from a JSON file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path | str | The path to the JSON file. | None |
Returns:
Type | Description |
---|---|
ConfigBase | The class loaded from the JSON file. |
View Source
@classmethod
def load_json(cls: type[T], path: str = None) -> T:
"""
Load the class from a JSON file.
Args:
path (str): The path to the JSON file.
Returns:
ConfigBase: The class loaded from the JSON file.
"""
if path is None:
path = UserPrompt.open_file(
title=f"Open {cls.__name__} File",
file_types=[("json", ".json")],
)
with open(path, "r") as f:
data = json.load(f)
obj = cls.__new__(cls)
obj.__dict__.update(data)
return obj
load_pickle
def load_pickle(
path: 'str' = None
) -> 'T'
Load the class from a pickle file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path | str | The path to the pickle file. | None |
Returns:
Type | Description |
---|---|
None | The class loaded from the pickle file. |
View Source
@classmethod
def load_pickle(cls: type[T], path: str = None) -> T:
"""
Load the class from a pickle file.
Args:
path (str): The path to the pickle file.
Returns:
The class loaded from the pickle file.
"""
if path is None:
path = UserPrompt.open_file(
title=f"Open {cls.__name__} File",
file_types=[("pickle", ".pkl")],
)
return pickle_load_object(path)
Methods
save_json
def save_json(
self,
path: 'str' = None
)
Saves the class as JSON file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path | str | The path to the output JSON file. | None |
View Source
def save_json(self, path: str = None):
"""
Saves the class as JSON file.
Args:
path (str): The path to the output JSON file.
"""
if path is None:
path = UserPrompt.save_file(
title=f"Save {type(self).__name__} As",
file_types=[("json", ".json")],
defaultextension=".json",
)
with open(path, "w") as f:
json.dump(self.__dict__, f, indent=4)
save_pickle
def save_pickle(
self,
path: 'str' = None
) -> 'None'
Saves the class as a pickle file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path | str | The path to the output pickle file. | None |
View Source
def save_pickle(self, path: str = None) -> None:
"""
Saves the class as a pickle file.
Args:
path (str): The path to the output pickle file.
"""
if path is None:
path = UserPrompt.save_file(
title=f"Save {type(self).__name__} As",
file_types=[("pickle", ".pkl")],
defaultextension=".pkl",
)
pickle_save_object(self, path)
IOConfig
class IOConfig(
input_frames: 'list[int]',
pred_frames: 'list[int]'
)
Configuration for the basic input/output of the network
The input_frames and pred_frames are lists of integers that represent the frames that will be used as input and output of the network. The frames are in the format of the number of frames before (negative) or after (positive) the prediction frame(0). To calculate in_dim,out_dim we assume that each input frame has 4 features (x,y,w,h), representing the worm bbox in that frame and each prediction frame has 2 features (x,y), representing the worm center in that frame.
View Source
@dataclass
class IOConfig(ConfigBase):
"""
Configuration for the basic input/output of the network
The input_frames and pred_frames are lists of integers that represent the frames
that will be used as input and output of the network. The frames are in the format
of the number of frames before (negative) or after (positive) the prediction frame(0).
To calculate in_dim,out_dim we assume that each input frame has 4 features (x,y,w,h), representing the worm bbox in that frame
and each prediction frame has 2 features (x,y), representing the worm center in that frame.
"""
input_frames: list[int]
pred_frames: list[int]
in_dim: int = field(init=False)
out_dim: int = field(init=False)
def __post_init__(self):
if 0 not in self.input_frames:
print(
"WARNING::IOConfig::__post_init__::input_frames doesn't contain 0 (the prediction frame). Please verify your parameters."
)
self.in_dim = len(self.input_frames) * 4
self.out_dim = len(self.pred_frames) * 2
@staticmethod
def from_datasetConfig(config: DatasetConfig) -> IOConfig:
return IOConfig(config.input_frames, config.pred_frames)
Ancestors (in MRO)
- wtracker.utils.config_base.ConfigBase
Static methods
from_datasetConfig
def from_datasetConfig(
config: 'DatasetConfig'
) -> 'IOConfig'
View Source
@staticmethod
def from_datasetConfig(config: DatasetConfig) -> IOConfig:
return IOConfig(config.input_frames, config.pred_frames)
load_json
def load_json(
path: 'str' = None
) -> 'T'
Load the class from a JSON file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path | str | The path to the JSON file. | None |
Returns:
Type | Description |
---|---|
ConfigBase | The class loaded from the JSON file. |
View Source
@classmethod
def load_json(cls: type[T], path: str = None) -> T:
"""
Load the class from a JSON file.
Args:
path (str): The path to the JSON file.
Returns:
ConfigBase: The class loaded from the JSON file.
"""
if path is None:
path = UserPrompt.open_file(
title=f"Open {cls.__name__} File",
file_types=[("json", ".json")],
)
with open(path, "r") as f:
data = json.load(f)
obj = cls.__new__(cls)
obj.__dict__.update(data)
return obj
load_pickle
def load_pickle(
path: 'str' = None
) -> 'T'
Load the class from a pickle file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path | str | The path to the pickle file. | None |
Returns:
Type | Description |
---|---|
None | The class loaded from the pickle file. |
View Source
@classmethod
def load_pickle(cls: type[T], path: str = None) -> T:
"""
Load the class from a pickle file.
Args:
path (str): The path to the pickle file.
Returns:
The class loaded from the pickle file.
"""
if path is None:
path = UserPrompt.open_file(
title=f"Open {cls.__name__} File",
file_types=[("pickle", ".pkl")],
)
return pickle_load_object(path)
Methods
save_json
def save_json(
self,
path: 'str' = None
)
Saves the class as JSON file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path | str | The path to the output JSON file. | None |
View Source
def save_json(self, path: str = None):
"""
Saves the class as JSON file.
Args:
path (str): The path to the output JSON file.
"""
if path is None:
path = UserPrompt.save_file(
title=f"Save {type(self).__name__} As",
file_types=[("json", ".json")],
defaultextension=".json",
)
with open(path, "w") as f:
json.dump(self.__dict__, f, indent=4)
save_pickle
def save_pickle(
self,
path: 'str' = None
) -> 'None'
Saves the class as a pickle file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path | str | The path to the output pickle file. | None |
View Source
def save_pickle(self, path: str = None) -> None:
"""
Saves the class as a pickle file.
Args:
path (str): The path to the output pickle file.
"""
if path is None:
path = UserPrompt.save_file(
title=f"Save {type(self).__name__} As",
file_types=[("pickle", ".pkl")],
defaultextension=".pkl",
)
pickle_save_object(self, path)
TrainConfig
class TrainConfig(
dataset: 'DatasetConfig',
model: 'nn.Module | str',
loss_fn: 'str',
optimizer: 'str',
device: 'str' = 'cuda',
log: 'bool' = False,
num_epochs: 'int' = 100,
checkpoints: 'str' = None,
early_stopping: 'int' = None,
print_every: 'int' = 5,
learning_rate: 'float' = 0.001,
weight_decay: 'float' = 1e-05,
batch_size: 'int' = 256,
shuffle: 'bool' = True,
num_workers: 'int' = 0,
train_test_split: 'float' = 0.8,
*,
seed: 'int' = 42
)
TrainConfig(dataset: 'DatasetConfig', model: 'nn.Module | str', loss_fn: 'str', optimizer: 'str', device: 'str' = 'cuda', log: 'bool' = False, num_epochs: 'int' = 100, checkpoints: 'str' = None, early_stopping: 'int' = None, print_every: 'int' = 5, learning_rate: 'float' = 0.001, weight_decay: 'float' = 1e-05, batch_size: 'int' = 256, shuffle: 'bool' = True, num_workers: 'int' = 0, train_test_split: 'float' = 0.8, *, seed: 'int' = 42)
View Source
@dataclass
class TrainConfig(ConfigBase):
# general parameters
seed: int = field(default=42, kw_only=True) # Random seed for reproducibility
dataset: DatasetConfig # The dataset to use for training, can also be a config object (if Dataset, it will be used as is)
# trainer parameters
model: nn.Module | str # The model to train, can also be a pretrained model (if str, it will be loaded from disk)
loss_fn: str # The loss function to use, can be any of the keys in the LOSSES dict
optimizer: str # The optimizer to use, can be any of the keys in the OPTIMIZERS dict
device: str = "cuda" # 'cuda' for training on GPU or 'cpu' otherwise
log: bool = False # Whether to log and save the training process with tensorboard
# training parameters
num_epochs: int = 100 # Number of times to iterate over the dataset
checkpoints: str = None # Path to save model checkpoints, influding the checkpoint name.
early_stopping: int = None # Number of epochs to wait before stopping training if no improvement was made
print_every: int = 5 # How often (#epochs) to print training progress
# optimizer parameters
learning_rate: float = 0.001 # Learning rate for the optimizer
weight_decay: float = (
1e-5 # Weight decay for the optimizer (regularization, values typically in range [0.0, 1e-4] but can be bigger)
)
# dataloader parameters
batch_size: int = 256 # Number of samples in each batch
shuffle: bool = True # Whether to shuffle the dataset at the beginning of each epoch
num_workers: int = 0 # Number of subprocesses to use for data loading
train_test_split: float = 0.8 # Fraction of the dataset to use for training, the rest will be used for testing
dl_train: DataLoader = field(init=False)
dl_test: DataLoader = field(init=False)
Ancestors (in MRO)
- wtracker.utils.config_base.ConfigBase
Class variables
batch_size
checkpoints
device
early_stopping
learning_rate
log
num_epochs
num_workers
print_every
seed
shuffle
train_test_split
weight_decay
Static methods
load_json
def load_json(
path: 'str' = None
) -> 'T'
Load the class from a JSON file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path | str | The path to the JSON file. | None |
Returns:
Type | Description |
---|---|
ConfigBase | The class loaded from the JSON file. |
View Source
@classmethod
def load_json(cls: type[T], path: str = None) -> T:
"""
Load the class from a JSON file.
Args:
path (str): The path to the JSON file.
Returns:
ConfigBase: The class loaded from the JSON file.
"""
if path is None:
path = UserPrompt.open_file(
title=f"Open {cls.__name__} File",
file_types=[("json", ".json")],
)
with open(path, "r") as f:
data = json.load(f)
obj = cls.__new__(cls)
obj.__dict__.update(data)
return obj
load_pickle
def load_pickle(
path: 'str' = None
) -> 'T'
Load the class from a pickle file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path | str | The path to the pickle file. | None |
Returns:
Type | Description |
---|---|
None | The class loaded from the pickle file. |
View Source
@classmethod
def load_pickle(cls: type[T], path: str = None) -> T:
"""
Load the class from a pickle file.
Args:
path (str): The path to the pickle file.
Returns:
The class loaded from the pickle file.
"""
if path is None:
path = UserPrompt.open_file(
title=f"Open {cls.__name__} File",
file_types=[("pickle", ".pkl")],
)
return pickle_load_object(path)
Methods
save_json
def save_json(
self,
path: 'str' = None
)
Saves the class as JSON file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path | str | The path to the output JSON file. | None |
View Source
def save_json(self, path: str = None):
"""
Saves the class as JSON file.
Args:
path (str): The path to the output JSON file.
"""
if path is None:
path = UserPrompt.save_file(
title=f"Save {type(self).__name__} As",
file_types=[("json", ".json")],
defaultextension=".json",
)
with open(path, "w") as f:
json.dump(self.__dict__, f, indent=4)
save_pickle
def save_pickle(
self,
path: 'str' = None
) -> 'None'
Saves the class as a pickle file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path | str | The path to the output pickle file. | None |
View Source
def save_pickle(self, path: str = None) -> None:
"""
Saves the class as a pickle file.
Args:
path (str): The path to the output pickle file.
"""
if path is None:
path = UserPrompt.save_file(
title=f"Save {type(self).__name__} As",
file_types=[("pickle", ".pkl")],
defaultextension=".pkl",
)
pickle_save_object(self, path)