Module wtracker.eval.plotter
View Source
from __future__ import annotations
import pandas as pd
import seaborn as sns
from typing import Callable
class Plotter:
"""
A class for plotting experiment log data.
The experiment data was previously analyzed by the DataAnalyzer class.
Supports analysis of multiple logs at once.
Args:
data_list (list[pd.DataFrame]): A list of dataframes, each holding the data of a single experiment log.
plot_height (int, optional): The height of the plot.
palette (str, optional): The color palette to use for the plots.
"""
def __init__(
self,
data_list: list[pd.DataFrame],
plot_height: int = 7,
palette: str = "viridis",
) -> None:
self.plot_height = plot_height
self.palette = palette
for i, data in enumerate(data_list):
data["log_num"] = i
self.data = pd.concat([d for d in data_list], ignore_index=True)
def _get_error_column(self, error_kind: str) -> str:
if error_kind == "bbox":
return "bbox_error"
elif error_kind == "dist":
return "worm_deviation"
elif error_kind == "precise":
return "precise_error"
else:
raise ValueError(f"Invalid error kind: {error_kind}")
def plot_speed(
self,
log_wise: bool = False,
condition: Callable[[pd.DataFrame], pd.DataFrame] = None,
plot_kind: str = "hist",
**kwargs,
) -> sns.FacetGrid:
"""
Plot the speed distribution of the worm.
Args:
log_wise (bool, optional): Whether to plot each log separately.
condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data.
plot_kind (str, optional): The kind of plot to create. Can be "hist", "kde", or "ecdf".
**kwargs: Additional keyword arguments to pass the `Plotter.create_distplot` function.
"""
return self.create_distplot(
x_col="wrm_speed",
x_label="speed",
title="Worm Speed Distribution",
plot_kind=plot_kind,
log_wise=log_wise,
condition=condition,
kde=True,
**kwargs,
)
def plot_error(
self,
error_kind: str = "bbox",
log_wise: bool = False,
cycle_wise: bool = False,
condition: Callable[[pd.DataFrame], pd.DataFrame] = None,
plot_kind: str = "hist",
**kwargs,
) -> sns.FacetGrid:
"""
Plot the error distribution.
Args:
error_kind (str, optional): The kind of error to plot. Can be "bbox", "dist", or "precise".
log_wise (bool, optional): Whether to plot each log separately.
cycle_wise (bool, optional): Whether to plot each cycle separately.
condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data.
plot_kind (str, optional): The kind of plot to create. Can be "hist", "kde", or "ecdf".
**kwargs: Additional keyword arguments to pass the `Plotter.create_distplot` function.
Returns:
sns.FacetGrid: The plot object.
"""
error_col = self._get_error_column(error_kind)
data = self.data
if cycle_wise:
data = self.data.groupby(["log_num", "cycle"])[error_col].max().reset_index()
return self.create_distplot(
x_col=error_col,
x_label=f"{error_kind} error",
title=f"{error_kind} Error Distribution",
plot_kind=plot_kind,
log_wise=log_wise,
condition=condition,
data=data,
**kwargs,
)
def plot_cycle_error(
self,
error_kind: str = "bbox",
log_wise: bool = False,
condition: Callable[[pd.DataFrame], pd.DataFrame] = None,
plot_kind: str = "boxen",
**kwargs,
) -> sns.JointGrid:
"""
Plot the error as a function of the cycle step.
Args:
error_kind (str, optional): The kind of error to plot. Can be "bbox", "dist", or "precise".
log_wise (bool, optional): Whether to plot each log separately.
condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data.
plot_kind (str, optional): The kind of plot to create. Can be "strip", "box", "violin", "boxen", "bar", or "count".
**kwargs: Additional keyword arguments to pass the `Plotter.create_catplot` function.
Returns:
sns.JointGrid: The plot object.
"""
error_col = self._get_error_column(error_kind)
return self.create_catplot(
x_col="cycle_step",
y_col=error_col,
x_label="cycle step",
y_label=f"{error_kind} error",
title=f"{error_kind} error as function of cycle step",
plot_kind=plot_kind,
log_wise=log_wise,
condition=condition,
**kwargs,
)
def plot_speed_vs_error(
self,
error_kind: str = "bbox",
cycle_wise: bool = False,
condition: Callable[[pd.DataFrame], pd.DataFrame] = None,
kind: str = "hist",
**kwargs,
) -> sns.JointGrid:
"""
Plot the speed of the worm vs the error.
Args:
error_kind (str, optional): The kind of error to plot. Can be "bbox", "dist", or "precise".
cycle_wise (bool, optional): Whether to plot each cycle separately.
condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data.
kind (str, optional): The kind of plot to create. Can be "scatter", "kde", "hist", "hex", "reg", or "resid".
**kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function.
Returns:
sns.JointGrid: The plot object.
"""
error_col = self._get_error_column(error_kind)
data = self.data
if cycle_wise:
data = (
self.data.groupby(["log_num", "cycle"])[[error_col, "wrm_speed"]]
.aggregate({error_col: "max", "wrm_speed": "mean"})
.reset_index()
)
return self.create_jointplot(
x_col="wrm_speed",
y_col=error_col,
plot_kind=kind,
x_label="speed",
y_label=f"{error_kind} error",
title=f"Speed vs {error_kind} Error",
condition=condition,
data=data,
**kwargs,
)
def plot_trajectory(
self,
hue_col="log_num",
condition: Callable[[pd.DataFrame], pd.DataFrame] = None,
**kwargs,
) -> sns.JointGrid:
"""
Plot the trajectory of the worm.
Args:
hue_col (str, optional): The column to use for coloring the plot.
condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data.
**kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function.
Returns:
sns.JointGrid: The plot object.
"""
plot = self.create_jointplot(
x_col="wrm_center_x",
y_col="wrm_center_y",
x_label="X",
y_label="Y",
title="Worm Trajectory",
hue_col=hue_col,
plot_kind="scatter",
alpha=1,
linewidth=0,
condition=condition,
**kwargs,
)
plot.ax_marg_x.remove()
plot.ax_marg_y.remove()
plot.ax_joint.invert_yaxis()
return plot
def plot_head_size(
self,
condition: Callable[[pd.DataFrame], pd.DataFrame] = None,
plot_kind: str = "hist",
**kwargs,
) -> sns.JointGrid:
"""
Plot the size of the worm head.
Args:
condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data.
plot_kind (str, optional): The kind of plot to create. Can be "scatter", "kde", "hist", "hex", "reg", or "resid".
**kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function.
Returns:
sns.JointGrid: The plot object.
"""
return self.create_jointplot(
x_col="wrm_w",
y_col="wrm_h",
x_label="width",
y_label="height",
title="Worm Head Size",
plot_kind=plot_kind,
condition=condition,
**kwargs,
)
def create_distplot(
self,
x_col: str,
y_col: str = None,
hue_col: str = None,
log_wise: bool = False,
plot_kind: str = "hist",
x_label: str = "",
y_label: str = "",
title: str | None = None,
condition: Callable[[pd.DataFrame], pd.DataFrame] = None,
transform: Callable[[pd.DataFrame], pd.DataFrame] = None,
data: pd.DataFrame = None,
**kwargs,
) -> sns.FacetGrid:
"""
Create a distribution plot from the data.
Args:
x_col (str): The column to plot on the x-axis.
y_col (str, optional): The column to plot on the y-axis.
hue_col (str, optional): The column to use for coloring the plot.
log_wise (bool, optional): Whether to plot each log separately.
plot_kind (str, optional): The kind of plot to create. Can be "hist", "kde", or "ecdf".
x_label (str, optional): The x-axis label.
y_label (str, optional): The y-axis label.
title (str, optional): The title of the plot.
condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data.
transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data.
data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used.
**kwargs: Additional keyword arguments to pass to the `seaborn.displot` function.
Returns:
sns.FacetGrid: The plot object.
"""
assert plot_kind in ["hist", "kde", "ecdf"]
if data is None:
data = self.data
if transform is not None:
data = transform(data)
if condition is not None:
data = data[condition(data)]
palette = self.palette if hue_col is not None else None
plot = sns.displot(
data=data.dropna(),
x=x_col,
y=y_col,
hue=hue_col,
col="log_num" if log_wise else None,
kind=plot_kind,
height=self.plot_height,
palette=palette,
**kwargs,
)
plot.set_xlabels(x_label.capitalize())
plot.set_ylabels(y_label.capitalize())
if title is not None:
if log_wise:
title = f"Log {{col_name}} :: {title.title()}"
plot.set_titles(title)
else:
plot.figure.suptitle(title.title())
plot.tight_layout()
return plot
def create_catplot(
self,
x_col: str,
y_col: str = None,
hue_col: str = None,
log_wise: bool = False,
plot_kind: str = "strip",
x_label: str = "",
y_label: str = "",
title: str | None = None,
condition: Callable[[pd.DataFrame], pd.DataFrame] = None,
transform: Callable[[pd.DataFrame], pd.DataFrame] = None,
data: pd.DataFrame = None,
**kwargs,
) -> sns.FacetGrid:
"""
Create a categorical plot from the data.
Args:
x_col (str): The column to plot on the x-axis.
y_col (str, optional): The column to plot on the y-axis.
hue_col (str, optional): The column to use for coloring the plot.
log_wise (bool, optional): Whether to plot each log separately.
plot_kind (str, optional): The kind of plot to create. Can be "strip", "box", "violin", "boxen", "bar", or "count".
x_label (str, optional): The x-axis label.
y_label (str, optional): The y-axis label.
title (str, optional): The title of the plot.
condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data.
transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data.
data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used.
**kwargs: Additional keyword arguments to pass to the `seaborn.catplot` function.
Returns:
sns.FacetGrid: The plot object.
"""
assert plot_kind in ["strip", "box", "violin", "boxen", "bar", "count"]
if data is None:
data = self.data
if transform is not None:
data = transform(data)
if condition is not None:
data = data[condition(data)]
palette = self.palette if hue_col is not None else None
plot = sns.catplot(
data=data.dropna(),
x=x_col,
y=y_col,
hue=hue_col,
col="log_num" if log_wise else None,
kind=plot_kind,
height=self.plot_height,
palette=palette,
**kwargs,
)
plot.set_xlabels(x_label.capitalize())
plot.set_ylabels(y_label.capitalize())
if title is not None:
if log_wise:
title = f"Log {{col_name}} :: {title.title()}"
plot.set_titles(title)
else:
plot.figure.suptitle(title.title())
plot.tight_layout()
return plot
def create_jointplot(
self,
x_col: str,
y_col: str,
hue_col: str = None,
plot_kind: str = "scatter",
x_label: str = "",
y_label: str = "",
title: str = "",
condition: Callable[[pd.DataFrame], pd.DataFrame] = None,
transform: Callable[[pd.DataFrame], pd.DataFrame] = None,
data: pd.DataFrame = None,
**kwargs,
) -> sns.JointGrid:
"""
Create a joint plot from the data.
Args:
x_col (str): The column to plot on the x-axis.
y_col (str): The column to plot on the y-axis.
hue_col (str, optional): The column to use for coloring the plot.
plot_kind (str, optional): The kind of plot to create. Can be "scatter", "kde", "hist", "hex", "reg", or "resid".
x_label (str, optional): The x-axis label.
y_label (str, optional): The y-axis label.
title (str, optional): The title of the plot.
condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data.
transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data.
data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used.
**kwargs: Additional keyword arguments to pass to the `seaborn.jointplot` function.
Returns:
sns.JointGrid: The plot object.
"""
assert plot_kind in ["scatter", "kde", "hist", "hex", "reg", "resid"]
if data is None:
data = self.data
if transform is not None:
data = transform(data)
if condition is not None:
data = data[condition(data)]
palette = self.palette if hue_col is not None else None
plot = sns.jointplot(
data=data.dropna(),
x=x_col,
y=y_col,
hue=hue_col,
kind=plot_kind,
height=self.plot_height,
palette=palette,
marginal_kws=dict(palette=palette),
**kwargs,
)
plot.set_axis_labels(x_label.capitalize(), y_label.capitalize())
plot.figure.suptitle(title.title())
plot.figure.tight_layout()
return plot
Classes
Plotter
class Plotter(
data_list: 'list[pd.DataFrame]',
plot_height: 'int' = 7,
palette: 'str' = 'viridis'
)
A class for plotting experiment log data.
The experiment data was previously analyzed by the DataAnalyzer class. Supports analysis of multiple logs at once.
Attributes
Name | Type | Description | Default |
---|---|---|---|
data_list | list[pd.DataFrame] | A list of dataframes, each holding the data of a single experiment log. | None |
plot_height | int | The height of the plot. | None |
palette | str | The color palette to use for the plots. | None |
View Source
class Plotter:
"""
A class for plotting experiment log data.
The experiment data was previously analyzed by the DataAnalyzer class.
Supports analysis of multiple logs at once.
Args:
data_list (list[pd.DataFrame]): A list of dataframes, each holding the data of a single experiment log.
plot_height (int, optional): The height of the plot.
palette (str, optional): The color palette to use for the plots.
"""
def __init__(
self,
data_list: list[pd.DataFrame],
plot_height: int = 7,
palette: str = "viridis",
) -> None:
self.plot_height = plot_height
self.palette = palette
for i, data in enumerate(data_list):
data["log_num"] = i
self.data = pd.concat([d for d in data_list], ignore_index=True)
def _get_error_column(self, error_kind: str) -> str:
if error_kind == "bbox":
return "bbox_error"
elif error_kind == "dist":
return "worm_deviation"
elif error_kind == "precise":
return "precise_error"
else:
raise ValueError(f"Invalid error kind: {error_kind}")
def plot_speed(
self,
log_wise: bool = False,
condition: Callable[[pd.DataFrame], pd.DataFrame] = None,
plot_kind: str = "hist",
**kwargs,
) -> sns.FacetGrid:
"""
Plot the speed distribution of the worm.
Args:
log_wise (bool, optional): Whether to plot each log separately.
condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data.
plot_kind (str, optional): The kind of plot to create. Can be "hist", "kde", or "ecdf".
**kwargs: Additional keyword arguments to pass the `Plotter.create_distplot` function.
"""
return self.create_distplot(
x_col="wrm_speed",
x_label="speed",
title="Worm Speed Distribution",
plot_kind=plot_kind,
log_wise=log_wise,
condition=condition,
kde=True,
**kwargs,
)
def plot_error(
self,
error_kind: str = "bbox",
log_wise: bool = False,
cycle_wise: bool = False,
condition: Callable[[pd.DataFrame], pd.DataFrame] = None,
plot_kind: str = "hist",
**kwargs,
) -> sns.FacetGrid:
"""
Plot the error distribution.
Args:
error_kind (str, optional): The kind of error to plot. Can be "bbox", "dist", or "precise".
log_wise (bool, optional): Whether to plot each log separately.
cycle_wise (bool, optional): Whether to plot each cycle separately.
condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data.
plot_kind (str, optional): The kind of plot to create. Can be "hist", "kde", or "ecdf".
**kwargs: Additional keyword arguments to pass the `Plotter.create_distplot` function.
Returns:
sns.FacetGrid: The plot object.
"""
error_col = self._get_error_column(error_kind)
data = self.data
if cycle_wise:
data = self.data.groupby(["log_num", "cycle"])[error_col].max().reset_index()
return self.create_distplot(
x_col=error_col,
x_label=f"{error_kind} error",
title=f"{error_kind} Error Distribution",
plot_kind=plot_kind,
log_wise=log_wise,
condition=condition,
data=data,
**kwargs,
)
def plot_cycle_error(
self,
error_kind: str = "bbox",
log_wise: bool = False,
condition: Callable[[pd.DataFrame], pd.DataFrame] = None,
plot_kind: str = "boxen",
**kwargs,
) -> sns.JointGrid:
"""
Plot the error as a function of the cycle step.
Args:
error_kind (str, optional): The kind of error to plot. Can be "bbox", "dist", or "precise".
log_wise (bool, optional): Whether to plot each log separately.
condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data.
plot_kind (str, optional): The kind of plot to create. Can be "strip", "box", "violin", "boxen", "bar", or "count".
**kwargs: Additional keyword arguments to pass the `Plotter.create_catplot` function.
Returns:
sns.JointGrid: The plot object.
"""
error_col = self._get_error_column(error_kind)
return self.create_catplot(
x_col="cycle_step",
y_col=error_col,
x_label="cycle step",
y_label=f"{error_kind} error",
title=f"{error_kind} error as function of cycle step",
plot_kind=plot_kind,
log_wise=log_wise,
condition=condition,
**kwargs,
)
def plot_speed_vs_error(
self,
error_kind: str = "bbox",
cycle_wise: bool = False,
condition: Callable[[pd.DataFrame], pd.DataFrame] = None,
kind: str = "hist",
**kwargs,
) -> sns.JointGrid:
"""
Plot the speed of the worm vs the error.
Args:
error_kind (str, optional): The kind of error to plot. Can be "bbox", "dist", or "precise".
cycle_wise (bool, optional): Whether to plot each cycle separately.
condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data.
kind (str, optional): The kind of plot to create. Can be "scatter", "kde", "hist", "hex", "reg", or "resid".
**kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function.
Returns:
sns.JointGrid: The plot object.
"""
error_col = self._get_error_column(error_kind)
data = self.data
if cycle_wise:
data = (
self.data.groupby(["log_num", "cycle"])[[error_col, "wrm_speed"]]
.aggregate({error_col: "max", "wrm_speed": "mean"})
.reset_index()
)
return self.create_jointplot(
x_col="wrm_speed",
y_col=error_col,
plot_kind=kind,
x_label="speed",
y_label=f"{error_kind} error",
title=f"Speed vs {error_kind} Error",
condition=condition,
data=data,
**kwargs,
)
def plot_trajectory(
self,
hue_col="log_num",
condition: Callable[[pd.DataFrame], pd.DataFrame] = None,
**kwargs,
) -> sns.JointGrid:
"""
Plot the trajectory of the worm.
Args:
hue_col (str, optional): The column to use for coloring the plot.
condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data.
**kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function.
Returns:
sns.JointGrid: The plot object.
"""
plot = self.create_jointplot(
x_col="wrm_center_x",
y_col="wrm_center_y",
x_label="X",
y_label="Y",
title="Worm Trajectory",
hue_col=hue_col,
plot_kind="scatter",
alpha=1,
linewidth=0,
condition=condition,
**kwargs,
)
plot.ax_marg_x.remove()
plot.ax_marg_y.remove()
plot.ax_joint.invert_yaxis()
return plot
def plot_head_size(
self,
condition: Callable[[pd.DataFrame], pd.DataFrame] = None,
plot_kind: str = "hist",
**kwargs,
) -> sns.JointGrid:
"""
Plot the size of the worm head.
Args:
condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data.
plot_kind (str, optional): The kind of plot to create. Can be "scatter", "kde", "hist", "hex", "reg", or "resid".
**kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function.
Returns:
sns.JointGrid: The plot object.
"""
return self.create_jointplot(
x_col="wrm_w",
y_col="wrm_h",
x_label="width",
y_label="height",
title="Worm Head Size",
plot_kind=plot_kind,
condition=condition,
**kwargs,
)
def create_distplot(
self,
x_col: str,
y_col: str = None,
hue_col: str = None,
log_wise: bool = False,
plot_kind: str = "hist",
x_label: str = "",
y_label: str = "",
title: str | None = None,
condition: Callable[[pd.DataFrame], pd.DataFrame] = None,
transform: Callable[[pd.DataFrame], pd.DataFrame] = None,
data: pd.DataFrame = None,
**kwargs,
) -> sns.FacetGrid:
"""
Create a distribution plot from the data.
Args:
x_col (str): The column to plot on the x-axis.
y_col (str, optional): The column to plot on the y-axis.
hue_col (str, optional): The column to use for coloring the plot.
log_wise (bool, optional): Whether to plot each log separately.
plot_kind (str, optional): The kind of plot to create. Can be "hist", "kde", or "ecdf".
x_label (str, optional): The x-axis label.
y_label (str, optional): The y-axis label.
title (str, optional): The title of the plot.
condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data.
transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data.
data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used.
**kwargs: Additional keyword arguments to pass to the `seaborn.displot` function.
Returns:
sns.FacetGrid: The plot object.
"""
assert plot_kind in ["hist", "kde", "ecdf"]
if data is None:
data = self.data
if transform is not None:
data = transform(data)
if condition is not None:
data = data[condition(data)]
palette = self.palette if hue_col is not None else None
plot = sns.displot(
data=data.dropna(),
x=x_col,
y=y_col,
hue=hue_col,
col="log_num" if log_wise else None,
kind=plot_kind,
height=self.plot_height,
palette=palette,
**kwargs,
)
plot.set_xlabels(x_label.capitalize())
plot.set_ylabels(y_label.capitalize())
if title is not None:
if log_wise:
title = f"Log {{col_name}} :: {title.title()}"
plot.set_titles(title)
else:
plot.figure.suptitle(title.title())
plot.tight_layout()
return plot
def create_catplot(
self,
x_col: str,
y_col: str = None,
hue_col: str = None,
log_wise: bool = False,
plot_kind: str = "strip",
x_label: str = "",
y_label: str = "",
title: str | None = None,
condition: Callable[[pd.DataFrame], pd.DataFrame] = None,
transform: Callable[[pd.DataFrame], pd.DataFrame] = None,
data: pd.DataFrame = None,
**kwargs,
) -> sns.FacetGrid:
"""
Create a categorical plot from the data.
Args:
x_col (str): The column to plot on the x-axis.
y_col (str, optional): The column to plot on the y-axis.
hue_col (str, optional): The column to use for coloring the plot.
log_wise (bool, optional): Whether to plot each log separately.
plot_kind (str, optional): The kind of plot to create. Can be "strip", "box", "violin", "boxen", "bar", or "count".
x_label (str, optional): The x-axis label.
y_label (str, optional): The y-axis label.
title (str, optional): The title of the plot.
condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data.
transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data.
data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used.
**kwargs: Additional keyword arguments to pass to the `seaborn.catplot` function.
Returns:
sns.FacetGrid: The plot object.
"""
assert plot_kind in ["strip", "box", "violin", "boxen", "bar", "count"]
if data is None:
data = self.data
if transform is not None:
data = transform(data)
if condition is not None:
data = data[condition(data)]
palette = self.palette if hue_col is not None else None
plot = sns.catplot(
data=data.dropna(),
x=x_col,
y=y_col,
hue=hue_col,
col="log_num" if log_wise else None,
kind=plot_kind,
height=self.plot_height,
palette=palette,
**kwargs,
)
plot.set_xlabels(x_label.capitalize())
plot.set_ylabels(y_label.capitalize())
if title is not None:
if log_wise:
title = f"Log {{col_name}} :: {title.title()}"
plot.set_titles(title)
else:
plot.figure.suptitle(title.title())
plot.tight_layout()
return plot
def create_jointplot(
self,
x_col: str,
y_col: str,
hue_col: str = None,
plot_kind: str = "scatter",
x_label: str = "",
y_label: str = "",
title: str = "",
condition: Callable[[pd.DataFrame], pd.DataFrame] = None,
transform: Callable[[pd.DataFrame], pd.DataFrame] = None,
data: pd.DataFrame = None,
**kwargs,
) -> sns.JointGrid:
"""
Create a joint plot from the data.
Args:
x_col (str): The column to plot on the x-axis.
y_col (str): The column to plot on the y-axis.
hue_col (str, optional): The column to use for coloring the plot.
plot_kind (str, optional): The kind of plot to create. Can be "scatter", "kde", "hist", "hex", "reg", or "resid".
x_label (str, optional): The x-axis label.
y_label (str, optional): The y-axis label.
title (str, optional): The title of the plot.
condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data.
transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data.
data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used.
**kwargs: Additional keyword arguments to pass to the `seaborn.jointplot` function.
Returns:
sns.JointGrid: The plot object.
"""
assert plot_kind in ["scatter", "kde", "hist", "hex", "reg", "resid"]
if data is None:
data = self.data
if transform is not None:
data = transform(data)
if condition is not None:
data = data[condition(data)]
palette = self.palette if hue_col is not None else None
plot = sns.jointplot(
data=data.dropna(),
x=x_col,
y=y_col,
hue=hue_col,
kind=plot_kind,
height=self.plot_height,
palette=palette,
marginal_kws=dict(palette=palette),
**kwargs,
)
plot.set_axis_labels(x_label.capitalize(), y_label.capitalize())
plot.figure.suptitle(title.title())
plot.figure.tight_layout()
return plot
Methods
create_catplot
def create_catplot(
self,
x_col: 'str',
y_col: 'str' = None,
hue_col: 'str' = None,
log_wise: 'bool' = False,
plot_kind: 'str' = 'strip',
x_label: 'str' = '',
y_label: 'str' = '',
title: 'str | None' = None,
condition: 'Callable[[pd.DataFrame], pd.DataFrame]' = None,
transform: 'Callable[[pd.DataFrame], pd.DataFrame]' = None,
data: 'pd.DataFrame' = None,
**kwargs
) -> 'sns.FacetGrid'
Create a categorical plot from the data.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x_col | str | The column to plot on the x-axis. | None |
y_col | str | The column to plot on the y-axis. | None |
hue_col | str | The column to use for coloring the plot. | None |
log_wise | bool | Whether to plot each log separately. | None |
plot_kind | str | The kind of plot to create. Can be "strip", "box", "violin", "boxen", "bar", or "count". | None |
x_label | str | The x-axis label. | None |
y_label | str | The y-axis label. | None |
title | str | The title of the plot. | None |
condition | Callable[[pd.DataFrame], pd.DataFrame] | A function to filter the data. | None |
transform | Callable[[pd.DataFrame], pd.DataFrame] | A function to transform the data. | None |
data | pd.DataFrame | Custom data to plot from. If None, the data passed to the constructor of the class is used. | None |
**kwargs | None | Additional keyword arguments to pass to the seaborn.catplot function. |
None |
Returns:
Type | Description |
---|---|
sns.FacetGrid | The plot object. |
View Source
def create_catplot(
self,
x_col: str,
y_col: str = None,
hue_col: str = None,
log_wise: bool = False,
plot_kind: str = "strip",
x_label: str = "",
y_label: str = "",
title: str | None = None,
condition: Callable[[pd.DataFrame], pd.DataFrame] = None,
transform: Callable[[pd.DataFrame], pd.DataFrame] = None,
data: pd.DataFrame = None,
**kwargs,
) -> sns.FacetGrid:
"""
Create a categorical plot from the data.
Args:
x_col (str): The column to plot on the x-axis.
y_col (str, optional): The column to plot on the y-axis.
hue_col (str, optional): The column to use for coloring the plot.
log_wise (bool, optional): Whether to plot each log separately.
plot_kind (str, optional): The kind of plot to create. Can be "strip", "box", "violin", "boxen", "bar", or "count".
x_label (str, optional): The x-axis label.
y_label (str, optional): The y-axis label.
title (str, optional): The title of the plot.
condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data.
transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data.
data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used.
**kwargs: Additional keyword arguments to pass to the `seaborn.catplot` function.
Returns:
sns.FacetGrid: The plot object.
"""
assert plot_kind in ["strip", "box", "violin", "boxen", "bar", "count"]
if data is None:
data = self.data
if transform is not None:
data = transform(data)
if condition is not None:
data = data[condition(data)]
palette = self.palette if hue_col is not None else None
plot = sns.catplot(
data=data.dropna(),
x=x_col,
y=y_col,
hue=hue_col,
col="log_num" if log_wise else None,
kind=plot_kind,
height=self.plot_height,
palette=palette,
**kwargs,
)
plot.set_xlabels(x_label.capitalize())
plot.set_ylabels(y_label.capitalize())
if title is not None:
if log_wise:
title = f"Log {{col_name}} :: {title.title()}"
plot.set_titles(title)
else:
plot.figure.suptitle(title.title())
plot.tight_layout()
return plot
create_distplot
def create_distplot(
self,
x_col: 'str',
y_col: 'str' = None,
hue_col: 'str' = None,
log_wise: 'bool' = False,
plot_kind: 'str' = 'hist',
x_label: 'str' = '',
y_label: 'str' = '',
title: 'str | None' = None,
condition: 'Callable[[pd.DataFrame], pd.DataFrame]' = None,
transform: 'Callable[[pd.DataFrame], pd.DataFrame]' = None,
data: 'pd.DataFrame' = None,
**kwargs
) -> 'sns.FacetGrid'
Create a distribution plot from the data.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x_col | str | The column to plot on the x-axis. | None |
y_col | str | The column to plot on the y-axis. | None |
hue_col | str | The column to use for coloring the plot. | None |
log_wise | bool | Whether to plot each log separately. | None |
plot_kind | str | The kind of plot to create. Can be "hist", "kde", or "ecdf". | None |
x_label | str | The x-axis label. | None |
y_label | str | The y-axis label. | None |
title | str | The title of the plot. | None |
condition | Callable[[pd.DataFrame], pd.DataFrame] | A function to filter the data. | None |
transform | Callable[[pd.DataFrame], pd.DataFrame] | A function to transform the data. | None |
data | pd.DataFrame | Custom data to plot from. If None, the data passed to the constructor of the class is used. | None |
**kwargs | None | Additional keyword arguments to pass to the seaborn.displot function. |
None |
Returns:
Type | Description |
---|---|
sns.FacetGrid | The plot object. |
View Source
def create_distplot(
self,
x_col: str,
y_col: str = None,
hue_col: str = None,
log_wise: bool = False,
plot_kind: str = "hist",
x_label: str = "",
y_label: str = "",
title: str | None = None,
condition: Callable[[pd.DataFrame], pd.DataFrame] = None,
transform: Callable[[pd.DataFrame], pd.DataFrame] = None,
data: pd.DataFrame = None,
**kwargs,
) -> sns.FacetGrid:
"""
Create a distribution plot from the data.
Args:
x_col (str): The column to plot on the x-axis.
y_col (str, optional): The column to plot on the y-axis.
hue_col (str, optional): The column to use for coloring the plot.
log_wise (bool, optional): Whether to plot each log separately.
plot_kind (str, optional): The kind of plot to create. Can be "hist", "kde", or "ecdf".
x_label (str, optional): The x-axis label.
y_label (str, optional): The y-axis label.
title (str, optional): The title of the plot.
condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data.
transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data.
data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used.
**kwargs: Additional keyword arguments to pass to the `seaborn.displot` function.
Returns:
sns.FacetGrid: The plot object.
"""
assert plot_kind in ["hist", "kde", "ecdf"]
if data is None:
data = self.data
if transform is not None:
data = transform(data)
if condition is not None:
data = data[condition(data)]
palette = self.palette if hue_col is not None else None
plot = sns.displot(
data=data.dropna(),
x=x_col,
y=y_col,
hue=hue_col,
col="log_num" if log_wise else None,
kind=plot_kind,
height=self.plot_height,
palette=palette,
**kwargs,
)
plot.set_xlabels(x_label.capitalize())
plot.set_ylabels(y_label.capitalize())
if title is not None:
if log_wise:
title = f"Log {{col_name}} :: {title.title()}"
plot.set_titles(title)
else:
plot.figure.suptitle(title.title())
plot.tight_layout()
return plot
create_jointplot
def create_jointplot(
self,
x_col: 'str',
y_col: 'str',
hue_col: 'str' = None,
plot_kind: 'str' = 'scatter',
x_label: 'str' = '',
y_label: 'str' = '',
title: 'str' = '',
condition: 'Callable[[pd.DataFrame], pd.DataFrame]' = None,
transform: 'Callable[[pd.DataFrame], pd.DataFrame]' = None,
data: 'pd.DataFrame' = None,
**kwargs
) -> 'sns.JointGrid'
Create a joint plot from the data.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x_col | str | The column to plot on the x-axis. | None |
y_col | str | The column to plot on the y-axis. | None |
hue_col | str | The column to use for coloring the plot. | None |
plot_kind | str | The kind of plot to create. Can be "scatter", "kde", "hist", "hex", "reg", or "resid". | None |
x_label | str | The x-axis label. | None |
y_label | str | The y-axis label. | None |
title | str | The title of the plot. | None |
condition | Callable[[pd.DataFrame], pd.DataFrame] | A function to filter the data. | None |
transform | Callable[[pd.DataFrame], pd.DataFrame] | A function to transform the data. | None |
data | pd.DataFrame | Custom data to plot from. If None, the data passed to the constructor of the class is used. | None |
**kwargs | None | Additional keyword arguments to pass to the seaborn.jointplot function. |
None |
Returns:
Type | Description |
---|---|
sns.JointGrid | The plot object. |
View Source
def create_jointplot(
self,
x_col: str,
y_col: str,
hue_col: str = None,
plot_kind: str = "scatter",
x_label: str = "",
y_label: str = "",
title: str = "",
condition: Callable[[pd.DataFrame], pd.DataFrame] = None,
transform: Callable[[pd.DataFrame], pd.DataFrame] = None,
data: pd.DataFrame = None,
**kwargs,
) -> sns.JointGrid:
"""
Create a joint plot from the data.
Args:
x_col (str): The column to plot on the x-axis.
y_col (str): The column to plot on the y-axis.
hue_col (str, optional): The column to use for coloring the plot.
plot_kind (str, optional): The kind of plot to create. Can be "scatter", "kde", "hist", "hex", "reg", or "resid".
x_label (str, optional): The x-axis label.
y_label (str, optional): The y-axis label.
title (str, optional): The title of the plot.
condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data.
transform (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to transform the data.
data (pd.DataFrame, optional): Custom data to plot from. If None, the data passed to the constructor of the class is used.
**kwargs: Additional keyword arguments to pass to the `seaborn.jointplot` function.
Returns:
sns.JointGrid: The plot object.
"""
assert plot_kind in ["scatter", "kde", "hist", "hex", "reg", "resid"]
if data is None:
data = self.data
if transform is not None:
data = transform(data)
if condition is not None:
data = data[condition(data)]
palette = self.palette if hue_col is not None else None
plot = sns.jointplot(
data=data.dropna(),
x=x_col,
y=y_col,
hue=hue_col,
kind=plot_kind,
height=self.plot_height,
palette=palette,
marginal_kws=dict(palette=palette),
**kwargs,
)
plot.set_axis_labels(x_label.capitalize(), y_label.capitalize())
plot.figure.suptitle(title.title())
plot.figure.tight_layout()
return plot
plot_cycle_error
def plot_cycle_error(
self,
error_kind: 'str' = 'bbox',
log_wise: 'bool' = False,
condition: 'Callable[[pd.DataFrame], pd.DataFrame]' = None,
plot_kind: 'str' = 'boxen',
**kwargs
) -> 'sns.JointGrid'
Plot the error as a function of the cycle step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
error_kind | str | The kind of error to plot. Can be "bbox", "dist", or "precise". | None |
log_wise | bool | Whether to plot each log separately. | None |
condition | Callable[[pd.DataFrame], pd.DataFrame] | A function to filter the data. | None |
plot_kind | str | The kind of plot to create. Can be "strip", "box", "violin", "boxen", "bar", or "count". | None |
**kwargs | None | Additional keyword arguments to pass the Plotter.create_catplot function. |
None |
Returns:
Type | Description |
---|---|
sns.JointGrid | The plot object. |
View Source
def plot_cycle_error(
self,
error_kind: str = "bbox",
log_wise: bool = False,
condition: Callable[[pd.DataFrame], pd.DataFrame] = None,
plot_kind: str = "boxen",
**kwargs,
) -> sns.JointGrid:
"""
Plot the error as a function of the cycle step.
Args:
error_kind (str, optional): The kind of error to plot. Can be "bbox", "dist", or "precise".
log_wise (bool, optional): Whether to plot each log separately.
condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data.
plot_kind (str, optional): The kind of plot to create. Can be "strip", "box", "violin", "boxen", "bar", or "count".
**kwargs: Additional keyword arguments to pass the `Plotter.create_catplot` function.
Returns:
sns.JointGrid: The plot object.
"""
error_col = self._get_error_column(error_kind)
return self.create_catplot(
x_col="cycle_step",
y_col=error_col,
x_label="cycle step",
y_label=f"{error_kind} error",
title=f"{error_kind} error as function of cycle step",
plot_kind=plot_kind,
log_wise=log_wise,
condition=condition,
**kwargs,
)
plot_error
def plot_error(
self,
error_kind: 'str' = 'bbox',
log_wise: 'bool' = False,
cycle_wise: 'bool' = False,
condition: 'Callable[[pd.DataFrame], pd.DataFrame]' = None,
plot_kind: 'str' = 'hist',
**kwargs
) -> 'sns.FacetGrid'
Plot the error distribution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
error_kind | str | The kind of error to plot. Can be "bbox", "dist", or "precise". | None |
log_wise | bool | Whether to plot each log separately. | None |
cycle_wise | bool | Whether to plot each cycle separately. | None |
condition | Callable[[pd.DataFrame], pd.DataFrame] | A function to filter the data. | None |
plot_kind | str | The kind of plot to create. Can be "hist", "kde", or "ecdf". | None |
**kwargs | None | Additional keyword arguments to pass the Plotter.create_distplot function. |
None |
Returns:
Type | Description |
---|---|
sns.FacetGrid | The plot object. |
View Source
def plot_error(
self,
error_kind: str = "bbox",
log_wise: bool = False,
cycle_wise: bool = False,
condition: Callable[[pd.DataFrame], pd.DataFrame] = None,
plot_kind: str = "hist",
**kwargs,
) -> sns.FacetGrid:
"""
Plot the error distribution.
Args:
error_kind (str, optional): The kind of error to plot. Can be "bbox", "dist", or "precise".
log_wise (bool, optional): Whether to plot each log separately.
cycle_wise (bool, optional): Whether to plot each cycle separately.
condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data.
plot_kind (str, optional): The kind of plot to create. Can be "hist", "kde", or "ecdf".
**kwargs: Additional keyword arguments to pass the `Plotter.create_distplot` function.
Returns:
sns.FacetGrid: The plot object.
"""
error_col = self._get_error_column(error_kind)
data = self.data
if cycle_wise:
data = self.data.groupby(["log_num", "cycle"])[error_col].max().reset_index()
return self.create_distplot(
x_col=error_col,
x_label=f"{error_kind} error",
title=f"{error_kind} Error Distribution",
plot_kind=plot_kind,
log_wise=log_wise,
condition=condition,
data=data,
**kwargs,
)
plot_head_size
def plot_head_size(
self,
condition: 'Callable[[pd.DataFrame], pd.DataFrame]' = None,
plot_kind: 'str' = 'hist',
**kwargs
) -> 'sns.JointGrid'
Plot the size of the worm head.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
condition | Callable[[pd.DataFrame], pd.DataFrame] | A function to filter the data. | None |
plot_kind | str | The kind of plot to create. Can be "scatter", "kde", "hist", "hex", "reg", or "resid". | None |
**kwargs | None | Additional keyword arguments to pass the Plotter.create_jointplot function. |
None |
Returns:
Type | Description |
---|---|
sns.JointGrid | The plot object. |
View Source
def plot_head_size(
self,
condition: Callable[[pd.DataFrame], pd.DataFrame] = None,
plot_kind: str = "hist",
**kwargs,
) -> sns.JointGrid:
"""
Plot the size of the worm head.
Args:
condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data.
plot_kind (str, optional): The kind of plot to create. Can be "scatter", "kde", "hist", "hex", "reg", or "resid".
**kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function.
Returns:
sns.JointGrid: The plot object.
"""
return self.create_jointplot(
x_col="wrm_w",
y_col="wrm_h",
x_label="width",
y_label="height",
title="Worm Head Size",
plot_kind=plot_kind,
condition=condition,
**kwargs,
)
plot_speed
def plot_speed(
self,
log_wise: 'bool' = False,
condition: 'Callable[[pd.DataFrame], pd.DataFrame]' = None,
plot_kind: 'str' = 'hist',
**kwargs
) -> 'sns.FacetGrid'
Plot the speed distribution of the worm.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
log_wise | bool | Whether to plot each log separately. | None |
condition | Callable[[pd.DataFrame], pd.DataFrame] | A function to filter the data. | None |
plot_kind | str | The kind of plot to create. Can be "hist", "kde", or "ecdf". | None |
**kwargs | None | Additional keyword arguments to pass the Plotter.create_distplot function. |
None |
View Source
def plot_speed(
self,
log_wise: bool = False,
condition: Callable[[pd.DataFrame], pd.DataFrame] = None,
plot_kind: str = "hist",
**kwargs,
) -> sns.FacetGrid:
"""
Plot the speed distribution of the worm.
Args:
log_wise (bool, optional): Whether to plot each log separately.
condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data.
plot_kind (str, optional): The kind of plot to create. Can be "hist", "kde", or "ecdf".
**kwargs: Additional keyword arguments to pass the `Plotter.create_distplot` function.
"""
return self.create_distplot(
x_col="wrm_speed",
x_label="speed",
title="Worm Speed Distribution",
plot_kind=plot_kind,
log_wise=log_wise,
condition=condition,
kde=True,
**kwargs,
)
plot_speed_vs_error
def plot_speed_vs_error(
self,
error_kind: 'str' = 'bbox',
cycle_wise: 'bool' = False,
condition: 'Callable[[pd.DataFrame], pd.DataFrame]' = None,
kind: 'str' = 'hist',
**kwargs
) -> 'sns.JointGrid'
Plot the speed of the worm vs the error.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
error_kind | str | The kind of error to plot. Can be "bbox", "dist", or "precise". | None |
cycle_wise | bool | Whether to plot each cycle separately. | None |
condition | Callable[[pd.DataFrame], pd.DataFrame] | A function to filter the data. | None |
kind | str | The kind of plot to create. Can be "scatter", "kde", "hist", "hex", "reg", or "resid". | None |
**kwargs | None | Additional keyword arguments to pass the Plotter.create_jointplot function. |
None |
Returns:
Type | Description |
---|---|
sns.JointGrid | The plot object. |
View Source
def plot_speed_vs_error(
self,
error_kind: str = "bbox",
cycle_wise: bool = False,
condition: Callable[[pd.DataFrame], pd.DataFrame] = None,
kind: str = "hist",
**kwargs,
) -> sns.JointGrid:
"""
Plot the speed of the worm vs the error.
Args:
error_kind (str, optional): The kind of error to plot. Can be "bbox", "dist", or "precise".
cycle_wise (bool, optional): Whether to plot each cycle separately.
condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data.
kind (str, optional): The kind of plot to create. Can be "scatter", "kde", "hist", "hex", "reg", or "resid".
**kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function.
Returns:
sns.JointGrid: The plot object.
"""
error_col = self._get_error_column(error_kind)
data = self.data
if cycle_wise:
data = (
self.data.groupby(["log_num", "cycle"])[[error_col, "wrm_speed"]]
.aggregate({error_col: "max", "wrm_speed": "mean"})
.reset_index()
)
return self.create_jointplot(
x_col="wrm_speed",
y_col=error_col,
plot_kind=kind,
x_label="speed",
y_label=f"{error_kind} error",
title=f"Speed vs {error_kind} Error",
condition=condition,
data=data,
**kwargs,
)
plot_trajectory
def plot_trajectory(
self,
hue_col='log_num',
condition: 'Callable[[pd.DataFrame], pd.DataFrame]' = None,
**kwargs
) -> 'sns.JointGrid'
Plot the trajectory of the worm.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
hue_col | str | The column to use for coloring the plot. | None |
condition | Callable[[pd.DataFrame], pd.DataFrame] | A function to filter the data. | None |
**kwargs | None | Additional keyword arguments to pass the Plotter.create_jointplot function. |
None |
Returns:
Type | Description |
---|---|
sns.JointGrid | The plot object. |
View Source
def plot_trajectory(
self,
hue_col="log_num",
condition: Callable[[pd.DataFrame], pd.DataFrame] = None,
**kwargs,
) -> sns.JointGrid:
"""
Plot the trajectory of the worm.
Args:
hue_col (str, optional): The column to use for coloring the plot.
condition (Callable[[pd.DataFrame], pd.DataFrame], optional): A function to filter the data.
**kwargs: Additional keyword arguments to pass the `Plotter.create_jointplot` function.
Returns:
sns.JointGrid: The plot object.
"""
plot = self.create_jointplot(
x_col="wrm_center_x",
y_col="wrm_center_y",
x_label="X",
y_label="Y",
title="Worm Trajectory",
hue_col=hue_col,
plot_kind="scatter",
alpha=1,
linewidth=0,
condition=condition,
**kwargs,
)
plot.ax_marg_x.remove()
plot.ax_marg_y.remove()
plot.ax_joint.invert_yaxis()
return plot