Module wtracker.sim.sim_controllers.csv_controller
View Source
from collections import deque
from typing import Collection
import pandas as pd
import numpy as np
from wtracker.sim.config import TimingConfig
from wtracker.sim.simulator import SimController, Simulator
from wtracker.utils.bbox_utils import BoxUtils, BoxConverter, BoxFormat
class CsvController(SimController):
def __init__(self, timing_config: TimingConfig, csv_path: str):
super().__init__(timing_config)
self.csv_path = csv_path
self._csv_data = pd.read_csv(self.csv_path, usecols=["wrm_x", "wrm_y", "wrm_w", "wrm_h"]).to_numpy(dtype=float)
self._camera_bboxes = deque(maxlen=timing_config.cycle_frame_num)
def on_sim_start(self, sim: Simulator):
self._camera_bboxes.clear()
def on_camera_frame(self, sim: Simulator):
self._camera_bboxes.append(sim.view.camera_position)
def predict(self, frame_nums: Collection[int], relative: bool = True) -> np.ndarray:
assert len(frame_nums) > 0
frame_nums = np.asanyarray(frame_nums, dtype=int)
valid_mask = (frame_nums >= 0) & (frame_nums < self._csv_data.shape[0])
worm_bboxes = np.full((frame_nums.shape[0], 4), np.nan)
worm_bboxes[valid_mask] = self._csv_data[frame_nums[valid_mask], :]
if not relative:
return worm_bboxes
# TODO: if relative == True then it works only if frame number if within the last cycle.
# maybe fix that.
cam_bboxes = [self._camera_bboxes[n % self.timing_config.cycle_frame_num] for n in frame_nums]
cam_bboxes = np.asanyarray(cam_bboxes, dtype=float)
# make bbox relative to camera view
worm_bboxes[:, 0] -= cam_bboxes[:, 0]
worm_bboxes[:, 1] -= cam_bboxes[:, 1]
return worm_bboxes
def begin_movement_prediction(self, sim: Simulator) -> None:
pass
def provide_movement_vector(self, sim: Simulator) -> tuple[int, int]:
bbox = self.predict([sim.frame_number - self.timing_config.pred_frame_num])
bbox = bbox[0, :]
if not np.isfinite(bbox).all():
return 0, 0
center = BoxUtils.center(bbox)
cam_center = sim.view.camera_size[0] / 2, sim.view.camera_size[1] / 2
dx = round(center[0] - cam_center[0])
dy = round(center[1] - cam_center[1])
return dx, dy
def _cycle_predict_all(self, sim: Simulator) -> np.ndarray:
start = (sim.cycle_number - 1) * self.timing_config.cycle_frame_num
end = start + self.timing_config.cycle_frame_num
end = min(end, len(self._csv_data))
return self.predict(np.arange(start, end))
Classes
CsvController
class CsvController(
timing_config: wtracker.sim.config.TimingConfig,
csv_path: str
)
Abstract base class for simulator controllers.
Attributes
Name | Type | Description | Default |
---|---|---|---|
timing_config | TimingConfig | The timing configuration for the simulator. | None |
View Source
class CsvController(SimController):
def __init__(self, timing_config: TimingConfig, csv_path: str):
super().__init__(timing_config)
self.csv_path = csv_path
self._csv_data = pd.read_csv(self.csv_path, usecols=["wrm_x", "wrm_y", "wrm_w", "wrm_h"]).to_numpy(dtype=float)
self._camera_bboxes = deque(maxlen=timing_config.cycle_frame_num)
def on_sim_start(self, sim: Simulator):
self._camera_bboxes.clear()
def on_camera_frame(self, sim: Simulator):
self._camera_bboxes.append(sim.view.camera_position)
def predict(self, frame_nums: Collection[int], relative: bool = True) -> np.ndarray:
assert len(frame_nums) > 0
frame_nums = np.asanyarray(frame_nums, dtype=int)
valid_mask = (frame_nums >= 0) & (frame_nums < self._csv_data.shape[0])
worm_bboxes = np.full((frame_nums.shape[0], 4), np.nan)
worm_bboxes[valid_mask] = self._csv_data[frame_nums[valid_mask], :]
if not relative:
return worm_bboxes
# TODO: if relative == True then it works only if frame number if within the last cycle.
# maybe fix that.
cam_bboxes = [self._camera_bboxes[n % self.timing_config.cycle_frame_num] for n in frame_nums]
cam_bboxes = np.asanyarray(cam_bboxes, dtype=float)
# make bbox relative to camera view
worm_bboxes[:, 0] -= cam_bboxes[:, 0]
worm_bboxes[:, 1] -= cam_bboxes[:, 1]
return worm_bboxes
def begin_movement_prediction(self, sim: Simulator) -> None:
pass
def provide_movement_vector(self, sim: Simulator) -> tuple[int, int]:
bbox = self.predict([sim.frame_number - self.timing_config.pred_frame_num])
bbox = bbox[0, :]
if not np.isfinite(bbox).all():
return 0, 0
center = BoxUtils.center(bbox)
cam_center = sim.view.camera_size[0] / 2, sim.view.camera_size[1] / 2
dx = round(center[0] - cam_center[0])
dy = round(center[1] - cam_center[1])
return dx, dy
def _cycle_predict_all(self, sim: Simulator) -> np.ndarray:
start = (sim.cycle_number - 1) * self.timing_config.cycle_frame_num
end = start + self.timing_config.cycle_frame_num
end = min(end, len(self._csv_data))
return self.predict(np.arange(start, end))
Ancestors (in MRO)
- wtracker.sim.simulator.SimController
- abc.ABC
Descendants
- wtracker.sim.sim_controllers.mlp_controllers.MLPController
- wtracker.sim.sim_controllers.optimal_controller.OptimalController
- wtracker.sim.sim_controllers.polyfit_controller.PolyfitController
Methods
begin_movement_prediction
def begin_movement_prediction(
self,
sim: wtracker.sim.simulator.Simulator
) -> None
Called when the movement prediction begins.
View Source
def begin_movement_prediction(self, sim: Simulator) -> None:
pass
on_camera_frame
def on_camera_frame(
self,
sim: wtracker.sim.simulator.Simulator
)
Called when a camera frame is captured. Happens every frame.
View Source
def on_camera_frame(self, sim: Simulator):
self._camera_bboxes.append(sim.view.camera_position)
on_cycle_end
def on_cycle_end(
self,
sim: 'Simulator'
)
Called when a cycle ends.
View Source
def on_cycle_end(self, sim: Simulator):
"""
Called when a cycle ends.
"""
pass
on_cycle_start
def on_cycle_start(
self,
sim: 'Simulator'
)
Called when a new cycle starts.
View Source
def on_cycle_start(self, sim: Simulator):
"""
Called when a new cycle starts.
"""
pass
on_imaging_end
def on_imaging_end(
self,
sim: 'Simulator'
)
Called when imaging phase ends.
View Source
def on_imaging_end(self, sim: Simulator):
"""
Called when imaging phase ends.
"""
pass
on_imaging_start
def on_imaging_start(
self,
sim: 'Simulator'
)
Called when imaging phase starts.
View Source
def on_imaging_start(self, sim: Simulator):
"""
Called when imaging phase starts.
"""
pass
on_micro_frame
def on_micro_frame(
self,
sim: 'Simulator'
)
Called when a micro frame is captured. Happens for every during the imaging phase.
View Source
def on_micro_frame(self, sim: Simulator):
"""
Called when a micro frame is captured. Happens for every during the imaging phase.
"""
pass
on_movement_end
def on_movement_end(
self,
sim: 'Simulator'
)
Called when movement phase ends.
View Source
def on_movement_end(self, sim: Simulator):
"""
Called when movement phase ends.
"""
pass
on_movement_start
def on_movement_start(
self,
sim: 'Simulator'
)
Called when movement phase starts.
View Source
def on_movement_start(self, sim: Simulator):
"""
Called when movement phase starts.
"""
pass
on_sim_end
def on_sim_end(
self,
sim: 'Simulator'
)
Called when the simulation ends.
View Source
def on_sim_end(self, sim: Simulator):
"""
Called when the simulation ends.
"""
pass
on_sim_start
def on_sim_start(
self,
sim: wtracker.sim.simulator.Simulator
)
Called when the simulation starts.
View Source
def on_sim_start(self, sim: Simulator):
self._camera_bboxes.clear()
predict
def predict(
self,
frame_nums: Collection[int],
relative: bool = True
) -> numpy.ndarray
View Source
def predict(self, frame_nums: Collection[int], relative: bool = True) -> np.ndarray:
assert len(frame_nums) > 0
frame_nums = np.asanyarray(frame_nums, dtype=int)
valid_mask = (frame_nums >= 0) & (frame_nums < self._csv_data.shape[0])
worm_bboxes = np.full((frame_nums.shape[0], 4), np.nan)
worm_bboxes[valid_mask] = self._csv_data[frame_nums[valid_mask], :]
if not relative:
return worm_bboxes
# TODO: if relative == True then it works only if frame number if within the last cycle.
# maybe fix that.
cam_bboxes = [self._camera_bboxes[n % self.timing_config.cycle_frame_num] for n in frame_nums]
cam_bboxes = np.asanyarray(cam_bboxes, dtype=float)
# make bbox relative to camera view
worm_bboxes[:, 0] -= cam_bboxes[:, 0]
worm_bboxes[:, 1] -= cam_bboxes[:, 1]
return worm_bboxes
provide_movement_vector
def provide_movement_vector(
self,
sim: wtracker.sim.simulator.Simulator
) -> tuple[int, int]
Provides the movement vector for the simulator. The platform is moved by the provided vector.
Returns:
Type | Description |
---|---|
tuple[int, int] | The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. |
View Source
def provide_movement_vector(self, sim: Simulator) -> tuple[int, int]:
bbox = self.predict([sim.frame_number - self.timing_config.pred_frame_num])
bbox = bbox[0, :]
if not np.isfinite(bbox).all():
return 0, 0
center = BoxUtils.center(bbox)
cam_center = sim.view.camera_size[0] / 2, sim.view.camera_size[1] / 2
dx = round(center[0] - cam_center[0])
dy = round(center[1] - cam_center[1])
return dx, dy