Skip to content

Sankey Diagram Strategy

sankey_strategy

Sankey Strategy - Alluvial/Flow Diagram Visualizations.

This module implements the SankeyStrategy for generating Sankey (alluvial) diagrams that visualize flow relationships between multiple categorical levels.

Classes:

Name Description
SankeyStrategy

Strategy for Sankey diagram generation using Plotly.

Notes
  • Visualizes flow/transition between categorical stages
  • Shows proportional relationships across multiple levels
  • Flow thickness proportional to count/aggregated value
  • Supports flexible multi-level flow configuration

For supported use cases, refer to the official documentation.

Classes

SankeyStrategy

SankeyStrategy(config: Dict[str, Any])

Bases: BasePlotStrategy

Strategy for Sankey diagram flow/alluvial visualizations.

This strategy creates Sankey diagrams showing flow relationships where nodes represent unique values and links show connections between stages.

Parameters:

Name Type Description Default
config Dict[str, Any]

Complete configuration from YAML file.

required

Attributes:

Name Type Description
data_config Dict[str, Any]

Data processing configuration.

plotly_config Dict[str, Any]

Plotly-specific configuration.

flow_columns List[str]

Columns defining the flow stages (in order).

value_column Optional[str]

Column for value aggregation (None = count occurrences).

aggregation str

Aggregation method: 'count', 'sum', 'nunique'.

node_pad int

Padding between nodes.

node_thickness int

Thickness of node rectangles.

color_by_stage bool

If True, color nodes by their stage level.

color_by_first_level bool

If True, color links by their source node at first level.

Methods:

Name Description
validate_data

Validate input data for Sankey diagram requirements

process_data

Process data for Sankey diagram

create_figure

Create Sankey diagram figure from processed flow data

get_stage_statistics

Calculate statistics for each stage in the flow

Notes
  • Flow thickness proportional to count or aggregated value
  • Supports multiple aggregation methods
  • Customizable node and link colors

Initialize strategy with configuration.

Parameters:

Name Type Description Default
config Dict[str, Any]

Complete configuration from YAML file.

required
Source code in src/domain/plot_strategies/charts/sankey_strategy.py
def __init__(self, config: Dict[str, Any]):
    """
    Initialize strategy with configuration.

    Parameters
    ----------
    config : Dict[str, Any]
        Complete configuration from YAML file.
    """
    super().__init__(config)
    self.data_config = config.get("data", {})
    self.plotly_config = self.viz_config.get("plotly", {})

    # Flow configuration
    self.flow_columns: List[str] = self.plotly_config.get("flow_columns", [])

    # Value/aggregation configuration
    self.value_column: Optional[str] = self.plotly_config.get("value_column", None)
    self.aggregation: str = self.plotly_config.get("aggregation", "count")

    # Node configuration
    self.node_pad: int = self.plotly_config.get("node_pad", 15)
    self.node_thickness: int = self.plotly_config.get("node_thickness", 20)

    # Color configuration
    self.color_by_stage: bool = self.plotly_config.get("color_by_stage", True)
    self.color_by_first_level: bool = self.plotly_config.get(
        "color_by_first_level", False
    )
    self.link_opacity: float = self.plotly_config.get("link_opacity", 0.5)
    self.node_colors: List[str] = self.plotly_config.get(
        "node_colors", DEFAULT_NODE_COLORS
    )

    # Link color configuration
    self.link_color: str = self.plotly_config.get(
        "link_color", "rgba(180, 180, 180, 0.5)"
    )

    logger.info(
        f"SankeyStrategy initialized for "
        f"{self.metadata.get('use_case_id', 'unknown')}: "
        f"flow_columns={self.flow_columns}, "
        f"aggregation='{self.aggregation}'"
    )
Functions
validate_data
validate_data(df: DataFrame) -> None

Validate input data for Sankey diagram requirements.

Parameters:

Name Type Description Default
df DataFrame

Input data to validate.

required

Raises:

Type Description
ValueError

If DataFrame is empty, flow columns missing, or fewer than 2 flow columns specified.

Source code in src/domain/plot_strategies/charts/sankey_strategy.py
def validate_data(self, df: pd.DataFrame) -> None:
    """
    Validate input data for Sankey diagram requirements.

    Parameters
    ----------
    df : pd.DataFrame
        Input data to validate.

    Raises
    ------
    ValueError
        If DataFrame is empty, flow columns missing, or fewer than
        2 flow columns specified.
    """
    logger.debug(
        f"Validating data - Shape: {df.shape}, " f"Columns: {df.columns.tolist()}"
    )

    if df.empty:
        raise ValueError("Input DataFrame is empty")

    # Validate flow columns exist
    if not self.flow_columns:
        raise ValueError(
            "No flow_columns specified in configuration. "
            "At least 2 columns are required for a Sankey diagram."
        )

    if len(self.flow_columns) < 2:
        raise ValueError(
            f"At least 2 flow columns required, got {len(self.flow_columns)}"
        )

    missing_cols = [c for c in self.flow_columns if c not in df.columns]
    if missing_cols:
        raise ValueError(
            f"Missing flow columns: {missing_cols}. "
            f"Available: {df.columns.tolist()}"
        )

    # Validate value column if specified
    if self.value_column and self.value_column not in df.columns:
        raise ValueError(
            f"Value column '{self.value_column}' not found. "
            f"Available: {df.columns.tolist()}"
        )

    logger.info(
        f"Data validation passed - {len(df)} records, "
        f"{len(self.flow_columns)} flow stages"
    )
process_data
process_data(df: DataFrame) -> pd.DataFrame

Process data for Sankey diagram: aggregate flows between stages.

Cleans data, removes placeholders, and aggregates flows between adjacent stages.

Parameters:

Name Type Description Default
df DataFrame

Input data with flow columns.

required

Returns:

Type Description
DataFrame

Processed data with source, target, value columns for each link.

Source code in src/domain/plot_strategies/charts/sankey_strategy.py
def process_data(self, df: pd.DataFrame) -> pd.DataFrame:
    """
    Process data for Sankey diagram: aggregate flows between stages.

    Cleans data, removes placeholders, and aggregates flows between
    adjacent stages.

    Parameters
    ----------
    df : pd.DataFrame
        Input data with flow columns.

    Returns
    -------
    pd.DataFrame
        Processed data with source, target, value columns for each link.
    """
    logger.info(f"Processing Sankey data with {len(self.flow_columns)} stages...")

    # Select only flow columns and drop nulls
    df_sankey = df[self.flow_columns].copy()
    initial_count = len(df_sankey)
    df_sankey = df_sankey.dropna()

    # Clean string values
    for col in self.flow_columns:
        df_sankey[col] = df_sankey[col].astype(str).str.strip()

    # Remove placeholder values
    placeholders = ["#N/D", "#N/A", "N/D", "", "nan", "None", "NaN"]
    for col in self.flow_columns:
        df_sankey = df_sankey[~df_sankey[col].isin(placeholders)]

    cleaned_count = len(df_sankey)
    logger.info(f"Data cleaned: {initial_count} -> {cleaned_count} rows")

    if df_sankey.empty:
        raise ValueError("No valid data remaining after cleaning flow columns.")

    # Aggregate: count occurrences for each unique path
    if self.value_column and self.value_column in df.columns:
        # Include value column for aggregation
        df_sankey["_value"] = df[self.value_column].loc[df_sankey.index]
        if self.aggregation == "sum":
            df_grouped = (
                df_sankey.groupby(self.flow_columns)["_value"]
                .sum()
                .reset_index(name="value")
            )
        elif self.aggregation == "nunique":
            df_grouped = (
                df_sankey.groupby(self.flow_columns)["_value"]
                .nunique()
                .reset_index(name="value")
            )
        else:
            df_grouped = (
                df_sankey.groupby(self.flow_columns)
                .size()
                .reset_index(name="value")
            )
    else:
        # Count occurrences
        df_grouped = (
            df_sankey.groupby(self.flow_columns).size().reset_index(name="value")
        )

    logger.info(f"Aggregated to {len(df_grouped)} unique paths")

    return df_grouped
create_figure
create_figure(processed_df: DataFrame) -> go.Figure

Create Sankey diagram figure from processed flow data.

Builds node and link structures, applies styling, and creates Plotly Sankey visualization.

Parameters:

Name Type Description Default
processed_df DataFrame

Processed data with flow columns and values.

required

Returns:

Type Description
Figure

Configured Plotly Sankey figure.

Source code in src/domain/plot_strategies/charts/sankey_strategy.py
def create_figure(self, processed_df: pd.DataFrame) -> go.Figure:
    """
    Create Sankey diagram figure from processed flow data.

    Builds node and link structures, applies styling, and creates
    Plotly Sankey visualization.

    Parameters
    ----------
    processed_df : pd.DataFrame
        Processed data with flow columns and values.

    Returns
    -------
    go.Figure
        Configured Plotly Sankey figure.
    """
    logger.debug("Creating Sankey diagram figure...")

    chart_config = self.plotly_config.get("chart", {})
    layout_config = self.plotly_config.get("layout", {})

    # Build node and link structures
    all_labels, label_to_id, node_colors = self._build_nodes(processed_df)
    sources, targets, values, link_colors = self._build_links(
        processed_df, label_to_id, all_labels
    )

    # Handle title configuration
    title_config = chart_config.get("title", {})
    if isinstance(title_config, str):
        # Backward compatibility: string title
        show_title = True
        title_text = title_config
        title_font_size = 16
    else:
        # New format: dict with show, text, font
        show_title = title_config.get("show", True)
        title_text = (
            title_config.get("text", "Sankey Diagram") if show_title else ""
        )
        title_font_size = title_config.get("font", {}).get("size", 16)

    # Create Sankey trace
    sankey_trace = go.Sankey(
        node=dict(
            pad=self.node_pad,
            thickness=self.node_thickness,
            line=dict(color="black", width=0.5),
            label=all_labels,
            color=node_colors,
        ),
        link=dict(source=sources, target=targets, value=values, color=link_colors),
    )

    # Create figure
    fig = go.Figure(data=[sankey_trace])

    # Layout configuration
    height = layout_config.get("height", 900)
    use_autosize = layout_config.get("autosize", True)
    template = layout_config.get("template", "simple_white")

    # Get font size (global label font size)
    font_size = layout_config.get("font_size", 10)

    # Get margin configuration
    margin_config = layout_config.get("margin", {})
    margin = dict(
        l=margin_config.get("l", 10),
        r=margin_config.get("r", 10),
        t=margin_config.get("t", 60),
        b=margin_config.get("b", 20),
    )

    # Get paper background color
    paper_bgcolor = layout_config.get("paper_bgcolor", "white")

    # Build layout update dict
    layout_update = dict(
        font_size=font_size,
        height=height,
        template=template,
        margin=margin,
        paper_bgcolor=paper_bgcolor,
    )

    # Add title if enabled
    if show_title and title_text:
        layout_update["title"] = dict(
            text=title_text,
            x=0.5,
            xanchor="center",
            font=dict(size=title_font_size),
        )

    # Add autosize or width
    if use_autosize:
        layout_update["autosize"] = True
    else:
        if layout_config.get("width"):
            layout_update["width"] = layout_config.get("width")

    fig.update_layout(**layout_update)

    # Log statistics
    n_nodes = len(all_labels)
    n_links = len(sources)
    logger.info(
        f"Sankey diagram created - {n_nodes} nodes, {n_links} links, "
        f"{len(self.flow_columns)} stages"
    )

    return fig
get_stage_statistics
get_stage_statistics(df: DataFrame) -> Dict[str, Any]

Calculate statistics for each stage in the flow.

Parameters:

Name Type Description Default
df DataFrame

Processed flow data.

required

Returns:

Type Description
Dict[str, Any]

Statistics per stage including unique counts and top values.

Source code in src/domain/plot_strategies/charts/sankey_strategy.py
def get_stage_statistics(self, df: pd.DataFrame) -> Dict[str, Any]:
    """
    Calculate statistics for each stage in the flow.

    Parameters
    ----------
    df : pd.DataFrame
        Processed flow data.

    Returns
    -------
    Dict[str, Any]
        Statistics per stage including unique counts and top values.
    """
    stats = {}
    for col in self.flow_columns:
        unique_count = df[col].nunique()
        top_values = df[col].value_counts().head(5).to_dict()
        stats[col] = {"unique_count": unique_count, "top_values": top_values}
    return stats
apply_filters
apply_filters(df: DataFrame, filters: Optional[Dict[str, Any]] = None) -> pd.DataFrame

Apply filters to data.

This is a common implementation that can be overridden by subclasses if needed.

Parameters:

Name Type Description Default
df DataFrame

Data to filter.

required
filters Optional[Dict[str, Any]]

Filter specifications.

None

Returns:

Type Description
DataFrame

Filtered data.

Source code in src/domain/plot_strategies/base/base_plot_strategy.py
def apply_filters(
    self, df: pd.DataFrame, filters: Optional[Dict[str, Any]] = None
) -> pd.DataFrame:
    """
    Apply filters to data.

    This is a common implementation that can be overridden
    by subclasses if needed.

    Parameters
    ----------
    df : pd.DataFrame
        Data to filter.
    filters : Optional[Dict[str, Any]], default=None
        Filter specifications.

    Returns
    -------
    pd.DataFrame
        Filtered data.
    """
    import logging

    logger = logging.getLogger(__name__)

    if not filters:
        logger.debug("No filters provided, returning original data")
        return df

    logger.info(
        f"Applying filters - Input shape: {df.shape}, "
        f"Columns: {df.columns.tolist()}"
    )
    logger.info(f"Filters to apply: {filters}")

    filtered_df = df.copy()

    # Get filter configurations
    filter_configs = self.config.get("filters", [])

    for filter_config in filter_configs:
        filter_id = filter_config.get("filter_id")
        filter_type = filter_config.get("type")

        if filter_id not in filters:
            continue

        filter_value = filters[filter_id]
        data_binding = filter_config.get("data_binding", {})
        column = data_binding.get("column")

        if not column or column not in filtered_df.columns:
            logger.warning(
                f"Filter '{filter_id}': Column '{column}' not found. "
                f"Available: {filtered_df.columns.tolist()}"
            )
            continue

        # Apply range filter
        if filter_type == "range" and isinstance(filter_value, list):
            min_val, max_val = filter_value
            logger.info(
                f"Applying range filter on '{column}': " f"[{min_val}, {max_val}]"
            )
            filtered_df = filtered_df[
                (filtered_df[column] >= min_val) & (filtered_df[column] <= max_val)
            ]
            logger.info(f"After filter: {len(filtered_df)} rows remaining")

    logger.info(f"Final filtered shape: {filtered_df.shape}")
    return filtered_df
apply_customizations
apply_customizations(fig: Figure, customizations: Optional[Any] = None) -> go.Figure

Apply custom styling to figure.

This is a hook for future customization features (FLEXIVEL and FLEXIVEL2).

Parameters:

Name Type Description Default
fig Figure

Base figure.

required
customizations Optional[Any]

Customization specifications.

None

Returns:

Type Description
Figure

Customized figure.

Source code in src/domain/plot_strategies/base/base_plot_strategy.py
def apply_customizations(
    self, fig: go.Figure, customizations: Optional[Any] = None
) -> go.Figure:
    """
    Apply custom styling to figure.

    This is a hook for future customization features
    (FLEXIVEL and FLEXIVEL2).

    Parameters
    ----------
    fig : go.Figure
        Base figure.
    customizations : Optional[Any], default=None
        Customization specifications.

    Returns
    -------
    go.Figure
        Customized figure.
    """
    # Hook for future implementation
    return fig
generate_plot
generate_plot(data: DataFrame, filters: Optional[Dict[str, Any]] = None, customizations: Optional[Any] = None) -> go.Figure

Generate complete plot (Template Method).

This method orchestrates the entire plot generation process: 1. Validate input data 2. Process data 3. Apply filters 4. Create figure 5. Apply customizations

Parameters:

Name Type Description Default
data DataFrame

Input data.

required
filters Optional[Dict[str, Any]]

Filters to apply.

None
customizations Optional[Any]

Customizations to apply.

None

Returns:

Type Description
Figure

Complete Plotly figure.

Raises:

Type Description
ValueError

If validation fails.

Source code in src/domain/plot_strategies/base/base_plot_strategy.py
def generate_plot(
    self,
    data: pd.DataFrame,
    filters: Optional[Dict[str, Any]] = None,
    customizations: Optional[Any] = None,
) -> go.Figure:
    """
    Generate complete plot (Template Method).

    This method orchestrates the entire plot generation process:
    1. Validate input data
    2. Process data
    3. Apply filters
    4. Create figure
    5. Apply customizations

    Parameters
    ----------
    data : pd.DataFrame
        Input data.
    filters : Optional[Dict[str, Any]], default=None
        Filters to apply.
    customizations : Optional[Any], default=None
        Customizations to apply.

    Returns
    -------
    go.Figure
        Complete Plotly figure.

    Raises
    ------
    ValueError
        If validation fails.
    """
    # 1. Validate
    self.validate_data(data)

    # 2. Process
    processed_df = self.process_data(data)

    # 3. Filter
    filtered_df = self.apply_filters(processed_df, filters)

    # 4. Create figure
    figure = self.create_figure(filtered_df)

    # 5. Apply customizations (hook for future)
    figure = self.apply_customizations(figure, customizations)

    return figure