Source code for bio_embeddings.visualize.mutagenesis

import pandas
import plotly
from pandas import DataFrame
from plotly import express as px, graph_objects as go
from plotly.graph_objs import Figure
from tqdm import tqdm

from bio_embeddings.mutagenesis import PROBABILITIES_COLUMNS, AMINO_ACIDS
from bio_embeddings.utilities import check_required, get_file_manager


def plot(probabilities: DataFrame) -> Figure:
    """Given the DataFrame from the previous stage and returns a heatmap"""
    x_labels = list(
        probabilities["position"].astype(str)
        + " "
        + probabilities["wild_type_amino_acid"]
    )

    # Only the probabilities for the amino acids
    values = probabilities[list(AMINO_ACIDS)].values.T

    fig = px.imshow(
        values,
        labels=dict(x="WT sequence", y="AA", color="Probability"),
        color_continuous_scale="blues",
        x=x_labels,
        y=list(AMINO_ACIDS),
        zmin=0,
        zmax=1,
        # Somehow makes the plot approximately the right size
        width=max(len(x_labels), 20) * 20,
        title=probabilities["original_id"].iloc[0],
    )

    fig.update_layout(
        plot_bgcolor="rgba(0,0,0,0)",
        xaxis=dict(
            tickmode="linear",
        ),
        yaxis=dict(
            tickmode="linear",
        ),
    )

    fig.add_trace(
        go.Scatter(x=x_labels, y=probabilities["wild_type_amino_acid"], mode="markers")
    )

    fig.update_xaxes(fixedrange=True)
    fig.update_yaxes(fixedrange=True)

    return fig


[docs]def plot_mutagenesis(result_kwargs): """BETA: visualize in-silico mutagenesis as a heatmap with plotly mandatory: * residue_probabilities_file """ required_kwargs = [ "protocol", "prefix", "stage_name", "residue_probabilities_file", ] check_required(result_kwargs, required_kwargs) file_manager = get_file_manager() file_manager.create_stage(result_kwargs["prefix"], result_kwargs["stage_name"]) probabilities_all = pandas.read_csv(result_kwargs["residue_probabilities_file"]) assert ( list(probabilities_all.columns) == PROBABILITIES_COLUMNS ), f"probabilities file is expected to have the following columns: {PROBABILITIES_COLUMNS}" number_of_proteins = len(set(probabilities_all["id"])) for sequence_id, probabilities in tqdm( probabilities_all.groupby("id"), total=number_of_proteins ): fig = plot(probabilities) plotly.offline.plot( fig, filename=file_manager.create_file( result_kwargs.get("prefix"), result_kwargs.get("stage_name"), sequence_id, extension=".html", ), ) return result_kwargs