Skip to content

Module wtracker.sim.sim_controllers.csv_controller

View Source
from collections import deque

from typing import Collection

import pandas as pd

import numpy as np

from wtracker.sim.config import TimingConfig

from wtracker.sim.simulator import SimController, Simulator

from wtracker.utils.bbox_utils import BoxUtils, BoxConverter, BoxFormat

class CsvController(SimController):

    def __init__(self, timing_config: TimingConfig, csv_path: str):

        super().__init__(timing_config)

        self.csv_path = csv_path

        self._csv_data = pd.read_csv(self.csv_path, usecols=["wrm_x", "wrm_y", "wrm_w", "wrm_h"]).to_numpy(dtype=float)

        self._camera_bboxes = deque(maxlen=timing_config.cycle_frame_num)

    def on_sim_start(self, sim: Simulator):

        self._camera_bboxes.clear()

    def on_camera_frame(self, sim: Simulator):

        self._camera_bboxes.append(sim.view.camera_position)

    def predict(self, frame_nums: Collection[int], relative: bool = True) -> np.ndarray:

        assert len(frame_nums) > 0

        frame_nums = np.asanyarray(frame_nums, dtype=int)

        valid_mask = (frame_nums >= 0) & (frame_nums < self._csv_data.shape[0])

        worm_bboxes = np.full((frame_nums.shape[0], 4), np.nan)

        worm_bboxes[valid_mask] = self._csv_data[frame_nums[valid_mask], :]

        if not relative:

            return worm_bboxes

        # TODO: if relative == True then it works only if frame number if within the last cycle.

        # maybe fix that.

        cam_bboxes = [self._camera_bboxes[n % self.timing_config.cycle_frame_num] for n in frame_nums]

        cam_bboxes = np.asanyarray(cam_bboxes, dtype=float)

        # make bbox relative to camera view

        worm_bboxes[:, 0] -= cam_bboxes[:, 0]

        worm_bboxes[:, 1] -= cam_bboxes[:, 1]

        return worm_bboxes

    def begin_movement_prediction(self, sim: Simulator) -> None:

        pass

    def provide_movement_vector(self, sim: Simulator) -> tuple[int, int]:

        bbox = self.predict([sim.frame_number - self.timing_config.pred_frame_num])

        bbox = bbox[0, :]

        if not np.isfinite(bbox).all():

            return 0, 0

        center = BoxUtils.center(bbox)

        cam_center = sim.view.camera_size[0] / 2, sim.view.camera_size[1] / 2

        dx = round(center[0] - cam_center[0])

        dy = round(center[1] - cam_center[1])

        return dx, dy

    def _cycle_predict_all(self, sim: Simulator) -> np.ndarray:

        start = (sim.cycle_number - 1) * self.timing_config.cycle_frame_num

        end = start + self.timing_config.cycle_frame_num

        end = min(end, len(self._csv_data))

        return self.predict(np.arange(start, end))

Classes

CsvController

class CsvController(
    timing_config: wtracker.sim.config.TimingConfig,
    csv_path: str
)

Abstract base class for simulator controllers.

Attributes

Name Type Description Default
timing_config TimingConfig The timing configuration for the simulator. None
View Source
class CsvController(SimController):

    def __init__(self, timing_config: TimingConfig, csv_path: str):

        super().__init__(timing_config)

        self.csv_path = csv_path

        self._csv_data = pd.read_csv(self.csv_path, usecols=["wrm_x", "wrm_y", "wrm_w", "wrm_h"]).to_numpy(dtype=float)

        self._camera_bboxes = deque(maxlen=timing_config.cycle_frame_num)

    def on_sim_start(self, sim: Simulator):

        self._camera_bboxes.clear()

    def on_camera_frame(self, sim: Simulator):

        self._camera_bboxes.append(sim.view.camera_position)

    def predict(self, frame_nums: Collection[int], relative: bool = True) -> np.ndarray:

        assert len(frame_nums) > 0

        frame_nums = np.asanyarray(frame_nums, dtype=int)

        valid_mask = (frame_nums >= 0) & (frame_nums < self._csv_data.shape[0])

        worm_bboxes = np.full((frame_nums.shape[0], 4), np.nan)

        worm_bboxes[valid_mask] = self._csv_data[frame_nums[valid_mask], :]

        if not relative:

            return worm_bboxes

        # TODO: if relative == True then it works only if frame number if within the last cycle.

        # maybe fix that.

        cam_bboxes = [self._camera_bboxes[n % self.timing_config.cycle_frame_num] for n in frame_nums]

        cam_bboxes = np.asanyarray(cam_bboxes, dtype=float)

        # make bbox relative to camera view

        worm_bboxes[:, 0] -= cam_bboxes[:, 0]

        worm_bboxes[:, 1] -= cam_bboxes[:, 1]

        return worm_bboxes

    def begin_movement_prediction(self, sim: Simulator) -> None:

        pass

    def provide_movement_vector(self, sim: Simulator) -> tuple[int, int]:

        bbox = self.predict([sim.frame_number - self.timing_config.pred_frame_num])

        bbox = bbox[0, :]

        if not np.isfinite(bbox).all():

            return 0, 0

        center = BoxUtils.center(bbox)

        cam_center = sim.view.camera_size[0] / 2, sim.view.camera_size[1] / 2

        dx = round(center[0] - cam_center[0])

        dy = round(center[1] - cam_center[1])

        return dx, dy

    def _cycle_predict_all(self, sim: Simulator) -> np.ndarray:

        start = (sim.cycle_number - 1) * self.timing_config.cycle_frame_num

        end = start + self.timing_config.cycle_frame_num

        end = min(end, len(self._csv_data))

        return self.predict(np.arange(start, end))

Ancestors (in MRO)

  • wtracker.sim.simulator.SimController
  • abc.ABC

Descendants

  • wtracker.sim.sim_controllers.mlp_controllers.MLPController
  • wtracker.sim.sim_controllers.optimal_controller.OptimalController
  • wtracker.sim.sim_controllers.polyfit_controller.PolyfitController

Methods

begin_movement_prediction

def begin_movement_prediction(
    self,
    sim: wtracker.sim.simulator.Simulator
) -> None

Called when the movement prediction begins.

View Source
    def begin_movement_prediction(self, sim: Simulator) -> None:

        pass

on_camera_frame

def on_camera_frame(
    self,
    sim: wtracker.sim.simulator.Simulator
)

Called when a camera frame is captured. Happens every frame.

View Source
    def on_camera_frame(self, sim: Simulator):

        self._camera_bboxes.append(sim.view.camera_position)

on_cycle_end

def on_cycle_end(
    self,
    sim: 'Simulator'
)

Called when a cycle ends.

View Source
    def on_cycle_end(self, sim: Simulator):

        """

        Called when a cycle ends.

        """

        pass

on_cycle_start

def on_cycle_start(
    self,
    sim: 'Simulator'
)

Called when a new cycle starts.

View Source
    def on_cycle_start(self, sim: Simulator):

        """

        Called when a new cycle starts.

        """

        pass

on_imaging_end

def on_imaging_end(
    self,
    sim: 'Simulator'
)

Called when imaging phase ends.

View Source
    def on_imaging_end(self, sim: Simulator):

        """

        Called when imaging phase ends.

        """

        pass

on_imaging_start

def on_imaging_start(
    self,
    sim: 'Simulator'
)

Called when imaging phase starts.

View Source
    def on_imaging_start(self, sim: Simulator):

        """

        Called when imaging phase starts.

        """

        pass

on_micro_frame

def on_micro_frame(
    self,
    sim: 'Simulator'
)

Called when a micro frame is captured. Happens for every during the imaging phase.

View Source
    def on_micro_frame(self, sim: Simulator):

        """

        Called when a micro frame is captured. Happens for every during the imaging phase.

        """

        pass

on_movement_end

def on_movement_end(
    self,
    sim: 'Simulator'
)

Called when movement phase ends.

View Source
    def on_movement_end(self, sim: Simulator):

        """

        Called when movement phase ends.

        """

        pass

on_movement_start

def on_movement_start(
    self,
    sim: 'Simulator'
)

Called when movement phase starts.

View Source
    def on_movement_start(self, sim: Simulator):

        """

        Called when movement phase starts.

        """

        pass

on_sim_end

def on_sim_end(
    self,
    sim: 'Simulator'
)

Called when the simulation ends.

View Source
    def on_sim_end(self, sim: Simulator):

        """

        Called when the simulation ends.

        """

        pass

on_sim_start

def on_sim_start(
    self,
    sim: wtracker.sim.simulator.Simulator
)

Called when the simulation starts.

View Source
    def on_sim_start(self, sim: Simulator):

        self._camera_bboxes.clear()

predict

def predict(
    self,
    frame_nums: Collection[int],
    relative: bool = True
) -> numpy.ndarray
View Source
    def predict(self, frame_nums: Collection[int], relative: bool = True) -> np.ndarray:

        assert len(frame_nums) > 0

        frame_nums = np.asanyarray(frame_nums, dtype=int)

        valid_mask = (frame_nums >= 0) & (frame_nums < self._csv_data.shape[0])

        worm_bboxes = np.full((frame_nums.shape[0], 4), np.nan)

        worm_bboxes[valid_mask] = self._csv_data[frame_nums[valid_mask], :]

        if not relative:

            return worm_bboxes

        # TODO: if relative == True then it works only if frame number if within the last cycle.

        # maybe fix that.

        cam_bboxes = [self._camera_bboxes[n % self.timing_config.cycle_frame_num] for n in frame_nums]

        cam_bboxes = np.asanyarray(cam_bboxes, dtype=float)

        # make bbox relative to camera view

        worm_bboxes[:, 0] -= cam_bboxes[:, 0]

        worm_bboxes[:, 1] -= cam_bboxes[:, 1]

        return worm_bboxes

provide_movement_vector

def provide_movement_vector(
    self,
    sim: wtracker.sim.simulator.Simulator
) -> tuple[int, int]

Provides the movement vector for the simulator. The platform is moved by the provided vector.

Returns:

Type Description
tuple[int, int] The movement vector in format (dx, dy).
The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction.
View Source
    def provide_movement_vector(self, sim: Simulator) -> tuple[int, int]:

        bbox = self.predict([sim.frame_number - self.timing_config.pred_frame_num])

        bbox = bbox[0, :]

        if not np.isfinite(bbox).all():

            return 0, 0

        center = BoxUtils.center(bbox)

        cam_center = sim.view.camera_size[0] / 2, sim.view.camera_size[1] / 2

        dx = round(center[0] - cam_center[0])

        dy = round(center[1] - cam_center[1])

        return dx, dy