import os
import tempfile
import typing as t
from collections import OrderedDict
import dash_bootstrap_components as dbc
import numpy as np
import plotly.graph_objects as go
from anndata import AnnData
from Bio import Phylo
from dash import Dash, State, dcc, html
from dash.dependencies import Input, Output
from dash.development.base_component import Component
from plotly.express.colors import sample_colorscale
from cellarium.cas.logging import logger
from cellarium.cas.postprocessing import (
CAS_CL_SCORES_ANNDATA_OBSM_KEY,
CAS_METADATA_ANNDATA_UNS_KEY,
CellOntologyScoresAggregationDomain,
CellOntologyScoresAggregationOp,
convert_aggregated_cell_ontology_scores_to_rooted_tree,
generate_phyloxml_from_scored_cell_ontology_tree,
get_aggregated_cas_ontology_aware_scores,
get_obs_indices_for_cluster,
)
from cellarium.cas.postprocessing.cell_ontology import CL_CELL_ROOT_NODE, CellOntologyCache
from cellarium.cas.visualization._components.circular_tree_plot import CircularTreePlot
from cellarium.cas.visualization.ui_utils import ConfigValue, find_and_kill_process
# cell type ontology terms (and all descendents) to hide from the visualization
DEFAULT_HIDDEN_CL_NAMES_SET = {}
# header component ID -> default title mapping for the panels
DEFAULT_PANEL_TITLES = {
"cell-selection-title-tree": "Cell Type Ontology View",
"cell-selection-title-umap": "UMAP View",
}
class DomainSelectionConstants:
NONE = 0
USER_SELECTION = 1
SEPARATOR = 2
# cell type ontology terms to always show as text labels in the visualization
DEFAULT_SHOWN_CL_NAMES_SET = {
"CL_0000236",
"CL_0000084",
"CL_0000789",
"CL_0000798",
"CL_0002420",
"CL_0002419",
"CL_0000786",
"CL_0000576",
"CL_0001065",
"CL_0000451",
"CL_0000094",
"CL_0000235",
"CL_0000097",
"CL_0000814",
"CL_0000827",
"CL_0000066",
"CL_0000163",
"CL_0000151",
"CL_0000064",
"CL_0000322",
"CL_0000076",
"CL_0005006",
"CL_0000148",
"CL_0000646",
"CL_0009004",
"CL_0000115",
"CL_0000125",
"CL_0002319",
"CL_0000187",
"CL_0000057",
"CL_0008034",
"CL_0000092",
"CL_0000058",
"CL_0000060",
"CL_0000136",
"CL_0000499",
"CL_0000222",
"CL_0007005",
"CL_0000039",
"CL_0000019",
"CL_0000223",
"CL_0008019",
"CL_0005026",
"CL_0000182",
"CL_0000023",
"CL_0000679",
"CL_0000126",
"CL_0000540",
"CL_0000127",
"CL_0011005",
}
[docs]
class CASCircularTreePlotUMAPDashApp:
"""
A Dash app for visualizing the results of a Cellarium CAS cell type ontology-aware analysis.
:param adata: The AnnData object containing the cell type ontology-aware analysis results.
:param cas_ontology_aware_response: The response from the Cellarium CAS cell type ontology-aware analysis. |br|
`Default:` ``None``
:param cluster_label_obs_column: The name of the observation column containing the cluster labels. |br|
`Default:` ``None``
:param aggregation_op: The aggregation operation to apply to the cell type ontology-aware scores. |br|
`Default:` ``CellOntologyScoresAggregationOp.MEAN``
:param aggregation_domain: The domain over which to aggregate the cell type ontology-aware scores. |br|
`Default:` ``CellOntologyScoresAggregationDomain.OVER_THRESHOLD``
:param score_threshold: The threshold for the cell type ontology-aware scores. |br|
`Default:` ``0.05``
:param min_cell_fraction: The minimum fraction of cells that must have a cell type ontology-aware score above the threshold. |br|
`Default:` ``0.01``
:param umap_marker_size: The size of the markers in the UMAP scatter plot. |br|
`Default:` ``3.0``
:param umap_padding: The padding to apply to the UMAP scatter plot bounds. |br|
`Default:` ``0.15``
:param umap_min_opacity: The minimum opacity for the UMAP scatter plot markers. |br|
`Default:` ``0.1``
:param umap_max_opacity: The maximum opacity for the UMAP scatter plot markers. |br|
`Default:` ``1.0``
:param umap_inactive_cell_color: The color for inactive cells in the UMAP scatter plot. |br|
`Default:` ``"rgb(180,180,180)"``
:param umap_inactive_cell_opacity: The opacity for inactive cells in the UMAP scatter plot. |br|
`Default:` ``0.5``
:param umap_active_cell_color: The color for active cells in the UMAP scatter plot. |br|
`Default:` ``"rgb(250,50,50)"``
:param umap_default_cell_color: The default color for cells in the UMAP scatter plot. |br|
`Default:` ``"rgb(180,180,180)"``
:param umap_default_opacity: The default opacity for cells in the UMAP scatter plot. |br|
`Default:` ``0.9``
:param circular_tree_plot_linecolor: The line color for the circular tree plot. |br|
`Default:` ``"rgb(200,200,200)"``
:param circular_tree_start_angle: The start angle for the circular tree plot. |br|
`Default:` ``180``
:param circular_tree_end_angle: The end angle for the circular tree plot. |br|
`Default:` ``360``
:param figure_height: The height of the figures in the Dash app. |br|
`Default:` ``400``
:param hidden_cl_names_set: The set of cell type ontology terms to hide from the visualization. |br|
`Default:` ``DEFAULT_HIDDEN_CL_NAMES_SET``
:param shown_cl_names_set: The set of cell type ontology terms to always show as text labels in the
visualization. |br|
`Default:` ``DEFAULT_SHOWN_CL_NAMES_SET``
:param score_colorscale: The colorscale to use for the cell type ontology-aware scores. |br|
`Default:` ``"Viridis"``
Example:
________
>>> from cellarium.cas._io import suppress_stderr
>>> from cellarium.cas.visualization import CASCircularTreePlotUMAPDashApp
>>> DASH_SERVER_PORT = 8050
>>> adata = ... # get your matrix
>>> cas_ontology_aware_response = cas.annotate_matrix_cell_type_ontology_aware_strategy(
>>> matrix=adata,
>>> chunk_size=500
>>> )
>>> with suppress_stderr():
>>> CASCircularTreePlotUMAPDashApp(
>>> adata,
>>> cas_ontology_aware_response,
>>> cluster_label_obs_column="cluster_label",
>>> ).run(port=DASH_SERVER_PORT, debug=False, jupyter_width="100%")
"""
ALL_CELLS_DOMAIN_KEY = "all cells"
CLUSTER_PREFIX_DOMAIN_KEY = "cluster "
def __init__(
self,
adata: AnnData,
cluster_label_obs_column: t.Optional[str] = None,
aggregation_op: CellOntologyScoresAggregationOp = CellOntologyScoresAggregationOp.MEAN,
aggregation_domain: CellOntologyScoresAggregationDomain = CellOntologyScoresAggregationDomain.OVER_THRESHOLD,
score_threshold: float = 0.05,
min_cell_fraction: float = 0.01,
umap_marker_size: float = 3.0,
umap_padding: float = 0.15,
umap_min_opacity: float = 0.1,
umap_max_opacity: float = 1.0,
umap_inactive_cell_color: str = "rgb(180,180,180)",
umap_inactive_cell_opacity: float = 0.5,
umap_active_cell_color: str = "rgb(250,50,50)",
umap_default_cell_color: str = "rgb(180,180,180)",
umap_default_opacity: float = 0.9,
circular_tree_plot_linecolor: str = "rgb(200,200,200)",
circular_tree_start_angle: int = 180,
circular_tree_end_angle: int = 360,
figure_height: int = 400,
root_node: str = CL_CELL_ROOT_NODE,
hidden_cl_names_set: set[str] = DEFAULT_HIDDEN_CL_NAMES_SET,
shown_cl_names_set: set[str] = DEFAULT_SHOWN_CL_NAMES_SET,
score_colorscale: t.Union[str, list] = "Viridis",
):
self.adata = adata
self.aggregation_op = aggregation_op
self.aggregation_domain = aggregation_domain
self.score_threshold = ConfigValue(score_threshold)
self.min_cell_fraction = ConfigValue(min_cell_fraction)
self.umap_min_opacity = umap_min_opacity
self.umap_max_opacity = umap_max_opacity
self.umap_marker_size = umap_marker_size
self.umap_padding = umap_padding
self.umap_inactive_cell_color = umap_inactive_cell_color
self.umap_inactive_cell_opacity = umap_inactive_cell_opacity
self.umap_active_cell_color = umap_active_cell_color
self.umap_default_cell_color = umap_default_cell_color
self.umap_default_opacity = umap_default_opacity
self.circular_tree_plot_linecolor = circular_tree_plot_linecolor
self.circular_tree_start_angle = circular_tree_start_angle
self.circular_tree_end_angle = circular_tree_end_angle
self.height = figure_height
self.root_node = root_node
self.hidden_cl_names_set = hidden_cl_names_set
self.shown_cl_names_set = shown_cl_names_set
self.score_colorscale = score_colorscale
assert "X_umap" in adata.obsm, (
"UMAP coordinates not found in adata.obsm['X_umap']. "
"This visualisation requires precomputed UMAP coordinates."
)
assert (CAS_CL_SCORES_ANNDATA_OBSM_KEY in adata.obsm) and (CAS_METADATA_ANNDATA_UNS_KEY in adata.uns), (
"Cell type ontology scores not found in the provided AnnData file. Please please run "
"`cellarium.cas.insert_cas_ontology_aware_response_into_adata` prior to running this visualisation."
)
# setup cell domains
self.cell_domain_map = OrderedDict()
self.cell_domain_map[self.ALL_CELLS_DOMAIN_KEY] = np.arange(adata.n_obs)
if cluster_label_obs_column is not None:
assert cluster_label_obs_column in adata.obs
for cluster_label in adata.obs[cluster_label_obs_column].cat.categories:
self.cell_domain_map[self.CLUSTER_PREFIX_DOMAIN_KEY + cluster_label] = get_obs_indices_for_cluster(
adata, cluster_label_obs_column, cluster_label
)
# default cell domain
self.selected_cell_domain_key = ConfigValue(DomainSelectionConstants.NONE)
# Selected cells (from UMAP chart)
self.selected_cells = []
# Selected cell class (from tree diagram)
self.selected_cl_name = None
# instantiate the cell type ontology cache
self.cl = CellOntologyCache()
# instantiate the Dash app
self.app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP, dbc.icons.BOOTSTRAP])
self.server = self.app.server
self.app.layout = self.__create_layout()
self.__setup_initialization()
self.__setup_callbacks()
[docs]
def run(self, port: int = 8050, jupyter_mode: str = "inline", **kwargs):
"""
Run the Dash application on the specified port.
:param port: The port on which to run the Dash application. |br|
`Default:` ``8050``
"""
logger.info(f"Starting Dash application on port {port}...")
try:
self.app.run_server(port=port, jupyter_mode=jupyter_mode, jupyter_height=self.height + 100, **kwargs)
except OSError: # Dash raises OSError if the port is already in use
find_and_kill_process(port)
self.app.run_server(port=port, jupyter_mode=jupyter_mode, jupyter_height=self.height + 100, **kwargs)
def __instantiate_circular_tree_plot(self) -> CircularTreePlot:
# reduce scores over the provided cells
selected_cells = self.__get_effective_selected_cells()
aggregated_scores = get_aggregated_cas_ontology_aware_scores(
self.adata,
obs_indices=(
self.cell_domain_map[self.ALL_CELLS_DOMAIN_KEY] if len(selected_cells) == 0 else selected_cells
),
aggregation_op=self.aggregation_op,
aggregation_domain=self.aggregation_domain,
threshold=self.score_threshold.get(),
)
# generate a Phylo tree
rooted_tree = convert_aggregated_cell_ontology_scores_to_rooted_tree(
aggregated_scores=aggregated_scores,
cl=self.cl,
root_cl_name=self.root_node,
min_fraction=self.min_cell_fraction.get(),
hidden_cl_names_set=self.hidden_cl_names_set,
)
phyloxml_string = generate_phyloxml_from_scored_cell_ontology_tree(
rooted_tree, "Scored cell type ontology tree", self.cl, indent=3
)
with tempfile.NamedTemporaryFile(delete=False, mode="w+t") as temp_file:
temp_file_name = temp_file.name
temp_file.write(phyloxml_string)
temp_file.flush()
try:
phyloxml_tree = Phylo.read(temp_file_name, "phyloxml")
finally:
os.remove(temp_file_name)
return CircularTreePlot(
tree=phyloxml_tree,
score_colorscale=self.score_colorscale,
linecolor=self.circular_tree_plot_linecolor,
start_angle=self.circular_tree_start_angle,
end_angle=self.circular_tree_end_angle,
shown_cl_names_set=self.shown_cl_names_set,
)
def __get_padded_umap_bounds(self, umap_padding: float) -> t.Tuple[float, float, float, float]:
actual_min_x = np.min(self.adata.obsm["X_umap"][:, 0])
actual_max_x = np.max(self.adata.obsm["X_umap"][:, 0])
actual_min_y = np.min(self.adata.obsm["X_umap"][:, 1])
actual_max_y = np.max(self.adata.obsm["X_umap"][:, 1])
padded_min_x = actual_min_x - umap_padding * (actual_max_x - actual_min_x)
padded_max_x = actual_max_x + umap_padding * (actual_max_x - actual_min_x)
padded_min_y = actual_min_y - umap_padding * (actual_max_y - actual_min_y)
padded_max_y = actual_max_y + umap_padding * (actual_max_y - actual_min_y)
return padded_min_x, padded_max_x, padded_min_y, padded_max_y
def __get_scores_for_cl_name(self, cl_name: str) -> np.ndarray:
cl_index = self.cl.cl_names_to_idx_map[cl_name]
return self.adata.obsm[CAS_CL_SCORES_ANNDATA_OBSM_KEY][:, cl_index].toarray().flatten()
def __get_scatter_plot_opacity_from_scores(self, scores: np.ndarray) -> np.ndarray:
min_score = np.min(scores)
max_score = np.max(scores)
normalized_scores = (scores - min_score) / (1e-6 + max_score - min_score)
return np.maximum(
scores, self.umap_min_opacity + (self.umap_max_opacity - self.umap_min_opacity) * normalized_scores
)
def __create_layout(self):
layout = html.Div(
[
dbc.Row(dbc.Col(className="gr-spacer", width=12)),
dbc.Row(
dbc.Col(
[
html.H3(self.__render_breadcrumb(), id="selected-domain-label", className="gr-breadcrumb"),
html.Div(
[
dbc.ButtonGroup(
[
dbc.Button(
html.I(className="bi bi-gear-fill"),
id="settings-button",
n_clicks=0,
size="sm",
),
]
)
],
className="gr-settings-buttons",
),
],
className="gr-title",
width=12,
)
),
dbc.Row(
self.__render_cell_selection_panes(),
id="panel-titles",
),
dbc.Row(
[
dbc.Col(
html.Div(
[
dcc.Graph(
id="circular-tree-plot",
style={
"width": "100%",
"display": "inline-block",
"height": f"{self.height}px",
},
config={"scrollZoom": True},
),
]
),
width=6,
),
dbc.Col(
html.Div(
[
dcc.Graph(
id="umap-scatter-plot",
style={
"width": "100%",
"display": "inline-block",
"height": f"{self.height - 10}px",
},
# Zoom is very choppy with this enabled. Users should use selection zoom
config={"scrollZoom": False},
),
]
),
width=6,
),
]
),
dbc.Offcanvas(
id="settings-pane", title="Settings", is_open=False, children=self.__render_closed_settings_pane()
),
html.Div(id="init", style={"display": "none"}),
html.Div(id="no-action", style={"display": "none"}),
],
)
return layout
def __initialize_umap_scatter_plot(self) -> go.Figure:
# calculate static bounds for the UMAP scatter plot
self.umap_min_x, self.umap_max_x, self.umap_min_y, self.umap_max_y = self.__get_padded_umap_bounds(
self.umap_padding
)
fig = go.Figure()
color = self.umap_default_cell_color
selected_cells = self.__get_effective_selected_cells()
if len(selected_cells) > 0:
color = [self.umap_inactive_cell_color] * self.adata.n_obs
for i_obs in selected_cells:
color[i_obs] = self.umap_active_cell_color
fig.add_trace(
go.Scatter(
x=self.adata.obsm["X_umap"][:, 0],
y=self.adata.obsm["X_umap"][:, 1],
mode="markers",
marker=dict(
color=color,
size=self.umap_marker_size,
opacity=self.umap_default_opacity,
),
unselected=dict(marker=dict(opacity=self.umap_default_opacity)),
)
)
self.__update_umap_scatter_plot_layout(fig)
self._umap_scatter_plot_figure = fig
return fig
def __initialize_circular_tree_plot(self) -> go.Figure:
self.circular_tree_plot = self.__instantiate_circular_tree_plot()
fig = self.circular_tree_plot.plotly_figure
return fig
def __setup_initialization(self):
@self.app.callback(Output("umap-scatter-plot", "figure"), Input("init", "children"))
def __initialize_umap_scatter_plot(init):
return self.__initialize_umap_scatter_plot()
@self.app.callback(Output("circular-tree-plot", "figure"), Input("init", "children"))
def __initialize_circular_tree_plot(init):
return self.__initialize_circular_tree_plot()
def __update_umap_scatter_plot_layout(self, umap_scatter_plot_fig):
umap_scatter_plot_fig.update_layout(
# uirevision is needed to maintain pan/zoom state. It must be updated to trigger a refresh
uirevision="true",
plot_bgcolor="white",
margin=dict(l=0, r=25, t=50, b=0),
xaxis=dict(
title="UMAP 1",
showgrid=False,
zeroline=False, # Keep zero line enabled
zerolinecolor="black",
range=[self.umap_min_x, self.umap_max_x], # Set x-axis limits
showline=True, # Show axis line
linecolor="black",
linewidth=1,
tickmode="linear",
tick0=-10,
dtick=5,
),
yaxis=dict(
title="UMAP 2",
showgrid=False,
zeroline=False, # Keep zero line enabled
zerolinecolor="black",
range=[self.umap_min_y, self.umap_max_y], # Set y-axis limits
showline=True, # Show axis line
linecolor="black",
linewidth=1,
tickmode="linear",
tick0=-10,
dtick=5,
),
dragmode="pan",
)
def __render_breadcrumb(self) -> Component:
selected_cells = self.__get_effective_selected_cells()
if len(selected_cells) == 0 and self.selected_cell_domain_key.get() == DomainSelectionConstants.NONE:
label = "Viewing results for all cells"
show_clear = False
elif (
len(selected_cells) == 1 and self.selected_cell_domain_key.get() == DomainSelectionConstants.USER_SELECTION
):
label = f"Selected cell index {selected_cells[0]}"
show_clear = True
elif len(selected_cells) > 1 and self.selected_cell_domain_key.get() == DomainSelectionConstants.USER_SELECTION:
label = f"Selected {len(selected_cells)} cells"
show_clear = True
else:
modifier = "cell" if len(selected_cells) == 1 else "cells"
label = f"Selected cell domain {self.selected_cell_domain_key.get()} ({len(selected_cells)} {modifier})"
show_clear = True
children = [html.B(label, className="gr-breadcrumb-label")]
if show_clear:
children.append(
html.Div(
[html.I(className="bi bi-x-circle")],
id="reset-selection-button",
n_clicks=0,
className="btn btn-link",
title="Clear selection",
)
)
return html.Div(children)
def __render_cell_selection_panes(self) -> Component:
return [
dbc.Col(self.__render_cell_selection_title(panel_id="cell-selection-title-tree"), width=6),
dbc.Col(self.__render_cell_selection_title(panel_id="cell-selection-title-umap"), width=6),
]
def __render_cell_selection_title(self, panel_id: str) -> Component:
if self.selected_cl_name is not None:
title = f"Selected cell class: {self.cl.cl_names_to_labels_map[self.selected_cl_name]}"
else:
title = DEFAULT_PANEL_TITLES[panel_id]
return [html.Div(title, className="gr-header", id=panel_id)]
def __render_closed_settings_pane(self) -> Component:
return [
html.Div(
[
html.Label("Cell selection:", style={"margin-bottom": "5px"}),
self.__render_domain_dropdown(),
],
className="gr-form-item",
),
html.Div(
[
dbc.Label("Evidence threshold:", html_for="evidence-threshold"),
dcc.Slider(
id="evidence-threshold",
min=0,
max=1,
value=self.score_threshold.get(dirty_read=True),
marks={
0: "0",
0.25: "0.25",
0.5: "0.5",
0.75: "0.75",
1: "1",
},
tooltip={"placement": "bottom", "always_visible": True, "style": {"margin": "0 5px"}},
),
],
className="gr-form-item",
),
html.Div(
[
dbc.Label("Minimum cell fraction:", html_for="cell-fraction"),
dcc.Slider(
id="cell-fraction",
min=0,
max=1,
value=self.min_cell_fraction.get(dirty_read=True),
marks={
0: "0",
0.25: "0.25",
0.5: "0.5",
0.75: "0.75",
1: "1",
},
tooltip={"placement": "bottom", "always_visible": True, "style": {"margin": "0 5px"}},
),
],
className="gr-form-item",
),
html.Div(
[
dbc.Button(
"Cancel",
id="cancel-button",
title="Cancel the changes and close the settings pane",
n_clicks=0,
),
dbc.Button(
"Update",
id="update-button",
title="Update the graphs based on the specified configuration",
n_clicks=0,
),
],
className="gr-settings-button-bar",
),
html.A(
html.Img(
src="assets/cellarium-powered-400px.png",
),
href="https://cellarium.ai",
className="gr-powered-by",
target="_blank",
),
]
def __render_domain_dropdown(self) -> Component:
labels = [{"label": "None selected", "value": DomainSelectionConstants.NONE}]
if len(self.selected_cells) > 0:
labels.append({"label": "User selection", "value": DomainSelectionConstants.USER_SELECTION})
if len(self.cell_domain_map.keys()) > 1:
labels.append({"label": "________________", "value": None, "disabled": True})
labels.append({"label": html.Span("Provided domains"), "value": None, "disabled": True})
for k in list(self.cell_domain_map.keys())[1:]:
labels.append({"label": k, "value": k})
return dcc.Dropdown(
id="domain-dropdown",
options=labels,
value=self.selected_cell_domain_key.get(), # default to no selection
className="gr-custom-dropdown",
clearable=False,
)
def __get_effective_selected_cells(self) -> list:
# User has chosen not to show any highlighted cells
if self.selected_cell_domain_key.get() == DomainSelectionConstants.NONE:
return []
# User has chose to highlight explicitly selected cells
if self.selected_cell_domain_key.get() == DomainSelectionConstants.USER_SELECTION:
return self.selected_cells
# User has chose to highlight pre-calculated domain cells
if self.selected_cell_domain_key.get() is not None:
return self.cell_domain_map[self.selected_cell_domain_key.get()]
def __clear_cell_selection(self):
self.selected_cells = []
self.selected_cell_domain_key.reset()
def __clear_cell_class_selection(self):
self.selected_cl_name = None
self.circular_tree_plot.update_selected_nodes(selected_cl_path=[])
def __setup_callbacks(self) -> None:
# Cell selection callbacks
@self.app.callback(
Output("circular-tree-plot", "figure", allow_duplicate=True),
Output("umap-scatter-plot", "figure", allow_duplicate=True),
Output("panel-titles", "children", allow_duplicate=True),
Input("circular-tree-plot", "clickData"),
prevent_initial_call=True,
)
def __update_umap_scatter_plot_based_on_circular_tree_plot(clickData):
if clickData is None or "points" not in clickData:
return (
self.circular_tree_plot.plotly_figure,
self._umap_scatter_plot_figure,
self.__render_cell_selection_panes(),
)
point = clickData["points"][0]
if "pointIndex" not in point:
return (
self._umap_scatter_plot_figure,
self._umap_scatter_plot_figure,
self.__render_cell_selection_panes(),
)
node_index = point["pointIndex"]
selected_cl_name = self.circular_tree_plot.clade_index_to_cl_name_map.get(node_index)
if selected_cl_name is not None:
# Toggle selection if the selected cell class was clicked
if selected_cl_name == self.selected_cl_name:
self.__clear_cell_class_selection()
else:
selected_cl_path = self.circular_tree_plot.get_clade_path_from_index(selected_cl_idx=node_index)
self.circular_tree_plot.update_selected_nodes(selected_cl_path)
self.selected_cl_name = selected_cl_name
selected_cells_set = set(self.__get_effective_selected_cells())
if self.selected_cl_name is None:
# Optimizing for lower code complexity over performance by always treating color and opacity as arrays in this case
color = [self.umap_default_cell_color] * self.adata.n_obs
opacity = [self.umap_default_opacity] * self.adata.n_obs
else:
scores = self.__get_scores_for_cl_name(self.selected_cl_name)
opacity = self.__get_scatter_plot_opacity_from_scores(scores)
color = sample_colorscale(self.circular_tree_plot.score_colorscale, scores)
selected_cells_set = set(self.__get_effective_selected_cells())
# if no cells are selected but a cell class is, highlight all cells
if len(selected_cells_set) == 0 and self.selected_cl_name is not None:
selected_cells_set = set(self.cell_domain_map[self.ALL_CELLS_DOMAIN_KEY])
for i_obs in range(self.adata.n_obs):
if i_obs not in selected_cells_set:
color[i_obs] = self.umap_inactive_cell_color
opacity[i_obs] = self.umap_inactive_cell_opacity
self._umap_scatter_plot_figure.update_traces(
marker=dict(
color=color,
colorscale=self.circular_tree_plot.score_colorscale,
opacity=opacity,
cmin=0.0,
cmax=1.0,
),
text=[f"{score:.5f}" for score in scores] if self.selected_cl_name is not None else None,
hovertemplate=(
"<b>Evidence score: %{text}</b><extra></extra>" if self.selected_cl_name is not None else None
),
)
self._umap_scatter_plot_figure.update_layout(
plot_bgcolor="white",
margin=dict(l=0, r=25, t=50, b=0),
# uirevision is needed to maintain pan/zoom state. It must be updated to trigger a refresh
uirevision=self._umap_scatter_plot_figure["layout"]["uirevision"],
)
return (
self.circular_tree_plot.plotly_figure,
self._umap_scatter_plot_figure,
self.__render_cell_selection_panes(),
)
@self.app.callback(
Output("circular-tree-plot", "figure", allow_duplicate=True),
Output("umap-scatter-plot", "figure", allow_duplicate=True),
Output("selected-domain-label", "children", allow_duplicate=True),
Output("panel-titles", "children", allow_duplicate=True),
Input("umap-scatter-plot", "clickData"),
State("umap-scatter-plot", "selectedData"),
prevent_initial_call=True,
)
def __update_circular_tree_plot_based_on_umap_scatter_plot(clickData, selectedData):
self._umap_scatter_plot_figure.update_selections()
if clickData is None or "points" not in clickData:
return (
self.circular_tree_plot.plotly_figure,
self._umap_scatter_plot_figure,
self.__render_breadcrumb(),
self.__render_cell_selection_panes(),
)
point = clickData["points"][0]
if "pointIndex" not in point:
return (
self.circular_tree_plot.plotly_figure,
self._umap_scatter_plot_figure,
self.__render_breadcrumb(),
self.__render_cell_selection_panes(),
)
node_index = point["pointIndex"]
self.selected_cells = [node_index]
self.selected_cell_domain_key.set(DomainSelectionConstants.USER_SELECTION).commit()
self.__clear_cell_class_selection()
self.__initialize_circular_tree_plot()
self.__initialize_umap_scatter_plot()
return (
self.circular_tree_plot.plotly_figure,
self._umap_scatter_plot_figure,
self.__render_breadcrumb(),
self.__render_cell_selection_panes(),
)
@self.app.callback(
Output("circular-tree-plot", "figure", allow_duplicate=True),
Output("umap-scatter-plot", "figure", allow_duplicate=True),
Output("umap-scatter-plot", "selectedData", allow_duplicate=True),
Output("selected-domain-label", "children", allow_duplicate=True),
Output("panel-titles", "children", allow_duplicate=True),
Input("umap-scatter-plot", "selectedData"),
prevent_initial_call=True,
)
def __update_circular_tree_plot_based_on_umap_scatter_plot_select(selectedData):
# A selection event is firing on initialization. Ignore it by only accepting selectedData with a range field or lasso field
if (
selectedData is None
or "points" not in selectedData
or ("range" not in selectedData and "lassoPoints" not in selectedData)
):
return (
self.circular_tree_plot.plotly_figure,
self._umap_scatter_plot_figure,
None,
self.__render_breadcrumb(),
self.__render_cell_selection_panes(),
)
points = selectedData["points"]
node_indexes = [point["pointIndex"] for point in points]
self.selected_cells = node_indexes
self.selected_cell_domain_key.set(DomainSelectionConstants.USER_SELECTION).commit()
self.__clear_cell_class_selection()
self.__initialize_circular_tree_plot()
self.__initialize_umap_scatter_plot()
return (
self.circular_tree_plot.plotly_figure,
self._umap_scatter_plot_figure,
selectedData,
self.__render_breadcrumb(),
self.__render_cell_selection_panes(),
)
@self.app.callback(
Output("circular-tree-plot", "figure", allow_duplicate=True),
Output("umap-scatter-plot", "figure", allow_duplicate=True),
Output("selected-domain-label", "children", allow_duplicate=True),
Output("settings-pane", "children", allow_duplicate=True),
Output("panel-titles", "children", allow_duplicate=True),
Input("reset-selection-button", "n_clicks"),
prevent_initial_call=True,
)
def __reset_selection(n_clicks):
if n_clicks != 0:
self.__clear_cell_selection()
self.__clear_cell_class_selection()
# update the figures
self.__initialize_circular_tree_plot()
self.__initialize_umap_scatter_plot()
return (
self.circular_tree_plot.plotly_figure,
self._umap_scatter_plot_figure,
self.__render_breadcrumb(),
self.__render_closed_settings_pane(),
self.__render_cell_selection_panes(),
)
# Settings callbacks
@self.app.callback(
Output("circular-tree-plot", "figure", allow_duplicate=True),
Output("umap-scatter-plot", "figure", allow_duplicate=True),
Output("selected-domain-label", "children", allow_duplicate=True),
Output("settings-pane", "children", allow_duplicate=True),
Output("panel-titles", "children", allow_duplicate=True),
Output("settings-pane", "is_open", allow_duplicate=True),
Input("update-button", "n_clicks"),
prevent_initial_call=True,
)
def __save_settings(n_clicks):
if n_clicks > 0:
# If a domain selection was changed and set to None, clear all selections
if (
self.selected_cell_domain_key.is_dirty()
and self.selected_cell_domain_key.get(dirty_read=True) is DomainSelectionConstants.NONE
):
self.__clear_cell_selection()
self.selected_cell_domain_key.commit()
self.score_threshold.commit()
self.min_cell_fraction.commit()
self.__clear_cell_class_selection()
# update the figures
self.__initialize_circular_tree_plot()
self.__initialize_umap_scatter_plot()
return (
self.circular_tree_plot.plotly_figure,
self._umap_scatter_plot_figure,
self.__render_breadcrumb(),
self.__render_closed_settings_pane(),
self.__render_cell_selection_panes(),
False,
)
@self.app.callback(
Output("settings-pane", "children", allow_duplicate=True),
Output("settings-pane", "is_open", allow_duplicate=True),
Input("cancel-button", "n_clicks"),
prevent_initial_call=True,
)
def __cancel_settings(n_clicks):
self.selected_cell_domain_key.rollback()
self.score_threshold.rollback()
self.min_cell_fraction.rollback()
return self.__render_closed_settings_pane(), False
@self.app.callback(
Output("settings-pane", "is_open", allow_duplicate=True),
Input("settings-button", "n_clicks"),
[State("settings-pane", "is_open")],
prevent_initial_call=True,
)
def __toggle_settings(n_clicks, is_open):
if n_clicks:
return not is_open
return is_open
@self.app.callback(
Output("no-action", "children", allow_duplicate=True),
Input("domain-dropdown", "value"),
prevent_initial_call=True,
)
def __update_domain(domain):
# set the domain
self.selected_cell_domain_key.set(domain)
@self.app.callback(
Output("no-action", "children", allow_duplicate=True),
Input("evidence-threshold", "value"),
prevent_initial_call=True,
)
def __update_evidence_threshold(input_value):
try:
self.score_threshold.set(float(input_value))
except ValueError:
pass
return input_value
@self.app.callback(
Output("no-action", "children", allow_duplicate=True),
Input("cell-fraction", "value"),
prevent_initial_call=True,
)
def __update_cell_fraction(input_value):
try:
self.min_cell_fraction.set(float(input_value))
except ValueError:
pass
return input_value