Skip to content

Module wtracker.sim.sim_controllers.yolo_controller

View Source
from typing import Collection, Any

from dataclasses import dataclass, field

import numpy as np

import cv2 as cv

from collections import deque

from ultralytics import YOLO

from wtracker.sim.simulator import Simulator, SimController

from wtracker.sim.config import TimingConfig

from wtracker.utils.config_base import ConfigBase

from wtracker.utils.bbox_utils import BoxUtils, BoxConverter, BoxFormat

@dataclass

class YoloConfig(ConfigBase):

    model_path: str

    """The path to the pretrained YOLO weights file."""

    device: str = "cpu"

    """Inference device for YOLO. Can be either 'cpu' or 'cuda'."""

    verbose: bool = False

    """Whether to print verbose output during YOLO inference."""

    pred_kwargs: dict = field(

        default_factory=lambda: {

            "imgsz": 384,

            "conf": 0.1,

        }

    )

    """Additional keyword arguments for the YOLO prediction method."""

    model: YOLO = field(default=None, init=False, repr=False)

    """The YOLO model object."""

    def __getstate__(self) -> dict[str, Any]:

        state = self.__dict__.copy()

        del state["model"]  # we dont want to serialize the model

        return state

    def load_model(self) -> YOLO:

        if self.model is None:

            self.model = YOLO(self.model_path, task="detect", verbose=self.verbose)

        return self.model

class YoloController(SimController):

    def __init__(self, timing_config: TimingConfig, yolo_config: YoloConfig):

        super().__init__(timing_config)

        self.yolo_config = yolo_config

        self._camera_frames = deque(maxlen=timing_config.cycle_frame_num)

        self._model = yolo_config.load_model()

    def on_sim_start(self, sim: Simulator):

        self._camera_frames.clear()

    def on_camera_frame(self, sim: Simulator):

        self._camera_frames.append(sim.camera_view())

    def on_cycle_end(self, sim: Simulator):

        self._camera_frames.clear()

    def predict(self, frames: Collection[np.ndarray]) -> np.ndarray:

        assert len(frames) > 0

        # convert grayscale images to BGR because YOLO expects 3-channel images

        if frames[0].ndim == 2:

            frames = [cv.cvtColor(frame, cv.COLOR_GRAY2BGR) for frame in frames]

        # predict bounding boxes and format results

        results = self._model.predict(

            source=frames,

            device=self.yolo_config.device,

            max_det=1,

            verbose=self.yolo_config.verbose,

            **self.yolo_config.pred_kwargs,

        )

        results = [res.numpy() for res in results]

        bboxes = []

        for res in results:

            if len(res.boxes.xyxy) == 0:

                bboxes.append(np.full([4], np.nan))

            else:

                bbox = BoxConverter.to_xywh(res.boxes.xyxy[0], BoxFormat.XYXY)

                bboxes.append(bbox)

        return np.stack(bboxes, axis=0)

    def begin_movement_prediction(self, sim: Simulator) -> None:

        pass

    def provide_movement_vector(self, sim: Simulator) -> tuple[int, int]:

        frame = self._camera_frames[-self.timing_config.pred_frame_num]

        bbox = self.predict([frame])

        bbox = bbox[0]

        if not np.isfinite(bbox).all():

            return 0, 0

        bbox_mid = bbox[0] + bbox[2] / 2, bbox[1] + bbox[3] / 2

        camera_mid = sim.view.camera_size[0] / 2, sim.view.camera_size[1] / 2

        return round(bbox_mid[0] - camera_mid[0]), round(bbox_mid[1] - camera_mid[1])

    def _cycle_predict_all(self, sim: Simulator) -> np.ndarray:

        return self.predict(self._camera_frames)

Classes

YoloConfig

class YoloConfig(
    model_path: str,
    device: str = 'cpu',
    verbose: bool = False,
    pred_kwargs: dict = <factory>
)

YoloConfig(model_path: str, device: str = 'cpu', verbose: bool = False, pred_kwargs: dict = )

View Source
@dataclass

class YoloConfig(ConfigBase):

    model_path: str

    """The path to the pretrained YOLO weights file."""

    device: str = "cpu"

    """Inference device for YOLO. Can be either 'cpu' or 'cuda'."""

    verbose: bool = False

    """Whether to print verbose output during YOLO inference."""

    pred_kwargs: dict = field(

        default_factory=lambda: {

            "imgsz": 384,

            "conf": 0.1,

        }

    )

    """Additional keyword arguments for the YOLO prediction method."""

    model: YOLO = field(default=None, init=False, repr=False)

    """The YOLO model object."""

    def __getstate__(self) -> dict[str, Any]:

        state = self.__dict__.copy()

        del state["model"]  # we dont want to serialize the model

        return state

    def load_model(self) -> YOLO:

        if self.model is None:

            self.model = YOLO(self.model_path, task="detect", verbose=self.verbose)

        return self.model

Ancestors (in MRO)

  • wtracker.utils.config_base.ConfigBase

Class variables

device
model
verbose

Static methods

load_json

def load_json(
    path: 'str' = None
) -> 'T'

Load the class from a JSON file.

Parameters:

Name Type Description Default
path str The path to the JSON file. None

Returns:

Type Description
ConfigBase The class loaded from the JSON file.
View Source
    @classmethod

    def load_json(cls: type[T], path: str = None) -> T:

        """

        Load the class from a JSON file.

        Args:

            path (str): The path to the JSON file.

        Returns:

            ConfigBase: The class loaded from the JSON file.

        """

        if path is None:

            path = UserPrompt.open_file(

                title=f"Open {cls.__name__} File",

                file_types=[("json", ".json")],

            )

        with open(path, "r") as f:

            data = json.load(f)

        obj = cls.__new__(cls)

        obj.__dict__.update(data)

        return obj

load_pickle

def load_pickle(
    path: 'str' = None
) -> 'T'

Load the class from a pickle file.

Parameters:

Name Type Description Default
path str The path to the pickle file. None

Returns:

Type Description
None The class loaded from the pickle file.
View Source
    @classmethod

    def load_pickle(cls: type[T], path: str = None) -> T:

        """

        Load the class from a pickle file.

        Args:

            path (str): The path to the pickle file.

        Returns:

            The class loaded from the pickle file.

        """

        if path is None:

            path = UserPrompt.open_file(

                title=f"Open {cls.__name__} File",

                file_types=[("pickle", ".pkl")],

            )

        return pickle_load_object(path)

Methods

load_model

def load_model(
    self
) -> ultralytics.models.yolo.model.YOLO
View Source
    def load_model(self) -> YOLO:

        if self.model is None:

            self.model = YOLO(self.model_path, task="detect", verbose=self.verbose)

        return self.model

save_json

def save_json(
    self,
    path: 'str' = None
)

Saves the class as JSON file.

Parameters:

Name Type Description Default
path str The path to the output JSON file. None
View Source
    def save_json(self, path: str = None):

        """

        Saves the class as JSON file.

        Args:

            path (str): The path to the output JSON file.

        """

        if path is None:

            path = UserPrompt.save_file(

                title=f"Save {type(self).__name__} As",

                file_types=[("json", ".json")],

                defaultextension=".json",

            )

        with open(path, "w") as f:

            json.dump(self.__dict__, f, indent=4)

save_pickle

def save_pickle(
    self,
    path: 'str' = None
) -> 'None'

Saves the class as a pickle file.

Parameters:

Name Type Description Default
path str The path to the output pickle file. None
View Source
    def save_pickle(self, path: str = None) -> None:

        """

        Saves the class as a pickle file.

        Args:

            path (str): The path to the output pickle file.

        """

        if path is None:

            path = UserPrompt.save_file(

                title=f"Save {type(self).__name__} As",

                file_types=[("pickle", ".pkl")],

                defaultextension=".pkl",

            )

        pickle_save_object(self, path)

YoloController

class YoloController(
    timing_config: wtracker.sim.config.TimingConfig,
    yolo_config: wtracker.sim.sim_controllers.yolo_controller.YoloConfig
)

Abstract base class for simulator controllers.

Attributes

Name Type Description Default
timing_config TimingConfig The timing configuration for the simulator. None
View Source
class YoloController(SimController):

    def __init__(self, timing_config: TimingConfig, yolo_config: YoloConfig):

        super().__init__(timing_config)

        self.yolo_config = yolo_config

        self._camera_frames = deque(maxlen=timing_config.cycle_frame_num)

        self._model = yolo_config.load_model()

    def on_sim_start(self, sim: Simulator):

        self._camera_frames.clear()

    def on_camera_frame(self, sim: Simulator):

        self._camera_frames.append(sim.camera_view())

    def on_cycle_end(self, sim: Simulator):

        self._camera_frames.clear()

    def predict(self, frames: Collection[np.ndarray]) -> np.ndarray:

        assert len(frames) > 0

        # convert grayscale images to BGR because YOLO expects 3-channel images

        if frames[0].ndim == 2:

            frames = [cv.cvtColor(frame, cv.COLOR_GRAY2BGR) for frame in frames]

        # predict bounding boxes and format results

        results = self._model.predict(

            source=frames,

            device=self.yolo_config.device,

            max_det=1,

            verbose=self.yolo_config.verbose,

            **self.yolo_config.pred_kwargs,

        )

        results = [res.numpy() for res in results]

        bboxes = []

        for res in results:

            if len(res.boxes.xyxy) == 0:

                bboxes.append(np.full([4], np.nan))

            else:

                bbox = BoxConverter.to_xywh(res.boxes.xyxy[0], BoxFormat.XYXY)

                bboxes.append(bbox)

        return np.stack(bboxes, axis=0)

    def begin_movement_prediction(self, sim: Simulator) -> None:

        pass

    def provide_movement_vector(self, sim: Simulator) -> tuple[int, int]:

        frame = self._camera_frames[-self.timing_config.pred_frame_num]

        bbox = self.predict([frame])

        bbox = bbox[0]

        if not np.isfinite(bbox).all():

            return 0, 0

        bbox_mid = bbox[0] + bbox[2] / 2, bbox[1] + bbox[3] / 2

        camera_mid = sim.view.camera_size[0] / 2, sim.view.camera_size[1] / 2

        return round(bbox_mid[0] - camera_mid[0]), round(bbox_mid[1] - camera_mid[1])

    def _cycle_predict_all(self, sim: Simulator) -> np.ndarray:

        return self.predict(self._camera_frames)

Ancestors (in MRO)

  • wtracker.sim.simulator.SimController
  • abc.ABC

Methods

begin_movement_prediction

def begin_movement_prediction(
    self,
    sim: wtracker.sim.simulator.Simulator
) -> None

Called when the movement prediction begins.

View Source
    def begin_movement_prediction(self, sim: Simulator) -> None:

        pass

on_camera_frame

def on_camera_frame(
    self,
    sim: wtracker.sim.simulator.Simulator
)

Called when a camera frame is captured. Happens every frame.

View Source
    def on_camera_frame(self, sim: Simulator):

        self._camera_frames.append(sim.camera_view())

on_cycle_end

def on_cycle_end(
    self,
    sim: wtracker.sim.simulator.Simulator
)

Called when a cycle ends.

View Source
    def on_cycle_end(self, sim: Simulator):

        self._camera_frames.clear()

on_cycle_start

def on_cycle_start(
    self,
    sim: 'Simulator'
)

Called when a new cycle starts.

View Source
    def on_cycle_start(self, sim: Simulator):

        """

        Called when a new cycle starts.

        """

        pass

on_imaging_end

def on_imaging_end(
    self,
    sim: 'Simulator'
)

Called when imaging phase ends.

View Source
    def on_imaging_end(self, sim: Simulator):

        """

        Called when imaging phase ends.

        """

        pass

on_imaging_start

def on_imaging_start(
    self,
    sim: 'Simulator'
)

Called when imaging phase starts.

View Source
    def on_imaging_start(self, sim: Simulator):

        """

        Called when imaging phase starts.

        """

        pass

on_micro_frame

def on_micro_frame(
    self,
    sim: 'Simulator'
)

Called when a micro frame is captured. Happens for every during the imaging phase.

View Source
    def on_micro_frame(self, sim: Simulator):

        """

        Called when a micro frame is captured. Happens for every during the imaging phase.

        """

        pass

on_movement_end

def on_movement_end(
    self,
    sim: 'Simulator'
)

Called when movement phase ends.

View Source
    def on_movement_end(self, sim: Simulator):

        """

        Called when movement phase ends.

        """

        pass

on_movement_start

def on_movement_start(
    self,
    sim: 'Simulator'
)

Called when movement phase starts.

View Source
    def on_movement_start(self, sim: Simulator):

        """

        Called when movement phase starts.

        """

        pass

on_sim_end

def on_sim_end(
    self,
    sim: 'Simulator'
)

Called when the simulation ends.

View Source
    def on_sim_end(self, sim: Simulator):

        """

        Called when the simulation ends.

        """

        pass

on_sim_start

def on_sim_start(
    self,
    sim: wtracker.sim.simulator.Simulator
)

Called when the simulation starts.

View Source
    def on_sim_start(self, sim: Simulator):

        self._camera_frames.clear()

predict

def predict(
    self,
    frames: Collection[numpy.ndarray]
) -> numpy.ndarray
View Source
    def predict(self, frames: Collection[np.ndarray]) -> np.ndarray:

        assert len(frames) > 0

        # convert grayscale images to BGR because YOLO expects 3-channel images

        if frames[0].ndim == 2:

            frames = [cv.cvtColor(frame, cv.COLOR_GRAY2BGR) for frame in frames]

        # predict bounding boxes and format results

        results = self._model.predict(

            source=frames,

            device=self.yolo_config.device,

            max_det=1,

            verbose=self.yolo_config.verbose,

            **self.yolo_config.pred_kwargs,

        )

        results = [res.numpy() for res in results]

        bboxes = []

        for res in results:

            if len(res.boxes.xyxy) == 0:

                bboxes.append(np.full([4], np.nan))

            else:

                bbox = BoxConverter.to_xywh(res.boxes.xyxy[0], BoxFormat.XYXY)

                bboxes.append(bbox)

        return np.stack(bboxes, axis=0)

provide_movement_vector

def provide_movement_vector(
    self,
    sim: wtracker.sim.simulator.Simulator
) -> tuple[int, int]

Provides the movement vector for the simulator. The platform is moved by the provided vector.

Returns:

Type Description
tuple[int, int] The movement vector in format (dx, dy).
The platform will be moved by dx pixels in the x-direction and dy pixels in the y-direction.
View Source
    def provide_movement_vector(self, sim: Simulator) -> tuple[int, int]:

        frame = self._camera_frames[-self.timing_config.pred_frame_num]

        bbox = self.predict([frame])

        bbox = bbox[0]

        if not np.isfinite(bbox).all():

            return 0, 0

        bbox_mid = bbox[0] + bbox[2] / 2, bbox[1] + bbox[3] / 2

        camera_mid = sim.view.camera_size[0] / 2, sim.view.camera_size[1] / 2

        return round(bbox_mid[0] - camera_mid[0]), round(bbox_mid[1] - camera_mid[1])