Skip to content

Module wtracker.neural.training

View Source
import os

import abc

import sys

import torch

import torch.nn as nn

import torch.nn.functional

import tqdm.auto

from torch import Tensor

from typing import Any, Tuple, Callable, Optional

from torch.optim import Optimizer

from torch.utils.data import DataLoader

from torch.utils.tensorboard import SummaryWriter

from wtracker.neural.train_results import FitResult, BatchResult, EpochResult

class Trainer(abc.ABC):

    """

    A class abstracting the various tasks of training models.

    Provides methods at multiple levels of granularity:

    - Multiple epochs (fit)

    - Single epoch (train_epoch/test_epoch)

    - Single batch (train_batch/test_batch)

    Args:

        model (nn.Module): The model to train.

        device (Optional[torch.device], optional): The device to run training on (CPU or GPU).

        log (bool, optional): Whether to log training progress with tensorboard.

    """

    def __init__(

        self,

        model: nn.Module,

        device: Optional[torch.device] = None,

        log: bool = False,

    ):

        self.model = model

        self.device = device

        self.logger = None if not log else SummaryWriter()

        if self.logger is not None:

            self.logger.add_hparams({"model": model.__class__.__name__}, {}, run_name="hparams")

            self.logger.add_hparams({"device": str(device)}, {}, run_name="hparams")

        if self.device:

            model.to(self.device)

    def _make_batch_result(self, loss, num_correct) -> BatchResult:

        loss = loss.item() if isinstance(loss, Tensor) else loss

        num_correct = num_correct.item() if isinstance(num_correct, Tensor) else num_correct

        return BatchResult(float(loss), int(num_correct))

    def _make_fit_result(self, num_epochs, train_losses, train_acc, test_losses, test_acc) -> FitResult:

        num_epochs = num_epochs.item() if isinstance(num_epochs, Tensor) else num_epochs

        train_losses = [x.item() if isinstance(x, Tensor) else x for x in train_losses]

        train_acc = [x.item() if isinstance(x, Tensor) else x for x in train_acc]

        test_losses = [x.item() if isinstance(x, Tensor) else x for x in test_losses]

        test_acc = [x.item() if isinstance(x, Tensor) else x for x in test_acc]

        return FitResult(int(num_epochs), train_losses, train_acc, test_losses, test_acc)

    def fit(

        self,

        dl_train: DataLoader,

        dl_test: DataLoader,

        num_epochs: int,

        checkpoints: str = None,

        early_stopping: int = None,

        print_every: int = 1,

        **kw,

    ) -> FitResult:

        """

        Trains the model for multiple epochs with a given training set,

        and calculates validation loss over a given validation set.

        Args:

            dl_train (DataLoader): Dataloader for the training set.

            dl_test (DataLoader): Dataloader for the test set.

            num_epochs (int): Number of epochs to train for.

            checkpoints (str, optional): Whether to save model to file every time the test set accuracy improves. Should be a string containing a filename without extension.

            early_stopping (int, optional): Whether to stop training early if there is no test loss improvement for this number of epochs.

            print_every (int, optional): Print progress every this number of epochs.

        Returns:

            FitResult: A FitResult object containing train and test losses per epoch.

        """

        actual_epoch_num = 0

        epochs_without_improvement = 0

        train_loss, train_acc, test_loss, test_acc = [], [], [], []

        best_val_loss = None

        # add graph to tensorboard

        if self.logger is not None:

            self.logger.add_graph(self.model, next(iter(dl_train))[0])

        for epoch in range(num_epochs):

            actual_epoch_num += 1

            verbose = False  # pass this to train/test_epoch.

            if print_every > 0 and (epoch % print_every == 0 or epoch == num_epochs - 1):

                verbose = True

            self._print(f"--- EPOCH {epoch+1}/{num_epochs} ---", verbose)

            train_result = self.train_epoch(dl_train, verbose=verbose, **kw)

            test_result = self.test_epoch(dl_test, verbose=verbose, **kw)

            train_loss.extend(train_result.losses)

            train_acc.append(train_result.accuracy)

            test_loss.extend(test_result.losses)

            test_acc.append(test_result.accuracy)

            # log results to tensorboard

            if self.logger is not None:

                self.logger.add_scalar("loss/train", Tensor(train_result.losses).mean(), epoch)

                self.logger.add_scalar("loss/test", Tensor(test_result.losses).mean(), epoch)

                self.logger.add_scalar("accuracy/train", train_result.accuracy, epoch)

                self.logger.add_scalar("accuracy/test", test_result.accuracy, epoch)

                self.logger.add_scalar("learning_rate", self.optimizer.param_groups[0]["lr"], epoch)

            curr_val_loss = Tensor(test_result.losses).mean().item()

            if best_val_loss is None or curr_val_loss < best_val_loss:

                best_val_loss = curr_val_loss

                epochs_without_improvement = 0

                if checkpoints is not None:

                    self.save_checkpoint(checkpoints, curr_val_loss)

            else:

                epochs_without_improvement += 1

                if early_stopping is not None and epochs_without_improvement >= early_stopping:

                    break

        return self._make_fit_result(actual_epoch_num, train_loss, train_acc, test_loss, test_acc)

    def save_checkpoint(self, checkpoint_filename: str, loss: Optional[float] = None) -> None:

        """

        Saves the model in it's current state to a file with the given name (treated

        as a relative path).

        Args:

            checkpoint_filename (str): File name or relative path to save to.

        """

        if self.logger is not None:

            checkpoint_filename = f"{self.logger.log_dir}/{checkpoint_filename}"

        torch.save(self.model, checkpoint_filename)

        print(f"\n*** Saved checkpoint {checkpoint_filename} :: val_loss={loss:.3f}")

    def train_epoch(self, dl_train: DataLoader, **kw) -> EpochResult:

        """

        Train once over a training set (single epoch).

        Args:

            dl_train (DataLoader): DataLoader for the training set.

            kw: Keyword args supported by _foreach_batch.

        Returns:

            EpochResult: An EpochResult for the epoch.

        """

        self.model.train(True)  # set train mode

        return self._foreach_batch(dl_train, self.train_batch, **kw)

    def test_epoch(self, dl_test: DataLoader, **kw) -> EpochResult:

        """

        Evaluate model once over a test set (single epoch).

        Args:

            dl_test (DataLoader): DataLoader for the test set.

            kw: Keyword args supported by _foreach_batch.

        Returns:

            EpochResult: An EpochResult for the epoch.

        """

        self.model.train(False)  # set evaluation (test) mode

        return self._foreach_batch(dl_test, self.test_batch, **kw)

    @abc.abstractmethod

    def train_batch(self, batch) -> BatchResult:

        """

        Runs a single batch forward through the model, calculates loss,

        preforms back-propagation and updates weights.

        Args:

            batch: A single batch of data from a data loader (might

                be a tuple of data and labels or anything else depending on

                the underlying dataset).

        Returns:

            BatchResult: A BatchResult containing the value of the loss function and

                the number of correctly classified samples in the batch.

        """

        raise NotImplementedError()

    @abc.abstractmethod

    def test_batch(self, batch) -> BatchResult:

        """

        Runs a single batch forward through the model and calculates loss.

        Args:

            batch: A single batch of data from a data loader (might

                be a tuple of data and labels or anything else depending on

                the underlying dataset).

        Returns:

            BatchResult: A BatchResult containing the value of the loss function and

                the number of correctly classified samples in the batch.

        """

        raise NotImplementedError()

    @staticmethod

    def _print(message, verbose=True):

        """Simple wrapper around print to make it conditional"""

        if verbose:

            print(message)

    @staticmethod

    def _foreach_batch(

        dl: DataLoader,

        forward_fn: Callable[[Any], BatchResult],

        verbose=True,

        max_batches=None,

    ) -> EpochResult:

        """

        Evaluates the given forward-function on batches from the given

        dataloader, and prints progress along the way.

        """

        losses = []

        num_correct = 0

        num_samples = len(dl.sampler)

        num_batches = len(dl.batch_sampler)

        if max_batches is not None:

            if max_batches < num_batches:

                num_batches = max_batches

                num_samples = num_batches * dl.batch_size

        if verbose:

            pbar_fn = tqdm.auto.tqdm

            pbar_file = sys.stdout

        else:

            pbar_fn = tqdm.tqdm

            pbar_file = open(os.devnull, "w")

        pbar_name = forward_fn.__name__

        with pbar_fn(desc=pbar_name, total=num_batches, file=pbar_file) as pbar:

            dl_iter = iter(dl)

            for batch_idx in range(num_batches):

                data = next(dl_iter)

                batch_res = forward_fn(data)

                pbar.set_description(f"{pbar_name} ({batch_res.loss:.3f})")

                pbar.update()

                losses.append(batch_res.loss)

                num_correct += batch_res.num_correct

            avg_loss = sum(losses) / num_batches

            accuracy = 100.0 * num_correct / num_samples

            pbar.set_description(f"{pbar_name} " f"(Avg. Loss {avg_loss:.3f}, " f"Accuracy {accuracy:.2f}%)")

        if not verbose:

            pbar_file.close()

        return EpochResult(losses=losses, accuracy=accuracy)

    def log_hparam(self, hparam_dict: dict[str, Any], metric_dict: dict[str, Any] = {}, run_name: str = "hparams"):

        if self.logger is not None:

            self.logger.add_hparams(hparam_dict, metric_dict, run_name=run_name)

class MLPTrainer(Trainer):

    """

    The `MLPTrainer` class is responsible for training and testing a multi-layer perceptron (MLP) models.

    Args:

        model (nn.Module): The MLP model to be trained.

        loss_fn (nn.Module): The loss function used for training.

        optimizer (Optimizer): The optimizer used for updating the model's parameters.

        device (Optional[torch.device], optional): The device on which the model and data should be loaded.

        log (bool, optional): Whether to log training progress with tensorboard.

    Attributes:

        loss_fn (nn.Module): The loss function used for training.

        optimizer (Optimizer): The optimizer used for updating the model's parameters.

    """

    def __init__(

        self,

        model: nn.Module,

        loss_fn: nn.Module,

        optimizer: Optimizer,

        device: Optional[torch.device] = None,

        log: bool = False,

    ):

        super().__init__(model, device, log=log)

        self.loss_fn = loss_fn

        self.optimizer = optimizer

        if self.logger is not None:

            self.logger.add_hparams({"loss_fn": loss_fn.__class__.__name__}, {}, run_name="hparams")

            self.logger.add_hparams({"optimizer": optimizer.__class__.__name__}, {}, run_name="hparams")

            optimizer_params = {}

            for key, val in optimizer.param_groups[0].items():

                optimizer_params[key] = str(val)

            optimizer_params.update({"params": ""})

            self.logger.add_hparams(optimizer_params, {}, run_name="hparams")

    def train_batch(self, batch) -> BatchResult:

        X, y = batch

        if self.device:

            X = X.to(self.device)

            y = y.to(self.device)

        self.model: nn.Module

        self.optimizer.zero_grad()

        preds = self.model.forward(X)

        loss = self.loss_fn(preds, y)

        loss.backward()

        self.optimizer.step()

        num_correct = torch.sum((preds - y).norm(dim=1) < 1.0)

        return self._make_batch_result(loss, num_correct)

    @torch.no_grad()

    def test_batch(self, batch) -> BatchResult:

        X, y = batch

        if self.device:

            X = X.to(self.device)

            y = y.to(self.device)

        preds = self.model.forward(X)

        num_correct = torch.sum((preds - y).norm(dim=1) < 1.0)

        loss = self.loss_fn(preds, y)

        return self._make_batch_result(loss, num_correct)

Classes

MLPTrainer

class MLPTrainer(
    model: torch.nn.modules.module.Module,
    loss_fn: torch.nn.modules.module.Module,
    optimizer: torch.optim.optimizer.Optimizer,
    device: Optional[torch.device] = None,
    log: bool = False
)

The MLPTrainer class is responsible for training and testing a multi-layer perceptron (MLP) models.

Attributes

Name Type Description Default
model nn.Module The MLP model to be trained. None
loss_fn nn.Module The loss function used for training. None
optimizer Optimizer The optimizer used for updating the model's parameters. None
device Optional[torch.device] The device on which the model and data should be loaded. None
log bool Whether to log training progress with tensorboard. None
loss_fn nn.Module The loss function used for training. None
optimizer Optimizer The optimizer used for updating the model's parameters. None
View Source
class MLPTrainer(Trainer):

    """

    The `MLPTrainer` class is responsible for training and testing a multi-layer perceptron (MLP) models.

    Args:

        model (nn.Module): The MLP model to be trained.

        loss_fn (nn.Module): The loss function used for training.

        optimizer (Optimizer): The optimizer used for updating the model's parameters.

        device (Optional[torch.device], optional): The device on which the model and data should be loaded.

        log (bool, optional): Whether to log training progress with tensorboard.

    Attributes:

        loss_fn (nn.Module): The loss function used for training.

        optimizer (Optimizer): The optimizer used for updating the model's parameters.

    """

    def __init__(

        self,

        model: nn.Module,

        loss_fn: nn.Module,

        optimizer: Optimizer,

        device: Optional[torch.device] = None,

        log: bool = False,

    ):

        super().__init__(model, device, log=log)

        self.loss_fn = loss_fn

        self.optimizer = optimizer

        if self.logger is not None:

            self.logger.add_hparams({"loss_fn": loss_fn.__class__.__name__}, {}, run_name="hparams")

            self.logger.add_hparams({"optimizer": optimizer.__class__.__name__}, {}, run_name="hparams")

            optimizer_params = {}

            for key, val in optimizer.param_groups[0].items():

                optimizer_params[key] = str(val)

            optimizer_params.update({"params": ""})

            self.logger.add_hparams(optimizer_params, {}, run_name="hparams")

    def train_batch(self, batch) -> BatchResult:

        X, y = batch

        if self.device:

            X = X.to(self.device)

            y = y.to(self.device)

        self.model: nn.Module

        self.optimizer.zero_grad()

        preds = self.model.forward(X)

        loss = self.loss_fn(preds, y)

        loss.backward()

        self.optimizer.step()

        num_correct = torch.sum((preds - y).norm(dim=1) < 1.0)

        return self._make_batch_result(loss, num_correct)

    @torch.no_grad()

    def test_batch(self, batch) -> BatchResult:

        X, y = batch

        if self.device:

            X = X.to(self.device)

            y = y.to(self.device)

        preds = self.model.forward(X)

        num_correct = torch.sum((preds - y).norm(dim=1) < 1.0)

        loss = self.loss_fn(preds, y)

        return self._make_batch_result(loss, num_correct)

Ancestors (in MRO)

  • wtracker.neural.training.Trainer
  • abc.ABC

Methods

fit

def fit(
    self,
    dl_train: torch.utils.data.dataloader.DataLoader,
    dl_test: torch.utils.data.dataloader.DataLoader,
    num_epochs: int,
    checkpoints: str = None,
    early_stopping: int = None,
    print_every: int = 1,
    **kw
) -> wtracker.neural.train_results.FitResult

Trains the model for multiple epochs with a given training set,

and calculates validation loss over a given validation set.

Parameters:

Name Type Description Default
dl_train DataLoader Dataloader for the training set. None
dl_test DataLoader Dataloader for the test set. None
num_epochs int Number of epochs to train for. None
checkpoints str Whether to save model to file every time the test set accuracy improves. Should be a string containing a filename without extension. None
early_stopping int Whether to stop training early if there is no test loss improvement for this number of epochs. None
print_every int Print progress every this number of epochs. None

Returns:

Type Description
FitResult A FitResult object containing train and test losses per epoch.
View Source
    def fit(

        self,

        dl_train: DataLoader,

        dl_test: DataLoader,

        num_epochs: int,

        checkpoints: str = None,

        early_stopping: int = None,

        print_every: int = 1,

        **kw,

    ) -> FitResult:

        """

        Trains the model for multiple epochs with a given training set,

        and calculates validation loss over a given validation set.

        Args:

            dl_train (DataLoader): Dataloader for the training set.

            dl_test (DataLoader): Dataloader for the test set.

            num_epochs (int): Number of epochs to train for.

            checkpoints (str, optional): Whether to save model to file every time the test set accuracy improves. Should be a string containing a filename without extension.

            early_stopping (int, optional): Whether to stop training early if there is no test loss improvement for this number of epochs.

            print_every (int, optional): Print progress every this number of epochs.

        Returns:

            FitResult: A FitResult object containing train and test losses per epoch.

        """

        actual_epoch_num = 0

        epochs_without_improvement = 0

        train_loss, train_acc, test_loss, test_acc = [], [], [], []

        best_val_loss = None

        # add graph to tensorboard

        if self.logger is not None:

            self.logger.add_graph(self.model, next(iter(dl_train))[0])

        for epoch in range(num_epochs):

            actual_epoch_num += 1

            verbose = False  # pass this to train/test_epoch.

            if print_every > 0 and (epoch % print_every == 0 or epoch == num_epochs - 1):

                verbose = True

            self._print(f"--- EPOCH {epoch+1}/{num_epochs} ---", verbose)

            train_result = self.train_epoch(dl_train, verbose=verbose, **kw)

            test_result = self.test_epoch(dl_test, verbose=verbose, **kw)

            train_loss.extend(train_result.losses)

            train_acc.append(train_result.accuracy)

            test_loss.extend(test_result.losses)

            test_acc.append(test_result.accuracy)

            # log results to tensorboard

            if self.logger is not None:

                self.logger.add_scalar("loss/train", Tensor(train_result.losses).mean(), epoch)

                self.logger.add_scalar("loss/test", Tensor(test_result.losses).mean(), epoch)

                self.logger.add_scalar("accuracy/train", train_result.accuracy, epoch)

                self.logger.add_scalar("accuracy/test", test_result.accuracy, epoch)

                self.logger.add_scalar("learning_rate", self.optimizer.param_groups[0]["lr"], epoch)

            curr_val_loss = Tensor(test_result.losses).mean().item()

            if best_val_loss is None or curr_val_loss < best_val_loss:

                best_val_loss = curr_val_loss

                epochs_without_improvement = 0

                if checkpoints is not None:

                    self.save_checkpoint(checkpoints, curr_val_loss)

            else:

                epochs_without_improvement += 1

                if early_stopping is not None and epochs_without_improvement >= early_stopping:

                    break

        return self._make_fit_result(actual_epoch_num, train_loss, train_acc, test_loss, test_acc)

log_hparam

def log_hparam(
    self,
    hparam_dict: dict[str, typing.Any],
    metric_dict: dict[str, typing.Any] = {},
    run_name: str = 'hparams'
)
View Source
    def log_hparam(self, hparam_dict: dict[str, Any], metric_dict: dict[str, Any] = {}, run_name: str = "hparams"):

        if self.logger is not None:

            self.logger.add_hparams(hparam_dict, metric_dict, run_name=run_name)

save_checkpoint

def save_checkpoint(
    self,
    checkpoint_filename: str,
    loss: Optional[float] = None
) -> None

Saves the model in it's current state to a file with the given name (treated

as a relative path).

Parameters:

Name Type Description Default
checkpoint_filename str File name or relative path to save to. None
View Source
    def save_checkpoint(self, checkpoint_filename: str, loss: Optional[float] = None) -> None:

        """

        Saves the model in it's current state to a file with the given name (treated

        as a relative path).

        Args:

            checkpoint_filename (str): File name or relative path to save to.

        """

        if self.logger is not None:

            checkpoint_filename = f"{self.logger.log_dir}/{checkpoint_filename}"

        torch.save(self.model, checkpoint_filename)

        print(f"\n*** Saved checkpoint {checkpoint_filename} :: val_loss={loss:.3f}")

test_batch

def test_batch(
    self,
    batch
) -> wtracker.neural.train_results.BatchResult

Runs a single batch forward through the model and calculates loss.

Parameters:

Name Type Description Default
batch None A single batch of data from a data loader (might
be a tuple of data and labels or anything else depending on
the underlying dataset).
None

Returns:

Type Description
BatchResult A BatchResult containing the value of the loss function and
the number of correctly classified samples in the batch.
View Source
    @torch.no_grad()

    def test_batch(self, batch) -> BatchResult:

        X, y = batch

        if self.device:

            X = X.to(self.device)

            y = y.to(self.device)

        preds = self.model.forward(X)

        num_correct = torch.sum((preds - y).norm(dim=1) < 1.0)

        loss = self.loss_fn(preds, y)

        return self._make_batch_result(loss, num_correct)

test_epoch

def test_epoch(
    self,
    dl_test: torch.utils.data.dataloader.DataLoader,
    **kw
) -> wtracker.neural.train_results.EpochResult

Evaluate model once over a test set (single epoch).

Parameters:

Name Type Description Default
dl_test DataLoader DataLoader for the test set. None
kw None Keyword args supported by _foreach_batch. None

Returns:

Type Description
EpochResult An EpochResult for the epoch.
View Source
    def test_epoch(self, dl_test: DataLoader, **kw) -> EpochResult:

        """

        Evaluate model once over a test set (single epoch).

        Args:

            dl_test (DataLoader): DataLoader for the test set.

            kw: Keyword args supported by _foreach_batch.

        Returns:

            EpochResult: An EpochResult for the epoch.

        """

        self.model.train(False)  # set evaluation (test) mode

        return self._foreach_batch(dl_test, self.test_batch, **kw)

train_batch

def train_batch(
    self,
    batch
) -> wtracker.neural.train_results.BatchResult

Runs a single batch forward through the model, calculates loss,

preforms back-propagation and updates weights.

Parameters:

Name Type Description Default
batch None A single batch of data from a data loader (might
be a tuple of data and labels or anything else depending on
the underlying dataset).
None

Returns:

Type Description
BatchResult A BatchResult containing the value of the loss function and
the number of correctly classified samples in the batch.
View Source
    def train_batch(self, batch) -> BatchResult:

        X, y = batch

        if self.device:

            X = X.to(self.device)

            y = y.to(self.device)

        self.model: nn.Module

        self.optimizer.zero_grad()

        preds = self.model.forward(X)

        loss = self.loss_fn(preds, y)

        loss.backward()

        self.optimizer.step()

        num_correct = torch.sum((preds - y).norm(dim=1) < 1.0)

        return self._make_batch_result(loss, num_correct)

train_epoch

def train_epoch(
    self,
    dl_train: torch.utils.data.dataloader.DataLoader,
    **kw
) -> wtracker.neural.train_results.EpochResult

Train once over a training set (single epoch).

Parameters:

Name Type Description Default
dl_train DataLoader DataLoader for the training set. None
kw None Keyword args supported by _foreach_batch. None

Returns:

Type Description
EpochResult An EpochResult for the epoch.
View Source
    def train_epoch(self, dl_train: DataLoader, **kw) -> EpochResult:

        """

        Train once over a training set (single epoch).

        Args:

            dl_train (DataLoader): DataLoader for the training set.

            kw: Keyword args supported by _foreach_batch.

        Returns:

            EpochResult: An EpochResult for the epoch.

        """

        self.model.train(True)  # set train mode

        return self._foreach_batch(dl_train, self.train_batch, **kw)

Trainer

class Trainer(
    model: torch.nn.modules.module.Module,
    device: Optional[torch.device] = None,
    log: bool = False
)

A class abstracting the various tasks of training models.

Provides methods at multiple levels of granularity: - Multiple epochs (fit) - Single epoch (train_epoch/test_epoch) - Single batch (train_batch/test_batch)

Attributes

Name Type Description Default
model nn.Module The model to train. None
device Optional[torch.device] The device to run training on (CPU or GPU). None
log bool Whether to log training progress with tensorboard. None
View Source
class Trainer(abc.ABC):

    """

    A class abstracting the various tasks of training models.

    Provides methods at multiple levels of granularity:

    - Multiple epochs (fit)

    - Single epoch (train_epoch/test_epoch)

    - Single batch (train_batch/test_batch)

    Args:

        model (nn.Module): The model to train.

        device (Optional[torch.device], optional): The device to run training on (CPU or GPU).

        log (bool, optional): Whether to log training progress with tensorboard.

    """

    def __init__(

        self,

        model: nn.Module,

        device: Optional[torch.device] = None,

        log: bool = False,

    ):

        self.model = model

        self.device = device

        self.logger = None if not log else SummaryWriter()

        if self.logger is not None:

            self.logger.add_hparams({"model": model.__class__.__name__}, {}, run_name="hparams")

            self.logger.add_hparams({"device": str(device)}, {}, run_name="hparams")

        if self.device:

            model.to(self.device)

    def _make_batch_result(self, loss, num_correct) -> BatchResult:

        loss = loss.item() if isinstance(loss, Tensor) else loss

        num_correct = num_correct.item() if isinstance(num_correct, Tensor) else num_correct

        return BatchResult(float(loss), int(num_correct))

    def _make_fit_result(self, num_epochs, train_losses, train_acc, test_losses, test_acc) -> FitResult:

        num_epochs = num_epochs.item() if isinstance(num_epochs, Tensor) else num_epochs

        train_losses = [x.item() if isinstance(x, Tensor) else x for x in train_losses]

        train_acc = [x.item() if isinstance(x, Tensor) else x for x in train_acc]

        test_losses = [x.item() if isinstance(x, Tensor) else x for x in test_losses]

        test_acc = [x.item() if isinstance(x, Tensor) else x for x in test_acc]

        return FitResult(int(num_epochs), train_losses, train_acc, test_losses, test_acc)

    def fit(

        self,

        dl_train: DataLoader,

        dl_test: DataLoader,

        num_epochs: int,

        checkpoints: str = None,

        early_stopping: int = None,

        print_every: int = 1,

        **kw,

    ) -> FitResult:

        """

        Trains the model for multiple epochs with a given training set,

        and calculates validation loss over a given validation set.

        Args:

            dl_train (DataLoader): Dataloader for the training set.

            dl_test (DataLoader): Dataloader for the test set.

            num_epochs (int): Number of epochs to train for.

            checkpoints (str, optional): Whether to save model to file every time the test set accuracy improves. Should be a string containing a filename without extension.

            early_stopping (int, optional): Whether to stop training early if there is no test loss improvement for this number of epochs.

            print_every (int, optional): Print progress every this number of epochs.

        Returns:

            FitResult: A FitResult object containing train and test losses per epoch.

        """

        actual_epoch_num = 0

        epochs_without_improvement = 0

        train_loss, train_acc, test_loss, test_acc = [], [], [], []

        best_val_loss = None

        # add graph to tensorboard

        if self.logger is not None:

            self.logger.add_graph(self.model, next(iter(dl_train))[0])

        for epoch in range(num_epochs):

            actual_epoch_num += 1

            verbose = False  # pass this to train/test_epoch.

            if print_every > 0 and (epoch % print_every == 0 or epoch == num_epochs - 1):

                verbose = True

            self._print(f"--- EPOCH {epoch+1}/{num_epochs} ---", verbose)

            train_result = self.train_epoch(dl_train, verbose=verbose, **kw)

            test_result = self.test_epoch(dl_test, verbose=verbose, **kw)

            train_loss.extend(train_result.losses)

            train_acc.append(train_result.accuracy)

            test_loss.extend(test_result.losses)

            test_acc.append(test_result.accuracy)

            # log results to tensorboard

            if self.logger is not None:

                self.logger.add_scalar("loss/train", Tensor(train_result.losses).mean(), epoch)

                self.logger.add_scalar("loss/test", Tensor(test_result.losses).mean(), epoch)

                self.logger.add_scalar("accuracy/train", train_result.accuracy, epoch)

                self.logger.add_scalar("accuracy/test", test_result.accuracy, epoch)

                self.logger.add_scalar("learning_rate", self.optimizer.param_groups[0]["lr"], epoch)

            curr_val_loss = Tensor(test_result.losses).mean().item()

            if best_val_loss is None or curr_val_loss < best_val_loss:

                best_val_loss = curr_val_loss

                epochs_without_improvement = 0

                if checkpoints is not None:

                    self.save_checkpoint(checkpoints, curr_val_loss)

            else:

                epochs_without_improvement += 1

                if early_stopping is not None and epochs_without_improvement >= early_stopping:

                    break

        return self._make_fit_result(actual_epoch_num, train_loss, train_acc, test_loss, test_acc)

    def save_checkpoint(self, checkpoint_filename: str, loss: Optional[float] = None) -> None:

        """

        Saves the model in it's current state to a file with the given name (treated

        as a relative path).

        Args:

            checkpoint_filename (str): File name or relative path to save to.

        """

        if self.logger is not None:

            checkpoint_filename = f"{self.logger.log_dir}/{checkpoint_filename}"

        torch.save(self.model, checkpoint_filename)

        print(f"\n*** Saved checkpoint {checkpoint_filename} :: val_loss={loss:.3f}")

    def train_epoch(self, dl_train: DataLoader, **kw) -> EpochResult:

        """

        Train once over a training set (single epoch).

        Args:

            dl_train (DataLoader): DataLoader for the training set.

            kw: Keyword args supported by _foreach_batch.

        Returns:

            EpochResult: An EpochResult for the epoch.

        """

        self.model.train(True)  # set train mode

        return self._foreach_batch(dl_train, self.train_batch, **kw)

    def test_epoch(self, dl_test: DataLoader, **kw) -> EpochResult:

        """

        Evaluate model once over a test set (single epoch).

        Args:

            dl_test (DataLoader): DataLoader for the test set.

            kw: Keyword args supported by _foreach_batch.

        Returns:

            EpochResult: An EpochResult for the epoch.

        """

        self.model.train(False)  # set evaluation (test) mode

        return self._foreach_batch(dl_test, self.test_batch, **kw)

    @abc.abstractmethod

    def train_batch(self, batch) -> BatchResult:

        """

        Runs a single batch forward through the model, calculates loss,

        preforms back-propagation and updates weights.

        Args:

            batch: A single batch of data from a data loader (might

                be a tuple of data and labels or anything else depending on

                the underlying dataset).

        Returns:

            BatchResult: A BatchResult containing the value of the loss function and

                the number of correctly classified samples in the batch.

        """

        raise NotImplementedError()

    @abc.abstractmethod

    def test_batch(self, batch) -> BatchResult:

        """

        Runs a single batch forward through the model and calculates loss.

        Args:

            batch: A single batch of data from a data loader (might

                be a tuple of data and labels or anything else depending on

                the underlying dataset).

        Returns:

            BatchResult: A BatchResult containing the value of the loss function and

                the number of correctly classified samples in the batch.

        """

        raise NotImplementedError()

    @staticmethod

    def _print(message, verbose=True):

        """Simple wrapper around print to make it conditional"""

        if verbose:

            print(message)

    @staticmethod

    def _foreach_batch(

        dl: DataLoader,

        forward_fn: Callable[[Any], BatchResult],

        verbose=True,

        max_batches=None,

    ) -> EpochResult:

        """

        Evaluates the given forward-function on batches from the given

        dataloader, and prints progress along the way.

        """

        losses = []

        num_correct = 0

        num_samples = len(dl.sampler)

        num_batches = len(dl.batch_sampler)

        if max_batches is not None:

            if max_batches < num_batches:

                num_batches = max_batches

                num_samples = num_batches * dl.batch_size

        if verbose:

            pbar_fn = tqdm.auto.tqdm

            pbar_file = sys.stdout

        else:

            pbar_fn = tqdm.tqdm

            pbar_file = open(os.devnull, "w")

        pbar_name = forward_fn.__name__

        with pbar_fn(desc=pbar_name, total=num_batches, file=pbar_file) as pbar:

            dl_iter = iter(dl)

            for batch_idx in range(num_batches):

                data = next(dl_iter)

                batch_res = forward_fn(data)

                pbar.set_description(f"{pbar_name} ({batch_res.loss:.3f})")

                pbar.update()

                losses.append(batch_res.loss)

                num_correct += batch_res.num_correct

            avg_loss = sum(losses) / num_batches

            accuracy = 100.0 * num_correct / num_samples

            pbar.set_description(f"{pbar_name} " f"(Avg. Loss {avg_loss:.3f}, " f"Accuracy {accuracy:.2f}%)")

        if not verbose:

            pbar_file.close()

        return EpochResult(losses=losses, accuracy=accuracy)

    def log_hparam(self, hparam_dict: dict[str, Any], metric_dict: dict[str, Any] = {}, run_name: str = "hparams"):

        if self.logger is not None:

            self.logger.add_hparams(hparam_dict, metric_dict, run_name=run_name)

Ancestors (in MRO)

  • abc.ABC

Descendants

  • wtracker.neural.training.MLPTrainer

Methods

fit

def fit(
    self,
    dl_train: torch.utils.data.dataloader.DataLoader,
    dl_test: torch.utils.data.dataloader.DataLoader,
    num_epochs: int,
    checkpoints: str = None,
    early_stopping: int = None,
    print_every: int = 1,
    **kw
) -> wtracker.neural.train_results.FitResult

Trains the model for multiple epochs with a given training set,

and calculates validation loss over a given validation set.

Parameters:

Name Type Description Default
dl_train DataLoader Dataloader for the training set. None
dl_test DataLoader Dataloader for the test set. None
num_epochs int Number of epochs to train for. None
checkpoints str Whether to save model to file every time the test set accuracy improves. Should be a string containing a filename without extension. None
early_stopping int Whether to stop training early if there is no test loss improvement for this number of epochs. None
print_every int Print progress every this number of epochs. None

Returns:

Type Description
FitResult A FitResult object containing train and test losses per epoch.
View Source
    def fit(

        self,

        dl_train: DataLoader,

        dl_test: DataLoader,

        num_epochs: int,

        checkpoints: str = None,

        early_stopping: int = None,

        print_every: int = 1,

        **kw,

    ) -> FitResult:

        """

        Trains the model for multiple epochs with a given training set,

        and calculates validation loss over a given validation set.

        Args:

            dl_train (DataLoader): Dataloader for the training set.

            dl_test (DataLoader): Dataloader for the test set.

            num_epochs (int): Number of epochs to train for.

            checkpoints (str, optional): Whether to save model to file every time the test set accuracy improves. Should be a string containing a filename without extension.

            early_stopping (int, optional): Whether to stop training early if there is no test loss improvement for this number of epochs.

            print_every (int, optional): Print progress every this number of epochs.

        Returns:

            FitResult: A FitResult object containing train and test losses per epoch.

        """

        actual_epoch_num = 0

        epochs_without_improvement = 0

        train_loss, train_acc, test_loss, test_acc = [], [], [], []

        best_val_loss = None

        # add graph to tensorboard

        if self.logger is not None:

            self.logger.add_graph(self.model, next(iter(dl_train))[0])

        for epoch in range(num_epochs):

            actual_epoch_num += 1

            verbose = False  # pass this to train/test_epoch.

            if print_every > 0 and (epoch % print_every == 0 or epoch == num_epochs - 1):

                verbose = True

            self._print(f"--- EPOCH {epoch+1}/{num_epochs} ---", verbose)

            train_result = self.train_epoch(dl_train, verbose=verbose, **kw)

            test_result = self.test_epoch(dl_test, verbose=verbose, **kw)

            train_loss.extend(train_result.losses)

            train_acc.append(train_result.accuracy)

            test_loss.extend(test_result.losses)

            test_acc.append(test_result.accuracy)

            # log results to tensorboard

            if self.logger is not None:

                self.logger.add_scalar("loss/train", Tensor(train_result.losses).mean(), epoch)

                self.logger.add_scalar("loss/test", Tensor(test_result.losses).mean(), epoch)

                self.logger.add_scalar("accuracy/train", train_result.accuracy, epoch)

                self.logger.add_scalar("accuracy/test", test_result.accuracy, epoch)

                self.logger.add_scalar("learning_rate", self.optimizer.param_groups[0]["lr"], epoch)

            curr_val_loss = Tensor(test_result.losses).mean().item()

            if best_val_loss is None or curr_val_loss < best_val_loss:

                best_val_loss = curr_val_loss

                epochs_without_improvement = 0

                if checkpoints is not None:

                    self.save_checkpoint(checkpoints, curr_val_loss)

            else:

                epochs_without_improvement += 1

                if early_stopping is not None and epochs_without_improvement >= early_stopping:

                    break

        return self._make_fit_result(actual_epoch_num, train_loss, train_acc, test_loss, test_acc)

log_hparam

def log_hparam(
    self,
    hparam_dict: dict[str, typing.Any],
    metric_dict: dict[str, typing.Any] = {},
    run_name: str = 'hparams'
)
View Source
    def log_hparam(self, hparam_dict: dict[str, Any], metric_dict: dict[str, Any] = {}, run_name: str = "hparams"):

        if self.logger is not None:

            self.logger.add_hparams(hparam_dict, metric_dict, run_name=run_name)

save_checkpoint

def save_checkpoint(
    self,
    checkpoint_filename: str,
    loss: Optional[float] = None
) -> None

Saves the model in it's current state to a file with the given name (treated

as a relative path).

Parameters:

Name Type Description Default
checkpoint_filename str File name or relative path to save to. None
View Source
    def save_checkpoint(self, checkpoint_filename: str, loss: Optional[float] = None) -> None:

        """

        Saves the model in it's current state to a file with the given name (treated

        as a relative path).

        Args:

            checkpoint_filename (str): File name or relative path to save to.

        """

        if self.logger is not None:

            checkpoint_filename = f"{self.logger.log_dir}/{checkpoint_filename}"

        torch.save(self.model, checkpoint_filename)

        print(f"\n*** Saved checkpoint {checkpoint_filename} :: val_loss={loss:.3f}")

test_batch

def test_batch(
    self,
    batch
) -> wtracker.neural.train_results.BatchResult

Runs a single batch forward through the model and calculates loss.

Parameters:

Name Type Description Default
batch None A single batch of data from a data loader (might
be a tuple of data and labels or anything else depending on
the underlying dataset).
None

Returns:

Type Description
BatchResult A BatchResult containing the value of the loss function and
the number of correctly classified samples in the batch.
View Source
    @abc.abstractmethod

    def test_batch(self, batch) -> BatchResult:

        """

        Runs a single batch forward through the model and calculates loss.

        Args:

            batch: A single batch of data from a data loader (might

                be a tuple of data and labels or anything else depending on

                the underlying dataset).

        Returns:

            BatchResult: A BatchResult containing the value of the loss function and

                the number of correctly classified samples in the batch.

        """

        raise NotImplementedError()

test_epoch

def test_epoch(
    self,
    dl_test: torch.utils.data.dataloader.DataLoader,
    **kw
) -> wtracker.neural.train_results.EpochResult

Evaluate model once over a test set (single epoch).

Parameters:

Name Type Description Default
dl_test DataLoader DataLoader for the test set. None
kw None Keyword args supported by _foreach_batch. None

Returns:

Type Description
EpochResult An EpochResult for the epoch.
View Source
    def test_epoch(self, dl_test: DataLoader, **kw) -> EpochResult:

        """

        Evaluate model once over a test set (single epoch).

        Args:

            dl_test (DataLoader): DataLoader for the test set.

            kw: Keyword args supported by _foreach_batch.

        Returns:

            EpochResult: An EpochResult for the epoch.

        """

        self.model.train(False)  # set evaluation (test) mode

        return self._foreach_batch(dl_test, self.test_batch, **kw)

train_batch

def train_batch(
    self,
    batch
) -> wtracker.neural.train_results.BatchResult

Runs a single batch forward through the model, calculates loss,

preforms back-propagation and updates weights.

Parameters:

Name Type Description Default
batch None A single batch of data from a data loader (might
be a tuple of data and labels or anything else depending on
the underlying dataset).
None

Returns:

Type Description
BatchResult A BatchResult containing the value of the loss function and
the number of correctly classified samples in the batch.
View Source
    @abc.abstractmethod

    def train_batch(self, batch) -> BatchResult:

        """

        Runs a single batch forward through the model, calculates loss,

        preforms back-propagation and updates weights.

        Args:

            batch: A single batch of data from a data loader (might

                be a tuple of data and labels or anything else depending on

                the underlying dataset).

        Returns:

            BatchResult: A BatchResult containing the value of the loss function and

                the number of correctly classified samples in the batch.

        """

        raise NotImplementedError()

train_epoch

def train_epoch(
    self,
    dl_train: torch.utils.data.dataloader.DataLoader,
    **kw
) -> wtracker.neural.train_results.EpochResult

Train once over a training set (single epoch).

Parameters:

Name Type Description Default
dl_train DataLoader DataLoader for the training set. None
kw None Keyword args supported by _foreach_batch. None

Returns:

Type Description
EpochResult An EpochResult for the epoch.
View Source
    def train_epoch(self, dl_train: DataLoader, **kw) -> EpochResult:

        """

        Train once over a training set (single epoch).

        Args:

            dl_train (DataLoader): DataLoader for the training set.

            kw: Keyword args supported by _foreach_batch.

        Returns:

            EpochResult: An EpochResult for the epoch.

        """

        self.model.train(True)  # set train mode

        return self._foreach_batch(dl_train, self.train_batch, **kw)