Skip to content

Chord Diagram Strategy

chord_strategy

Chord Diagram Strategy.

This module implements the ChordStrategy class following the Strategy Pattern, providing logic for generating chord diagrams that visualize relationships between categorical entities as arcs connecting nodes arranged in a circle.

Classes:

Name Description
ChordStrategy

Concrete strategy for chord diagram generation using Plotly

Notes

Chord diagrams are ideal for: - Visualizing relationships between entities (source-target pairs) - Showing interaction strength between categories - Displaying sample similarity networks - Mapping set intersections and overlaps

Processing Modes: 1. Direct Aggregation Mode (mode='aggregation'): - Groups by source and target columns - Aggregates count of interactions

  1. Pairwise Similarity Mode (mode='pairwise'):
  2. Computes shared entities between pairs
  3. Creates links based on intersection size

  4. Set Intersection Mode (mode='set_intersection'):

  5. Computes pairwise intersections between named sets
  6. Links represent overlap size between sets

For supported use cases, refer to the official documentation.

Version: 1.0.0

Classes

ChordStrategy

ChordStrategy(config: Dict[str, Any])

Bases: BasePlotStrategy

Chord diagram strategy for network relationship visualizations.

This strategy creates chord diagrams showing relationships between categorical entities where: - Nodes: Categories arranged in a circle - Arcs: Connections between nodes weighted by interaction strength - Colors: Node/arc coloring based on category

Attributes:

Name Type Description
data_config Dict[str, Any]

Data processing configuration

plotly_config Dict[str, Any]

Plotly-specific configuration

mode str

Processing mode: 'aggregation', 'pairwise', or 'set_intersection'

source_column str

Column name for source entities

target_column str

Column name for target entities

value_column Optional[str]

Column for pre-aggregated values (if any)

group_by_column Optional[str]

Column for grouping in pairwise mode

shared_column Optional[str]

Column for computing shared entities in pairwise mode

Notes

Required YAML configuration structure:

visualization:
  strategy: "ChordStrategy"
  plotly:
    mode: "aggregation"  # or "pairwise", "set_intersection"
    source_column: "sample"
    target_column: "compoundclass"
    # For pairwise mode:
    # group_by_column: "sample"
    # shared_column: "compoundname"
    chart:
      title:
        text: "Sample-Compound Interactions"
      colorscale: "Category20"
    layout:
      height: 800
      width: 800

Refer to the official documentation for supported use cases and detailed configuration examples.

Initialize chord diagram strategy.

Parameters:

Name Type Description Default
config Dict[str, Any]

Complete configuration from YAML file

required
Source code in src/domain/plot_strategies/charts/chord_strategy.py
def __init__(self, config: Dict[str, Any]):
    """
    Initialize chord diagram strategy.

    Parameters
    ----------
    config : Dict[str, Any]
        Complete configuration from YAML file
    """
    super().__init__(config)
    self.data_config = config.get("data", {})
    self.plotly_config = self.viz_config.get("plotly", {})

    # Processing mode
    self.mode: str = self.plotly_config.get("mode", "aggregation")

    # Column configuration
    self.source_column: str = self.plotly_config.get("source_column", "source")
    self.target_column: str = self.plotly_config.get("target_column", "target")
    self.value_column: Optional[str] = self.plotly_config.get("value_column", None)

    # Pairwise mode configuration
    self.group_by_column: Optional[str] = self.plotly_config.get(
        "group_by_column", None
    )
    self.shared_column: Optional[str] = self.plotly_config.get(
        "shared_column", None
    )

    # Set intersection mode configuration
    self.set_column: Optional[str] = self.plotly_config.get("set_column", None)
    self.element_column: Optional[str] = self.plotly_config.get(
        "element_column", None
    )

    # Visual configuration
    self.colorscale: str = self.plotly_config.get("colorscale", "Category20")
    self.min_link_value: int = self.plotly_config.get("min_link_value", 1)

    # Circle arcs configuration
    circle_arc_config = self.plotly_config.get("circle_arcs", {})
    self.circle_arcs_enabled: bool = circle_arc_config.get("enabled", True)
    self.circle_arc_width: int = circle_arc_config.get("width", 15)
    self.circle_arc_gap: float = circle_arc_config.get("gap", 0.02)
    self.circle_arc_radius: float = circle_arc_config.get("radius", 1.0)
    self.circle_arc_proportional: bool = circle_arc_config.get("proportional", True)

    # Labels configuration
    label_config = self.plotly_config.get("labels", {})
    self.labels_enabled: bool = label_config.get("enabled", True)
    self.labels_radius_offset: float = label_config.get("radius_offset", 1.15)
    self.labels_font: Dict[str, Any] = label_config.get(
        "font", {"size": 11, "color": "#2c3e50"}
    )

    # Chord configuration
    chord_config = self.plotly_config.get("chords", {})
    self.chord_anchor: str = chord_config.get("anchor", "arc_midpoint")

    logger.info(
        f"ChordStrategy initialized for "
        f"{self.metadata.get('use_case_id', 'unknown')}: "
        f"mode='{self.mode}', "
        f"source='{self.source_column}', target='{self.target_column}', "
        f"circle_arcs={'enabled' if self.circle_arcs_enabled else 'disabled'}"
    )
Functions
validate_data
validate_data(df: DataFrame) -> None

Validate input data for chord diagram requirements.

Parameters:

Name Type Description Default
df DataFrame

Input data to validate

required

Raises:

Type Description
ValueError

If any validation rule fails

Source code in src/domain/plot_strategies/charts/chord_strategy.py
def validate_data(self, df: pd.DataFrame) -> None:
    """
    Validate input data for chord diagram requirements.

    Parameters
    ----------
    df : pd.DataFrame
        Input data to validate

    Raises
    ------
    ValueError
        If any validation rule fails
    """
    logger.debug(
        f"Validating data - Shape: {df.shape}, " f"Columns: {df.columns.tolist()}"
    )

    if df.empty:
        raise ValueError("Input DataFrame is empty")

    # Mode-specific validation
    if self.mode == "aggregation":
        required_cols = [self.source_column, self.target_column]
        missing = [c for c in required_cols if c not in df.columns]
        if missing:
            raise ValueError(
                f"Missing columns for aggregation mode: {missing}. "
                f"Available: {df.columns.tolist()}"
            )

    elif self.mode == "pairwise":
        if not self.group_by_column or not self.shared_column:
            raise ValueError(
                "Pairwise mode requires 'group_by_column' and "
                "'shared_column' configuration."
            )
        required_cols = [self.group_by_column, self.shared_column]
        missing = [c for c in required_cols if c not in df.columns]
        if missing:
            raise ValueError(
                f"Missing columns for pairwise mode: {missing}. "
                f"Available: {df.columns.tolist()}"
            )

    elif self.mode == "set_intersection":
        if not self.set_column or not self.element_column:
            raise ValueError(
                "Set intersection mode requires 'set_column' and "
                "'element_column' configuration."
            )
        required_cols = [self.set_column, self.element_column]
        missing = [c for c in required_cols if c not in df.columns]
        if missing:
            raise ValueError(
                f"Missing columns for set_intersection mode: {missing}. "
                f"Available: {df.columns.tolist()}"
            )

    else:
        raise ValueError(
            f"Unknown mode: '{self.mode}'. "
            f"Supported: 'aggregation', 'pairwise', 'set_intersection'"
        )

    logger.info(f"Data validation passed - {len(df)} records, mode='{self.mode}'")
process_data
process_data(df: DataFrame) -> pd.DataFrame

Process data and create links DataFrame for chord diagram.

Parameters:

Name Type Description Default
df DataFrame

Input data with required columns

required

Returns:

Type Description
DataFrame

Links DataFrame with columns: source, target, value

Source code in src/domain/plot_strategies/charts/chord_strategy.py
def process_data(self, df: pd.DataFrame) -> pd.DataFrame:
    """
    Process data and create links DataFrame for chord diagram.

    Parameters
    ----------
    df : pd.DataFrame
        Input data with required columns

    Returns
    -------
    pd.DataFrame
        Links DataFrame with columns: source, target, value
    """
    logger.info(f"Processing data in '{self.mode}' mode...")

    if self.mode == "aggregation":
        links = self._process_aggregation(df)
    elif self.mode == "pairwise":
        links = self._process_pairwise(df)
    elif self.mode == "set_intersection":
        links = self._process_set_intersection(df)
    else:
        raise ValueError(f"Unknown mode: {self.mode}")

    # Filter by minimum link value
    if self.min_link_value > 1:
        initial_count = len(links)
        links = links[links["value"] >= self.min_link_value]
        logger.debug(
            f"Filtered links by min_value={self.min_link_value}: "
            f"{initial_count} -> {len(links)}"
        )

    if links.empty:
        raise ValueError(
            "No valid links after processing. "
            "Check data or adjust min_link_value."
        )

    logger.info(
        f"Links created - {len(links)} connections, "
        f"Value range: [{links['value'].min()}, {links['value'].max()}]"
    )

    return links
create_figure
create_figure(processed_df: DataFrame) -> go.Figure

Create chord diagram figure from processed links data.

This implementation uses Plotly's graph_objects to create a custom chord diagram since Plotly Express doesn't have native chord support.

Parameters:

Name Type Description Default
processed_df DataFrame

Processed links data with source, target, value columns

required

Returns:

Type Description
Figure

Configured Plotly figure with chord diagram

Source code in src/domain/plot_strategies/charts/chord_strategy.py
def create_figure(self, processed_df: pd.DataFrame) -> go.Figure:
    """
    Create chord diagram figure from processed links data.

    This implementation uses Plotly's graph_objects to create a custom
    chord diagram since Plotly Express doesn't have native chord support.

    Parameters
    ----------
    processed_df : pd.DataFrame
        Processed links data with source, target, value columns

    Returns
    -------
    go.Figure
        Configured Plotly figure with chord diagram
    """
    logger.debug("Creating chord diagram figure...")

    chart_config = self.plotly_config.get("chart", {})
    layout_config = self.plotly_config.get("layout", {})

    # Title
    title_config = chart_config.get("title", {})
    show_title = title_config.get("show", True)
    title_text = title_config.get("text", "Chord Diagram") if show_title else ""

    # Get all unique nodes
    all_nodes = list(
        set(
            processed_df["source"].unique().tolist()
            + processed_df["target"].unique().tolist()
        )
    )
    all_nodes.sort()
    n_nodes = len(all_nodes)

    # Create node index mapping
    node_to_idx = {node: idx for idx, node in enumerate(all_nodes)}

    # Get colorscale
    colors = self._get_color_palette(n_nodes)

    # Calculate node connection values
    node_values = [
        self._count_connections(node, processed_df) for node in all_nodes
    ]
    total_value = sum(node_values)

    # Calculate arc spans for each node
    arc_spans = self._calculate_arc_spans(n_nodes, node_values, total_value)

    # Calculate value range for normalization
    min_value = processed_df["value"].min()
    max_value = processed_df["value"].max()

    logger.debug(
        f"Value range for line width: min={min_value}, " f"max={max_value}"
    )

    # Create traces
    traces = []

    # Add circle arc traces for each node
    if self.circle_arcs_enabled:
        for idx, node in enumerate(all_nodes):
            start_angle, end_angle = arc_spans[idx]
            arc_trace = self._create_circle_arc_trace(
                start_angle,
                end_angle,
                self.circle_arc_radius,
                colors[idx],
                node,
                node_values[idx],
            )
            traces.append(arc_trace)

    # Add chord (connection) traces for each link
    for _, row in processed_df.iterrows():
        source_idx = node_to_idx[row["source"]]
        target_idx = node_to_idx[row["target"]]
        value = row["value"]

        # Calculate connection points based on anchor setting
        if self.chord_anchor == "arc_midpoint":
            source_angle = (arc_spans[source_idx][0] + arc_spans[source_idx][1]) / 2
            target_angle = (arc_spans[target_idx][0] + arc_spans[target_idx][1]) / 2
        else:
            # Default to simple angle positions
            source_angle = arc_spans[source_idx][0]
            target_angle = arc_spans[target_idx][0]

        source_x = self.circle_arc_radius * np.cos(source_angle)
        source_y = self.circle_arc_radius * np.sin(source_angle)
        target_x = self.circle_arc_radius * np.cos(target_angle)
        target_y = self.circle_arc_radius * np.sin(target_angle)

        # Create bezier curve for the chord
        chord_trace = self._create_chord_trace(
            source_x,
            source_y,
            target_x,
            target_y,
            value,
            colors[source_idx],
            row["source"],
            row["target"],
            min_value,
            max_value,
        )
        traces.append(chord_trace)

    # Add label traces
    if self.labels_enabled:
        label_radius = self.circle_arc_radius * self.labels_radius_offset
        label_angles = [
            (arc_spans[i][0] + arc_spans[i][1]) / 2 for i in range(n_nodes)
        ]
        label_x = label_radius * np.cos(label_angles)
        label_y = label_radius * np.sin(label_angles)

        label_trace = go.Scatter(
            x=label_x.tolist(),
            y=label_y.tolist(),
            mode="text",
            text=all_nodes,
            textfont=self.labels_font,
            hoverinfo="skip",
            showlegend=False,
        )
        traces.append(label_trace)

    # Create figure
    fig = go.Figure(data=traces)

    # Layout configuration
    height = layout_config.get("height", 800)
    use_autosize = layout_config.get("autosize", False)

    layout_update = {
        "title": dict(
            text=title_text,
            x=0.5,
            xanchor="center",
            font=title_config.get("font", dict(size=16)),
        ),
        "showlegend": False,
        "hovermode": "closest",
        "xaxis": dict(
            showgrid=False, zeroline=False, showticklabels=False, range=[-1.5, 1.5]
        ),
        "yaxis": dict(
            showgrid=False,
            zeroline=False,
            showticklabels=False,
            range=[-1.5, 1.5],
            scaleanchor="x",
            scaleratio=1,
        ),
        "height": height,
        "template": layout_config.get("template", "simple_white"),
        "margin": layout_config.get("margin", dict(l=50, r=50, t=80, b=50)),
    }

    # Add autosize if enabled
    if use_autosize:
        layout_update["autosize"] = True
    else:
        if layout_config.get("width"):
            layout_update["width"] = layout_config.get("width", 800)

    fig.update_layout(**layout_update)

    logger.info(
        f"Chord diagram created - " f"{n_nodes} nodes, {len(processed_df)} links"
    )

    return fig
apply_filters
apply_filters(df: DataFrame, filters: Optional[Dict[str, Any]] = None) -> pd.DataFrame

Apply filters to data.

This is a common implementation that can be overridden by subclasses if needed.

Parameters:

Name Type Description Default
df DataFrame

Data to filter.

required
filters Optional[Dict[str, Any]]

Filter specifications.

None

Returns:

Type Description
DataFrame

Filtered data.

Source code in src/domain/plot_strategies/base/base_plot_strategy.py
def apply_filters(
    self, df: pd.DataFrame, filters: Optional[Dict[str, Any]] = None
) -> pd.DataFrame:
    """
    Apply filters to data.

    This is a common implementation that can be overridden
    by subclasses if needed.

    Parameters
    ----------
    df : pd.DataFrame
        Data to filter.
    filters : Optional[Dict[str, Any]], default=None
        Filter specifications.

    Returns
    -------
    pd.DataFrame
        Filtered data.
    """
    import logging

    logger = logging.getLogger(__name__)

    if not filters:
        logger.debug("No filters provided, returning original data")
        return df

    logger.info(
        f"Applying filters - Input shape: {df.shape}, "
        f"Columns: {df.columns.tolist()}"
    )
    logger.info(f"Filters to apply: {filters}")

    filtered_df = df.copy()

    # Get filter configurations
    filter_configs = self.config.get("filters", [])

    for filter_config in filter_configs:
        filter_id = filter_config.get("filter_id")
        filter_type = filter_config.get("type")

        if filter_id not in filters:
            continue

        filter_value = filters[filter_id]
        data_binding = filter_config.get("data_binding", {})
        column = data_binding.get("column")

        if not column or column not in filtered_df.columns:
            logger.warning(
                f"Filter '{filter_id}': Column '{column}' not found. "
                f"Available: {filtered_df.columns.tolist()}"
            )
            continue

        # Apply range filter
        if filter_type == "range" and isinstance(filter_value, list):
            min_val, max_val = filter_value
            logger.info(
                f"Applying range filter on '{column}': " f"[{min_val}, {max_val}]"
            )
            filtered_df = filtered_df[
                (filtered_df[column] >= min_val) & (filtered_df[column] <= max_val)
            ]
            logger.info(f"After filter: {len(filtered_df)} rows remaining")

    logger.info(f"Final filtered shape: {filtered_df.shape}")
    return filtered_df
apply_customizations
apply_customizations(fig: Figure, customizations: Optional[Any] = None) -> go.Figure

Apply custom styling to figure.

This is a hook for future customization features (FLEXIVEL and FLEXIVEL2).

Parameters:

Name Type Description Default
fig Figure

Base figure.

required
customizations Optional[Any]

Customization specifications.

None

Returns:

Type Description
Figure

Customized figure.

Source code in src/domain/plot_strategies/base/base_plot_strategy.py
def apply_customizations(
    self, fig: go.Figure, customizations: Optional[Any] = None
) -> go.Figure:
    """
    Apply custom styling to figure.

    This is a hook for future customization features
    (FLEXIVEL and FLEXIVEL2).

    Parameters
    ----------
    fig : go.Figure
        Base figure.
    customizations : Optional[Any], default=None
        Customization specifications.

    Returns
    -------
    go.Figure
        Customized figure.
    """
    # Hook for future implementation
    return fig
generate_plot
generate_plot(data: DataFrame, filters: Optional[Dict[str, Any]] = None, customizations: Optional[Any] = None) -> go.Figure

Generate complete plot (Template Method).

This method orchestrates the entire plot generation process: 1. Validate input data 2. Process data 3. Apply filters 4. Create figure 5. Apply customizations

Parameters:

Name Type Description Default
data DataFrame

Input data.

required
filters Optional[Dict[str, Any]]

Filters to apply.

None
customizations Optional[Any]

Customizations to apply.

None

Returns:

Type Description
Figure

Complete Plotly figure.

Raises:

Type Description
ValueError

If validation fails.

Source code in src/domain/plot_strategies/base/base_plot_strategy.py
def generate_plot(
    self,
    data: pd.DataFrame,
    filters: Optional[Dict[str, Any]] = None,
    customizations: Optional[Any] = None,
) -> go.Figure:
    """
    Generate complete plot (Template Method).

    This method orchestrates the entire plot generation process:
    1. Validate input data
    2. Process data
    3. Apply filters
    4. Create figure
    5. Apply customizations

    Parameters
    ----------
    data : pd.DataFrame
        Input data.
    filters : Optional[Dict[str, Any]], default=None
        Filters to apply.
    customizations : Optional[Any], default=None
        Customizations to apply.

    Returns
    -------
    go.Figure
        Complete Plotly figure.

    Raises
    ------
    ValueError
        If validation fails.
    """
    # 1. Validate
    self.validate_data(data)

    # 2. Process
    processed_df = self.process_data(data)

    # 3. Filter
    filtered_df = self.apply_filters(processed_df, filters)

    # 4. Create figure
    figure = self.create_figure(filtered_df)

    # 5. Apply customizations (hook for future)
    figure = self.apply_customizations(figure, customizations)

    return figure