Skip to content

Heatmap Strategy

heatmap_strategy

Heatmap Strategy - Count-Based Matrix Visualizations.

This module implements the HeatmapStrategy for creating heatmap visualizations that show counts of unique values at the intersection of two categorical dimensions.

Classes:

Name Description
HeatmapStrategy

Strategy for count-based heatmap generation.

Notes
  • Shows absolute counts of unique values (not percentages)
  • Supports multiple aggregation methods (nunique, count, sum)
  • Automatically sorts by totals for readability

For supported use cases, refer to the official documentation.

Classes

HeatmapStrategy

HeatmapStrategy(config: Dict[str, Any])

Bases: BasePlotStrategy

Strategy for count-based heatmap matrix visualizations.

This strategy creates heatmaps showing counts of unique values where rows and columns represent categorical dimensions, and cell values represent aggregated counts.

Parameters:

Name Type Description Default
config Dict[str, Any]

Complete configuration from YAML file.

required

Attributes:

Name Type Description
data_config Dict[str, Any]

Data processing configuration.

plotly_config Dict[str, Any]

Plotly-specific configuration.

row_column str

Column for heatmap rows (y-axis).

col_column str

Column for heatmap columns (x-axis).

value_column str

Column containing values to count unique occurrences.

aggregation str

Aggregation method: 'nunique' (default), 'count', 'sum'.

Methods:

Name Description
validate_data

Validate input data for heatmap requirements

process_data

Process data and create count matrix

create_figure

Create heatmap figure from count matrix

Notes
  • Supports multiple aggregation methods
  • Automatically sorts rows and columns by totals
  • Shows absolute counts (not percentages)

Initialize strategy with configuration.

Parameters:

Name Type Description Default
config Dict[str, Any]

Complete configuration from YAML file.

required
Source code in src/domain/plot_strategies/charts/heatmap_strategy.py
def __init__(self, config: Dict[str, Any]):
    """
    Initialize strategy with configuration.

    Parameters
    ----------
    config : Dict[str, Any]
        Complete configuration from YAML file.
    """
    super().__init__(config)
    self.data_config = config.get("data", {})
    self.plotly_config = self.viz_config.get("plotly", {})

    # Extract strategy-specific parameters
    self.row_column = self.plotly_config.get("row_column", "referenceAG")
    self.col_column = self.plotly_config.get("col_column", "sample")
    self.value_column = self.plotly_config.get("value_column", "ko")
    self.aggregation = self.plotly_config.get("aggregation", "nunique")

    logger.info(
        f"HeatmapStrategy initialized for "
        f"{self.metadata.get('use_case_id', 'unknown')}: "
        f"rows='{self.row_column}', cols='{self.col_column}', "
        f"values='{self.value_column}', agg='{self.aggregation}'"
    )
Functions
validate_data
validate_data(df: DataFrame) -> None

Validate input data for heatmap requirements.

Parameters:

Name Type Description Default
df DataFrame

Input data to validate.

required

Raises:

Type Description
ValueError

If DataFrame is empty, required columns missing, or no valid row/column categories found.

Source code in src/domain/plot_strategies/charts/heatmap_strategy.py
def validate_data(self, df: pd.DataFrame) -> None:
    """
    Validate input data for heatmap requirements.

    Parameters
    ----------
    df : pd.DataFrame
        Input data to validate.

    Raises
    ------
    ValueError
        If DataFrame is empty, required columns missing, or no valid
        row/column categories found.
    """
    logger.debug(
        f"Validating data - Shape: {df.shape}, " f"Columns: {df.columns.tolist()}"
    )

    # Check DataFrame not empty
    if df.empty:
        raise ValueError("Input DataFrame is empty")

    # Required columns
    required_cols = [self.row_column, self.col_column, self.value_column]

    # Check required columns exist
    missing_cols = [col for col in required_cols if col not in df.columns]
    if missing_cols:
        raise ValueError(
            f"Missing required columns: {missing_cols}. "
            f"Available columns: {df.columns.tolist()}"
        )

    # Drop rows with null values in critical columns
    df_clean = df.dropna(subset=required_cols)
    if df_clean.empty:
        raise ValueError(
            f"No valid data after removing nulls in columns: {required_cols}"
        )

    # Check at least one row and column category
    n_rows = df_clean[self.row_column].nunique()
    n_cols = df_clean[self.col_column].nunique()

    if n_rows == 0:
        raise ValueError(f"No categories found in row column '{self.row_column}'")
    if n_cols == 0:
        raise ValueError(f"No categories found in column '{self.col_column}'")

    logger.info(
        f"Data validation passed - "
        f"{n_rows} row categories, {n_cols} column categories, "
        f"{len(df_clean)} records"
    )
process_data
process_data(df: DataFrame) -> pd.DataFrame

Process data and create count matrix.

Cleans data, normalizes strings, groups by row and column dimensions, aggregates values, and creates a sorted matrix.

Parameters:

Name Type Description Default
df DataFrame

Input data with required columns.

required

Returns:

Type Description
DataFrame

Heatmap matrix with row categories as index, column categories as columns, and aggregated counts as values.

Source code in src/domain/plot_strategies/charts/heatmap_strategy.py
def process_data(self, df: pd.DataFrame) -> pd.DataFrame:
    """
    Process data and create count matrix.

    Cleans data, normalizes strings, groups by row and column dimensions,
    aggregates values, and creates a sorted matrix.

    Parameters
    ----------
    df : pd.DataFrame
        Input data with required columns.

    Returns
    -------
    pd.DataFrame
        Heatmap matrix with row categories as index, column categories
        as columns, and aggregated counts as values.
    """
    logger.info(f"Processing data with '{self.aggregation}' aggregation...")

    # Clean data: remove nulls
    df_clean = df.dropna(
        subset=[self.row_column, self.col_column, self.value_column]
    ).copy()

    logger.debug(
        f"After null removal: {len(df_clean)} records "
        f"({len(df) - len(df_clean)} removed)"
    )

    # Normalize string columns
    for col in [self.row_column, self.col_column, self.value_column]:
        if df_clean[col].dtype == "object":
            df_clean[col] = df_clean[col].str.strip()
            # Uppercase for row and value columns (typically categorical)
            if col in [self.row_column, self.value_column]:
                df_clean[col] = df_clean[col].str.upper()

    # Aggregate values per (row, column) pair
    if self.aggregation == "nunique":
        # Count unique values
        aggregated = df_clean.groupby([self.row_column, self.col_column])[
            self.value_column
        ].nunique()
    elif self.aggregation == "count":
        # Count all occurrences
        aggregated = df_clean.groupby([self.row_column, self.col_column])[
            self.value_column
        ].count()
    elif self.aggregation == "sum":
        # Sum values (requires numeric column)
        aggregated = df_clean.groupby([self.row_column, self.col_column])[
            self.value_column
        ].sum()
    else:
        raise ValueError(
            f"Unknown aggregation method: '{self.aggregation}'. "
            f"Supported: 'nunique', 'count', 'sum'"
        )

    logger.debug(f"Aggregated {len(aggregated)} (row, column) pairs")

    # Pivot to 2D matrix: rows = row_column, cols = col_column
    heatmap_matrix = aggregated.unstack(level=self.col_column).fillna(0)

    # Convert to int if all values are whole numbers
    if (heatmap_matrix % 1 == 0).all().all():
        heatmap_matrix = heatmap_matrix.astype(int)

    # Sort rows and columns by total counts for readability
    heatmap_matrix = heatmap_matrix.loc[
        heatmap_matrix.sum(axis=1).sort_values(ascending=False).index,
        heatmap_matrix.sum(axis=0).sort_values(ascending=False).index,
    ]

    logger.info(
        f"Heatmap matrix created - "
        f"Shape: {heatmap_matrix.shape}, "
        f"Value range: [{heatmap_matrix.min().min()}, "
        f"{heatmap_matrix.max().max()}]"
    )

    return heatmap_matrix
create_figure
create_figure(processed_df: DataFrame) -> go.Figure

Create heatmap figure from count matrix.

Parameters:

Name Type Description Default
processed_df DataFrame

Heatmap matrix (rows × columns).

required

Returns:

Type Description
Figure

Configured Plotly heatmap.

Source code in src/domain/plot_strategies/charts/heatmap_strategy.py
def create_figure(self, processed_df: pd.DataFrame) -> go.Figure:
    """
    Create heatmap figure from count matrix.

    Parameters
    ----------
    processed_df : pd.DataFrame
        Heatmap matrix (rows × columns).

    Returns
    -------
    go.Figure
        Configured Plotly heatmap.
    """
    logger.debug("Creating heatmap figure...")

    # Extract chart configuration
    chart_config = self.plotly_config.get("chart", {})
    layout_config = self.plotly_config.get("layout", {})

    # Handle title configuration (support both string and dict)
    title_config = chart_config.get("title", {})
    if isinstance(title_config, str):
        # Backward compatibility: string title
        show_title = True
        title_text = title_config
    else:
        # New format: dict with show, text
        show_title = title_config.get("show", True)
        title_text = (
            title_config.get("text", "Unique Value Count Heatmap")
            if show_title
            else ""
        )

    # Get axis labels
    x_label = chart_config.get("xaxis", {}).get("title", "Sample")
    y_label = chart_config.get("yaxis", {}).get("title", "Category")
    color_label = chart_config.get("color_label", "Unique Count")

    # Get text display setting
    # For count matrices, show integer values
    total_cells = processed_df.size
    show_text = total_cells <= 400  # Enable text only for smaller matrices
    text_auto = chart_config.get("text_auto", True if show_text else False)

    # Get color scale
    color_scale = chart_config.get("color_continuous_scale", "Greens")

    # Create heatmap using plotly express
    fig = px.imshow(
        processed_df,
        labels=dict(x=x_label, y=y_label, color=color_label),
        text_auto=text_auto,
        aspect="auto",
        color_continuous_scale=color_scale,
    )

    # Apply layout configuration
    template = layout_config.get("template", "simple_white")

    # Calculate dynamic height based on number of rows
    n_rows = processed_df.shape[0]
    default_height = max(500, min(1200, 40 * n_rows + 160))
    height = layout_config.get("height", default_height)
    use_autosize = layout_config.get("autosize", False)

    # Get margin configuration
    margin_config = layout_config.get("margin", {})
    margin = dict(
        l=margin_config.get("l", 80),
        r=margin_config.get("r", 30),
        t=margin_config.get("t", 70),
        b=margin_config.get("b", 60),
    )

    # Get axis angles
    xaxis_tickangle = chart_config.get("xaxis_tickangle", -45)
    yaxis_tickangle = chart_config.get("yaxis_tickangle", 0)

    # Get colorbar configuration
    colorbar_config = chart_config.get("colorbar", {})
    colorbar_title = colorbar_config.get("title", color_label)

    # Build layout update dict
    layout_update = {
        "height": height,
        "template": template,
        "xaxis_tickangle": xaxis_tickangle,
        "margin": margin,
        "plot_bgcolor": "white",
        "coloraxis_colorbar": dict(title=colorbar_title),
    }

    # Add title if enabled
    if show_title and title_text:
        layout_update["title"] = {"text": title_text, "x": 0.5, "xanchor": "center"}

    # Add autosize or width
    if use_autosize:
        layout_update["autosize"] = True
    else:
        if layout_config.get("width"):
            layout_update["width"] = layout_config.get("width")

    fig.update_layout(**layout_update)

    # Remove grid lines for cleaner look
    fig.update_xaxes(showgrid=False)

    # Update Y-axis with rotation
    fig.update_yaxes(showgrid=False, tickangle=yaxis_tickangle)

    # Update text font size if configured
    text_font_size = chart_config.get("text_font_size", 10)
    if text_auto:
        fig.update_traces(textfont_size=text_font_size)

    logger.info(
        f"Heatmap figure created - "
        f"Size: {layout_update.get('width', 'auto')}x{height}px, "
        f"Template: {template}"
    )

    return fig
apply_filters
apply_filters(df: DataFrame, filters: Optional[Dict[str, Any]] = None) -> pd.DataFrame

Apply filters to data.

This is a common implementation that can be overridden by subclasses if needed.

Parameters:

Name Type Description Default
df DataFrame

Data to filter.

required
filters Optional[Dict[str, Any]]

Filter specifications.

None

Returns:

Type Description
DataFrame

Filtered data.

Source code in src/domain/plot_strategies/base/base_plot_strategy.py
def apply_filters(
    self, df: pd.DataFrame, filters: Optional[Dict[str, Any]] = None
) -> pd.DataFrame:
    """
    Apply filters to data.

    This is a common implementation that can be overridden
    by subclasses if needed.

    Parameters
    ----------
    df : pd.DataFrame
        Data to filter.
    filters : Optional[Dict[str, Any]], default=None
        Filter specifications.

    Returns
    -------
    pd.DataFrame
        Filtered data.
    """
    import logging

    logger = logging.getLogger(__name__)

    if not filters:
        logger.debug("No filters provided, returning original data")
        return df

    logger.info(
        f"Applying filters - Input shape: {df.shape}, "
        f"Columns: {df.columns.tolist()}"
    )
    logger.info(f"Filters to apply: {filters}")

    filtered_df = df.copy()

    # Get filter configurations
    filter_configs = self.config.get("filters", [])

    for filter_config in filter_configs:
        filter_id = filter_config.get("filter_id")
        filter_type = filter_config.get("type")

        if filter_id not in filters:
            continue

        filter_value = filters[filter_id]
        data_binding = filter_config.get("data_binding", {})
        column = data_binding.get("column")

        if not column or column not in filtered_df.columns:
            logger.warning(
                f"Filter '{filter_id}': Column '{column}' not found. "
                f"Available: {filtered_df.columns.tolist()}"
            )
            continue

        # Apply range filter
        if filter_type == "range" and isinstance(filter_value, list):
            min_val, max_val = filter_value
            logger.info(
                f"Applying range filter on '{column}': " f"[{min_val}, {max_val}]"
            )
            filtered_df = filtered_df[
                (filtered_df[column] >= min_val) & (filtered_df[column] <= max_val)
            ]
            logger.info(f"After filter: {len(filtered_df)} rows remaining")

    logger.info(f"Final filtered shape: {filtered_df.shape}")
    return filtered_df
apply_customizations
apply_customizations(fig: Figure, customizations: Optional[Any] = None) -> go.Figure

Apply custom styling to figure.

This is a hook for future customization features (FLEXIVEL and FLEXIVEL2).

Parameters:

Name Type Description Default
fig Figure

Base figure.

required
customizations Optional[Any]

Customization specifications.

None

Returns:

Type Description
Figure

Customized figure.

Source code in src/domain/plot_strategies/base/base_plot_strategy.py
def apply_customizations(
    self, fig: go.Figure, customizations: Optional[Any] = None
) -> go.Figure:
    """
    Apply custom styling to figure.

    This is a hook for future customization features
    (FLEXIVEL and FLEXIVEL2).

    Parameters
    ----------
    fig : go.Figure
        Base figure.
    customizations : Optional[Any], default=None
        Customization specifications.

    Returns
    -------
    go.Figure
        Customized figure.
    """
    # Hook for future implementation
    return fig
generate_plot
generate_plot(data: DataFrame, filters: Optional[Dict[str, Any]] = None, customizations: Optional[Any] = None) -> go.Figure

Generate complete plot (Template Method).

This method orchestrates the entire plot generation process: 1. Validate input data 2. Process data 3. Apply filters 4. Create figure 5. Apply customizations

Parameters:

Name Type Description Default
data DataFrame

Input data.

required
filters Optional[Dict[str, Any]]

Filters to apply.

None
customizations Optional[Any]

Customizations to apply.

None

Returns:

Type Description
Figure

Complete Plotly figure.

Raises:

Type Description
ValueError

If validation fails.

Source code in src/domain/plot_strategies/base/base_plot_strategy.py
def generate_plot(
    self,
    data: pd.DataFrame,
    filters: Optional[Dict[str, Any]] = None,
    customizations: Optional[Any] = None,
) -> go.Figure:
    """
    Generate complete plot (Template Method).

    This method orchestrates the entire plot generation process:
    1. Validate input data
    2. Process data
    3. Apply filters
    4. Create figure
    5. Apply customizations

    Parameters
    ----------
    data : pd.DataFrame
        Input data.
    filters : Optional[Dict[str, Any]], default=None
        Filters to apply.
    customizations : Optional[Any], default=None
        Customizations to apply.

    Returns
    -------
    go.Figure
        Complete Plotly figure.

    Raises
    ------
    ValueError
        If validation fails.
    """
    # 1. Validate
    self.validate_data(data)

    # 2. Process
    processed_df = self.process_data(data)

    # 3. Filter
    filtered_df = self.apply_filters(processed_df, filters)

    # 4. Create figure
    figure = self.create_figure(filtered_df)

    # 5. Apply customizations (hook for future)
    figure = self.apply_customizations(figure, customizations)

    return figure