import plotly
import plotly.express as px
from plotly.graph_objects import Figure as _Figure
from pandas import DataFrame
[docs]def render_3D_scatter_plotly(embeddings_dataframe: DataFrame) -> _Figure:
    """
    Return a Plotly Figure (3D scatter plot) based on a DataFrame containing three components.
    :param embeddings_dataframe: the DataFrame *must* contain three numerical columns called `component_0`,
            `component_1` and `component_2`. The DataFrame index will be used to identify the points in the
            scatter plot. Optionally, the DataFrame may contain a column called `label` which will be used
            to color the points in the scatter plot.
    :return: A 3D scatter plot
    """
    if 'label' in embeddings_dataframe.columns:
        fig = px.scatter_3d(embeddings_dataframe,
                            x='component_0',
                            y='component_1',
                            z='component_2',
                            color='label',
                            symbol='label',
                            hover_name=embeddings_dataframe.index,
                            hover_data=["label"]
                            )
    else:
        fig = px.scatter_3d(embeddings_dataframe,
                            x='component_0',
                            y='component_1',
                            z='component_2',
                            hover_name=embeddings_dataframe.index,
                            )
    fig.update_layout(
        # Remove axes ticks and labels as they are usually not informative
        scene=dict(
            xaxis=dict(
                showticklabels=False,
                showspikes=False,
                title=""
            ),
            yaxis=dict(
                showticklabels=False,
                showspikes=False,
                title=""
            ),
            zaxis=dict(
                showticklabels=False,
                showspikes=False,
                title=""
            )
        ),
    )
    return fig 
[docs]def render_scatter_plotly(embeddings_dataframe: DataFrame) -> _Figure:
    """
    Return a Plotly Figure (2D scatter plot) based on a DataFrame containing three components.
    :param embeddings_dataframe: the DataFrame *must* contain two numerical columns called `component_0`
            and `component_1`. The DataFrame index will be used to identify the points in the
            scatter plot. Optionally, the DataFrame may contain a column called `label` which will be used
            to color the points in the scatter plot.
    :return: A 2D scatter plot
    """
    hover_data = []
    if "original_id" in embeddings_dataframe.columns:
        hover_data.append("original_id")
    if "label" in embeddings_dataframe.columns:
        hover_data.append("label")
    if 'label' in embeddings_dataframe.columns:
        fig = px.scatter(embeddings_dataframe,
                         x='component_0',
                         y='component_1',
                         color='label',
                         symbol='label',
                         hover_name=embeddings_dataframe.index,
                         hover_data=hover_data
                         )
    else:
        fig = px.scatter(embeddings_dataframe,
                         x='component_0',
                         y='component_1',
                         hover_name=embeddings_dataframe.index,
                         hover_data=hover_data
                         )
    fig.update_layout(
        # Remove axes ticks and labels as they are usually not informative
        scene=dict(
            xaxis=dict(
                showticklabels=False,
                showspikes=False,
                title=""
            ),
            yaxis=dict(
                showticklabels=False,
                showspikes=False,
                title=""
            )
        ),
    )
    return fig