Module wtracker.sim.sim_controllers.yolo_controller
View Source
from typing import Collection, Any
from dataclasses import dataclass, field
import numpy as np
import cv2 as cv
from collections import deque
from ultralytics import YOLO
from wtracker.sim.simulator import Simulator, SimController
from wtracker.sim.config import TimingConfig
from wtracker.utils.config_base import ConfigBase
from wtracker.utils.bbox_utils import BoxUtils, BoxConverter, BoxFormat
@dataclass
class YoloConfig(ConfigBase):
model_path: str
"""The path to the pretrained YOLO weights file."""
device: str = "cpu"
"""Inference device for YOLO. Can be either 'cpu' or 'cuda'."""
verbose: bool = False
"""Whether to print verbose output during YOLO inference."""
pred_kwargs: dict = field(
default_factory=lambda: {
"imgsz": 384,
"conf": 0.1,
}
)
"""Additional keyword arguments for the YOLO prediction method."""
model: YOLO = field(default=None, init=False, repr=False)
"""The YOLO model object."""
def __getstate__(self) -> dict[str, Any]:
state = self.__dict__.copy()
del state["model"] # we dont want to serialize the model
return state
def load_model(self) -> YOLO:
if self.model is None:
self.model = YOLO(self.model_path, task="detect", verbose=self.verbose)
return self.model
class YoloController(SimController):
def __init__(self, timing_config: TimingConfig, yolo_config: YoloConfig):
super().__init__(timing_config)
self.yolo_config = yolo_config
self._camera_frames = deque(maxlen=timing_config.cycle_frame_num)
self._model = yolo_config.load_model()
def on_sim_start(self, sim: Simulator):
self._camera_frames.clear()
def on_camera_frame(self, sim: Simulator):
self._camera_frames.append(sim.camera_view())
def on_cycle_end(self, sim: Simulator):
self._camera_frames.clear()
def predict(self, frames: Collection[np.ndarray]) -> np.ndarray:
assert len(frames) > 0
# convert grayscale images to BGR because YOLO expects 3-channel images
if frames[0].ndim == 2:
frames = [cv.cvtColor(frame, cv.COLOR_GRAY2BGR) for frame in frames]
# predict bounding boxes and format results
results = self._model.predict(
source=frames,
device=self.yolo_config.device,
max_det=1,
verbose=self.yolo_config.verbose,
**self.yolo_config.pred_kwargs,
)
results = [res.numpy() for res in results]
bboxes = []
for res in results:
if len(res.boxes.xyxy) == 0:
bboxes.append(np.full([4], np.nan))
else:
bbox = BoxConverter.to_xywh(res.boxes.xyxy[0], BoxFormat.XYXY)
bboxes.append(bbox)
return np.stack(bboxes, axis=0)
def begin_movement_prediction(self, sim: Simulator) -> None:
pass
def provide_movement_vector(self, sim: Simulator) -> tuple[int, int]:
frame = self._camera_frames[-self.timing_config.pred_frame_num]
bbox = self.predict([frame])
bbox = bbox[0]
if not np.isfinite(bbox).all():
return 0, 0
bbox_mid = bbox[0] + bbox[2] / 2, bbox[1] + bbox[3] / 2
camera_mid = sim.view.camera_size[0] / 2, sim.view.camera_size[1] / 2
return round(bbox_mid[0] - camera_mid[0]), round(bbox_mid[1] - camera_mid[1])
def _cycle_predict_all(self, sim: Simulator) -> np.ndarray:
return self.predict(self._camera_frames)
Classes
YoloConfig
class YoloConfig(
model_path: str,
device: str = 'cpu',
verbose: bool = False,
pred_kwargs: dict = <factory>
)
YoloConfig(model_path: str, device: str = 'cpu', verbose: bool = False, pred_kwargs: dict =
View Source
@dataclass
class YoloConfig(ConfigBase):
model_path: str
"""The path to the pretrained YOLO weights file."""
device: str = "cpu"
"""Inference device for YOLO. Can be either 'cpu' or 'cuda'."""
verbose: bool = False
"""Whether to print verbose output during YOLO inference."""
pred_kwargs: dict = field(
default_factory=lambda: {
"imgsz": 384,
"conf": 0.1,
}
)
"""Additional keyword arguments for the YOLO prediction method."""
model: YOLO = field(default=None, init=False, repr=False)
"""The YOLO model object."""
def __getstate__(self) -> dict[str, Any]:
state = self.__dict__.copy()
del state["model"] # we dont want to serialize the model
return state
def load_model(self) -> YOLO:
if self.model is None:
self.model = YOLO(self.model_path, task="detect", verbose=self.verbose)
return self.model
Ancestors (in MRO)
- wtracker.utils.config_base.ConfigBase
Class variables
device
model
verbose
Static methods
load_json
def load_json(
path: 'str' = None
) -> 'T'
Load the class from a JSON file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path | str | The path to the JSON file. | None |
Returns:
Type | Description |
---|---|
ConfigBase | The class loaded from the JSON file. |
View Source
@classmethod
def load_json(cls: type[T], path: str = None) -> T:
"""
Load the class from a JSON file.
Args:
path (str): The path to the JSON file.
Returns:
ConfigBase: The class loaded from the JSON file.
"""
if path is None:
path = UserPrompt.open_file(
title=f"Open {cls.__name__} File",
file_types=[("json", ".json")],
)
with open(path, "r") as f:
data = json.load(f)
obj = cls.__new__(cls)
obj.__dict__.update(data)
return obj
load_pickle
def load_pickle(
path: 'str' = None
) -> 'T'
Load the class from a pickle file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path | str | The path to the pickle file. | None |
Returns:
Type | Description |
---|---|
None | The class loaded from the pickle file. |
View Source
@classmethod
def load_pickle(cls: type[T], path: str = None) -> T:
"""
Load the class from a pickle file.
Args:
path (str): The path to the pickle file.
Returns:
The class loaded from the pickle file.
"""
if path is None:
path = UserPrompt.open_file(
title=f"Open {cls.__name__} File",
file_types=[("pickle", ".pkl")],
)
return pickle_load_object(path)
Methods
load_model
def load_model(
self
) -> ultralytics.models.yolo.model.YOLO
View Source
def load_model(self) -> YOLO:
if self.model is None:
self.model = YOLO(self.model_path, task="detect", verbose=self.verbose)
return self.model
save_json
def save_json(
self,
path: 'str' = None
)
Saves the class as JSON file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path | str | The path to the output JSON file. | None |
View Source
def save_json(self, path: str = None):
"""
Saves the class as JSON file.
Args:
path (str): The path to the output JSON file.
"""
if path is None:
path = UserPrompt.save_file(
title=f"Save {type(self).__name__} As",
file_types=[("json", ".json")],
defaultextension=".json",
)
with open(path, "w") as f:
json.dump(self.__dict__, f, indent=4)
save_pickle
def save_pickle(
self,
path: 'str' = None
) -> 'None'
Saves the class as a pickle file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path | str | The path to the output pickle file. | None |
View Source
def save_pickle(self, path: str = None) -> None:
"""
Saves the class as a pickle file.
Args:
path (str): The path to the output pickle file.
"""
if path is None:
path = UserPrompt.save_file(
title=f"Save {type(self).__name__} As",
file_types=[("pickle", ".pkl")],
defaultextension=".pkl",
)
pickle_save_object(self, path)
YoloController
class YoloController(
timing_config: wtracker.sim.config.TimingConfig,
yolo_config: wtracker.sim.sim_controllers.yolo_controller.YoloConfig
)
Abstract base class for simulator controllers.
Attributes
Name | Type | Description | Default |
---|---|---|---|
timing_config | TimingConfig | The timing configuration for the simulator. | None |
View Source
class YoloController(SimController):
def __init__(self, timing_config: TimingConfig, yolo_config: YoloConfig):
super().__init__(timing_config)
self.yolo_config = yolo_config
self._camera_frames = deque(maxlen=timing_config.cycle_frame_num)
self._model = yolo_config.load_model()
def on_sim_start(self, sim: Simulator):
self._camera_frames.clear()
def on_camera_frame(self, sim: Simulator):
self._camera_frames.append(sim.camera_view())
def on_cycle_end(self, sim: Simulator):
self._camera_frames.clear()
def predict(self, frames: Collection[np.ndarray]) -> np.ndarray:
assert len(frames) > 0
# convert grayscale images to BGR because YOLO expects 3-channel images
if frames[0].ndim == 2:
frames = [cv.cvtColor(frame, cv.COLOR_GRAY2BGR) for frame in frames]
# predict bounding boxes and format results
results = self._model.predict(
source=frames,
device=self.yolo_config.device,
max_det=1,
verbose=self.yolo_config.verbose,
**self.yolo_config.pred_kwargs,
)
results = [res.numpy() for res in results]
bboxes = []
for res in results:
if len(res.boxes.xyxy) == 0:
bboxes.append(np.full([4], np.nan))
else:
bbox = BoxConverter.to_xywh(res.boxes.xyxy[0], BoxFormat.XYXY)
bboxes.append(bbox)
return np.stack(bboxes, axis=0)
def begin_movement_prediction(self, sim: Simulator) -> None:
pass
def provide_movement_vector(self, sim: Simulator) -> tuple[int, int]:
frame = self._camera_frames[-self.timing_config.pred_frame_num]
bbox = self.predict([frame])
bbox = bbox[0]
if not np.isfinite(bbox).all():
return 0, 0
bbox_mid = bbox[0] + bbox[2] / 2, bbox[1] + bbox[3] / 2
camera_mid = sim.view.camera_size[0] / 2, sim.view.camera_size[1] / 2
return round(bbox_mid[0] - camera_mid[0]), round(bbox_mid[1] - camera_mid[1])
def _cycle_predict_all(self, sim: Simulator) -> np.ndarray:
return self.predict(self._camera_frames)
Ancestors (in MRO)
- wtracker.sim.simulator.SimController
- abc.ABC
Methods
begin_movement_prediction
def begin_movement_prediction(
self,
sim: wtracker.sim.simulator.Simulator
) -> None
Called when the movement prediction begins.
View Source
def begin_movement_prediction(self, sim: Simulator) -> None:
pass
on_camera_frame
def on_camera_frame(
self,
sim: wtracker.sim.simulator.Simulator
)
Called when a camera frame is captured. Happens every frame.
View Source
def on_camera_frame(self, sim: Simulator):
self._camera_frames.append(sim.camera_view())
on_cycle_end
def on_cycle_end(
self,
sim: wtracker.sim.simulator.Simulator
)
Called when a cycle ends.
View Source
def on_cycle_end(self, sim: Simulator):
self._camera_frames.clear()
on_cycle_start
def on_cycle_start(
self,
sim: 'Simulator'
)
Called when a new cycle starts.
View Source
def on_cycle_start(self, sim: Simulator):
"""
Called when a new cycle starts.
"""
pass
on_imaging_end
def on_imaging_end(
self,
sim: 'Simulator'
)
Called when imaging phase ends.
View Source
def on_imaging_end(self, sim: Simulator):
"""
Called when imaging phase ends.
"""
pass
on_imaging_start
def on_imaging_start(
self,
sim: 'Simulator'
)
Called when imaging phase starts.
View Source
def on_imaging_start(self, sim: Simulator):
"""
Called when imaging phase starts.
"""
pass
on_micro_frame
def on_micro_frame(
self,
sim: 'Simulator'
)
Called when a micro frame is captured. Happens for every during the imaging phase.
View Source
def on_micro_frame(self, sim: Simulator):
"""
Called when a micro frame is captured. Happens for every during the imaging phase.
"""
pass
on_movement_end
def on_movement_end(
self,
sim: 'Simulator'
)
Called when movement phase ends.
View Source
def on_movement_end(self, sim: Simulator):
"""
Called when movement phase ends.
"""
pass
on_movement_start
def on_movement_start(
self,
sim: 'Simulator'
)
Called when movement phase starts.
View Source
def on_movement_start(self, sim: Simulator):
"""
Called when movement phase starts.
"""
pass
on_sim_end
def on_sim_end(
self,
sim: 'Simulator'
)
Called when the simulation ends.
View Source
def on_sim_end(self, sim: Simulator):
"""
Called when the simulation ends.
"""
pass
on_sim_start
def on_sim_start(
self,
sim: wtracker.sim.simulator.Simulator
)
Called when the simulation starts.
View Source
def on_sim_start(self, sim: Simulator):
self._camera_frames.clear()
predict
def predict(
self,
frames: Collection[numpy.ndarray]
) -> numpy.ndarray
View Source
def predict(self, frames: Collection[np.ndarray]) -> np.ndarray:
assert len(frames) > 0
# convert grayscale images to BGR because YOLO expects 3-channel images
if frames[0].ndim == 2:
frames = [cv.cvtColor(frame, cv.COLOR_GRAY2BGR) for frame in frames]
# predict bounding boxes and format results
results = self._model.predict(
source=frames,
device=self.yolo_config.device,
max_det=1,
verbose=self.yolo_config.verbose,
**self.yolo_config.pred_kwargs,
)
results = [res.numpy() for res in results]
bboxes = []
for res in results:
if len(res.boxes.xyxy) == 0:
bboxes.append(np.full([4], np.nan))
else:
bbox = BoxConverter.to_xywh(res.boxes.xyxy[0], BoxFormat.XYXY)
bboxes.append(bbox)
return np.stack(bboxes, axis=0)
provide_movement_vector
def provide_movement_vector(
self,
sim: wtracker.sim.simulator.Simulator
) -> tuple[int, int]
Provides the movement vector for the simulator. The platform is moved by the provided vector.
Returns:
Type | Description |
---|---|
tuple[int, int] | The movement vector in format (dx, dy). The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction. |
View Source
def provide_movement_vector(self, sim: Simulator) -> tuple[int, int]:
frame = self._camera_frames[-self.timing_config.pred_frame_num]
bbox = self.predict([frame])
bbox = bbox[0]
if not np.isfinite(bbox).all():
return 0, 0
bbox_mid = bbox[0] + bbox[2] / 2, bbox[1] + bbox[3] / 2
camera_mid = sim.view.camera_size[0] / 2, sim.view.camera_size[1] / 2
return round(bbox_mid[0] - camera_mid[0]), round(bbox_mid[1] - camera_mid[1])