Skip to content

Module wtracker.sim.sim_controllers.mlp_controllers

View Source
from typing import Collection

import numpy as np

from collections import deque

from torch import Tensor

from wtracker.sim.config import TimingConfig

from wtracker.sim.simulator import Simulator

from wtracker.sim.sim_controllers.csv_controller import CsvController

from wtracker.utils.bbox_utils import BoxUtils, BoxConverter, BoxFormat

from wtracker.neural.mlp import WormPredictor

from wtracker.neural.config import IOConfig

class MLPController(CsvController):

    """

    MLPController class represents a controller that uses a WormPredictor model to provide movement vectors for a simulation.

    Args:

        timing_config (TimingConfig): The timing configuration for the simulation.

        csv_path (str): The path to the CSV file containing the simulation data.

        model (WormPredictor): The WormPredictor model used for predicting worm movement.

        max_speed (float): max speed of the worm in mm/s, predictions above this will be clipped.

    """

    def __init__(self, timing_config: TimingConfig, csv_path: str, model: WormPredictor, max_speed: float = 0.9):

        super().__init__(timing_config, csv_path)

        self.model: WormPredictor = model

        self.io_config: IOConfig = model.io_config

        self.model.eval()

        px_per_mm = self.timing_config.px_per_mm

        fps = self.timing_config.frames_per_sec

        max_speed_px_frame = max_speed * (px_per_mm / fps)

        self.max_dist_per_pred = max_speed_px_frame * (self.io_config.pred_frames[0])

    def provide_movement_vector(self, sim: Simulator) -> tuple[int, int]:

        # frames for prediction (input to the model)

        frames_for_pred = np.asanyarray(self.io_config.input_frames, dtype=int)

        frames_for_pred += sim.frame_number - self.timing_config.pred_frame_num

        cam_center = BoxUtils.center(np.asanyarray(sim.view.camera_position))

        worm_bboxes = self.predict(frames_for_pred, relative=False).reshape(1, -1)

        if not np.isfinite(worm_bboxes).all():

            return 0, 0

        # relative position of the worm to the camera center, we use the worm x,y instead of center because of how the model and dataset are built

        rel_x, rel_y = worm_bboxes[0, 0] - cam_center[0], worm_bboxes[0, 1] - cam_center[1]

        # make coordinates relative to first bbox

        x = worm_bboxes[0, 0]

        x_mask = np.arange(0, worm_bboxes.shape[1]) % 4 == 0

        y = worm_bboxes[0, 1]

        y_mask = np.arange(0, worm_bboxes.shape[1]) % 4 == 1

        worm_bboxes[:, x_mask] -= x

        worm_bboxes[:, y_mask] -= y

        # predict the movement of the worm via the model

        pred = self.model.forward(Tensor(worm_bboxes)).flatten().detach().numpy()

        # make sure the prediction is within the limits and apply post-proccessing steps

        pred = np.clip(pred, -self.max_dist_per_pred, self.max_dist_per_pred)

        dx = round(pred[0].item() + rel_x)

        dy = round(pred[1].item() + rel_y)

        # dx = np.clip(dx, -self.max_dist_per_pred, self.max_dist_per_pred)

        # dy = np.clip(dy, -self.max_dist_per_pred, self.max_dist_per_pred)

        return (dx, dy)

    def print_model(self):

        print(self.model)

Classes

MLPController

class MLPController(
    timing_config: wtracker.sim.config.TimingConfig,
    csv_path: str,
    model: wtracker.neural.mlp.WormPredictor,
    max_speed: float = 0.9
)

MLPController class represents a controller that uses a WormPredictor model to provide movement vectors for a simulation.

Attributes

Name Type Description Default
timing_config TimingConfig The timing configuration for the simulation. None
csv_path str The path to the CSV file containing the simulation data. None
model WormPredictor The WormPredictor model used for predicting worm movement. None
max_speed float max speed of the worm in mm/s, predictions above this will be clipped. None
View Source
class MLPController(CsvController):

    """

    MLPController class represents a controller that uses a WormPredictor model to provide movement vectors for a simulation.

    Args:

        timing_config (TimingConfig): The timing configuration for the simulation.

        csv_path (str): The path to the CSV file containing the simulation data.

        model (WormPredictor): The WormPredictor model used for predicting worm movement.

        max_speed (float): max speed of the worm in mm/s, predictions above this will be clipped.

    """

    def __init__(self, timing_config: TimingConfig, csv_path: str, model: WormPredictor, max_speed: float = 0.9):

        super().__init__(timing_config, csv_path)

        self.model: WormPredictor = model

        self.io_config: IOConfig = model.io_config

        self.model.eval()

        px_per_mm = self.timing_config.px_per_mm

        fps = self.timing_config.frames_per_sec

        max_speed_px_frame = max_speed * (px_per_mm / fps)

        self.max_dist_per_pred = max_speed_px_frame * (self.io_config.pred_frames[0])

    def provide_movement_vector(self, sim: Simulator) -> tuple[int, int]:

        # frames for prediction (input to the model)

        frames_for_pred = np.asanyarray(self.io_config.input_frames, dtype=int)

        frames_for_pred += sim.frame_number - self.timing_config.pred_frame_num

        cam_center = BoxUtils.center(np.asanyarray(sim.view.camera_position))

        worm_bboxes = self.predict(frames_for_pred, relative=False).reshape(1, -1)

        if not np.isfinite(worm_bboxes).all():

            return 0, 0

        # relative position of the worm to the camera center, we use the worm x,y instead of center because of how the model and dataset are built

        rel_x, rel_y = worm_bboxes[0, 0] - cam_center[0], worm_bboxes[0, 1] - cam_center[1]

        # make coordinates relative to first bbox

        x = worm_bboxes[0, 0]

        x_mask = np.arange(0, worm_bboxes.shape[1]) % 4 == 0

        y = worm_bboxes[0, 1]

        y_mask = np.arange(0, worm_bboxes.shape[1]) % 4 == 1

        worm_bboxes[:, x_mask] -= x

        worm_bboxes[:, y_mask] -= y

        # predict the movement of the worm via the model

        pred = self.model.forward(Tensor(worm_bboxes)).flatten().detach().numpy()

        # make sure the prediction is within the limits and apply post-proccessing steps

        pred = np.clip(pred, -self.max_dist_per_pred, self.max_dist_per_pred)

        dx = round(pred[0].item() + rel_x)

        dy = round(pred[1].item() + rel_y)

        # dx = np.clip(dx, -self.max_dist_per_pred, self.max_dist_per_pred)

        # dy = np.clip(dy, -self.max_dist_per_pred, self.max_dist_per_pred)

        return (dx, dy)

    def print_model(self):

        print(self.model)

Ancestors (in MRO)

  • wtracker.sim.sim_controllers.csv_controller.CsvController
  • wtracker.sim.simulator.SimController
  • abc.ABC

Methods

begin_movement_prediction

def begin_movement_prediction(
    self,
    sim: wtracker.sim.simulator.Simulator
) -> None

Called when the movement prediction begins.

View Source
    def begin_movement_prediction(self, sim: Simulator) -> None:

        pass

on_camera_frame

def on_camera_frame(
    self,
    sim: wtracker.sim.simulator.Simulator
)

Called when a camera frame is captured. Happens every frame.

View Source
    def on_camera_frame(self, sim: Simulator):

        self._camera_bboxes.append(sim.view.camera_position)

on_cycle_end

def on_cycle_end(
    self,
    sim: 'Simulator'
)

Called when a cycle ends.

View Source
    def on_cycle_end(self, sim: Simulator):

        """

        Called when a cycle ends.

        """

        pass

on_cycle_start

def on_cycle_start(
    self,
    sim: 'Simulator'
)

Called when a new cycle starts.

View Source
    def on_cycle_start(self, sim: Simulator):

        """

        Called when a new cycle starts.

        """

        pass

on_imaging_end

def on_imaging_end(
    self,
    sim: 'Simulator'
)

Called when imaging phase ends.

View Source
    def on_imaging_end(self, sim: Simulator):

        """

        Called when imaging phase ends.

        """

        pass

on_imaging_start

def on_imaging_start(
    self,
    sim: 'Simulator'
)

Called when imaging phase starts.

View Source
    def on_imaging_start(self, sim: Simulator):

        """

        Called when imaging phase starts.

        """

        pass

on_micro_frame

def on_micro_frame(
    self,
    sim: 'Simulator'
)

Called when a micro frame is captured. Happens for every during the imaging phase.

View Source
    def on_micro_frame(self, sim: Simulator):

        """

        Called when a micro frame is captured. Happens for every during the imaging phase.

        """

        pass

on_movement_end

def on_movement_end(
    self,
    sim: 'Simulator'
)

Called when movement phase ends.

View Source
    def on_movement_end(self, sim: Simulator):

        """

        Called when movement phase ends.

        """

        pass

on_movement_start

def on_movement_start(
    self,
    sim: 'Simulator'
)

Called when movement phase starts.

View Source
    def on_movement_start(self, sim: Simulator):

        """

        Called when movement phase starts.

        """

        pass

on_sim_end

def on_sim_end(
    self,
    sim: 'Simulator'
)

Called when the simulation ends.

View Source
    def on_sim_end(self, sim: Simulator):

        """

        Called when the simulation ends.

        """

        pass

on_sim_start

def on_sim_start(
    self,
    sim: wtracker.sim.simulator.Simulator
)

Called when the simulation starts.

View Source
    def on_sim_start(self, sim: Simulator):

        self._camera_bboxes.clear()

predict

def predict(
    self,
    frame_nums: Collection[int],
    relative: bool = True
) -> numpy.ndarray
View Source
    def predict(self, frame_nums: Collection[int], relative: bool = True) -> np.ndarray:

        assert len(frame_nums) > 0

        frame_nums = np.asanyarray(frame_nums, dtype=int)

        valid_mask = (frame_nums >= 0) & (frame_nums < self._csv_data.shape[0])

        worm_bboxes = np.full((frame_nums.shape[0], 4), np.nan)

        worm_bboxes[valid_mask] = self._csv_data[frame_nums[valid_mask], :]

        if not relative:

            return worm_bboxes

        # TODO: if relative == True then it works only if frame number if within the last cycle.

        # maybe fix that.

        cam_bboxes = [self._camera_bboxes[n % self.timing_config.cycle_frame_num] for n in frame_nums]

        cam_bboxes = np.asanyarray(cam_bboxes, dtype=float)

        # make bbox relative to camera view

        worm_bboxes[:, 0] -= cam_bboxes[:, 0]

        worm_bboxes[:, 1] -= cam_bboxes[:, 1]

        return worm_bboxes
def print_model(
    self
)
View Source
    def print_model(self):

        print(self.model)

provide_movement_vector

def provide_movement_vector(
    self,
    sim: wtracker.sim.simulator.Simulator
) -> tuple[int, int]

Provides the movement vector for the simulator. The platform is moved by the provided vector.

Returns:

Type Description
tuple[int, int] The movement vector in format (dx, dy).
The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction.
View Source
    def provide_movement_vector(self, sim: Simulator) -> tuple[int, int]:

        # frames for prediction (input to the model)

        frames_for_pred = np.asanyarray(self.io_config.input_frames, dtype=int)

        frames_for_pred += sim.frame_number - self.timing_config.pred_frame_num

        cam_center = BoxUtils.center(np.asanyarray(sim.view.camera_position))

        worm_bboxes = self.predict(frames_for_pred, relative=False).reshape(1, -1)

        if not np.isfinite(worm_bboxes).all():

            return 0, 0

        # relative position of the worm to the camera center, we use the worm x,y instead of center because of how the model and dataset are built

        rel_x, rel_y = worm_bboxes[0, 0] - cam_center[0], worm_bboxes[0, 1] - cam_center[1]

        # make coordinates relative to first bbox

        x = worm_bboxes[0, 0]

        x_mask = np.arange(0, worm_bboxes.shape[1]) % 4 == 0

        y = worm_bboxes[0, 1]

        y_mask = np.arange(0, worm_bboxes.shape[1]) % 4 == 1

        worm_bboxes[:, x_mask] -= x

        worm_bboxes[:, y_mask] -= y

        # predict the movement of the worm via the model

        pred = self.model.forward(Tensor(worm_bboxes)).flatten().detach().numpy()

        # make sure the prediction is within the limits and apply post-proccessing steps

        pred = np.clip(pred, -self.max_dist_per_pred, self.max_dist_per_pred)

        dx = round(pred[0].item() + rel_x)

        dy = round(pred[1].item() + rel_y)

        # dx = np.clip(dx, -self.max_dist_per_pred, self.max_dist_per_pred)

        # dy = np.clip(dy, -self.max_dist_per_pred, self.max_dist_per_pred)

        return (dx, dy)