import pandas
import plotly
from pandas import DataFrame
from plotly import express as px, graph_objects as go
from plotly.graph_objs import Figure
from tqdm import tqdm
from bio_embeddings.mutagenesis import PROBABILITIES_COLUMNS, AMINO_ACIDS
from bio_embeddings.utilities import check_required, get_file_manager
def plot(probabilities: DataFrame) -> Figure:
"""Given the DataFrame from the previous stage and returns a heatmap"""
x_labels = list(
probabilities["position"].astype(str)
+ " "
+ probabilities["wild_type_amino_acid"]
)
# Only the probabilities for the amino acids
values = probabilities[list(AMINO_ACIDS)].values.T
fig = px.imshow(
values,
labels=dict(x="WT sequence", y="AA", color="Probability"),
color_continuous_scale="blues",
x=x_labels,
y=list(AMINO_ACIDS),
zmin=0,
zmax=1,
# Somehow makes the plot approximately the right size
width=max(len(x_labels), 20) * 20,
title=probabilities["original_id"].iloc[0],
)
fig.update_layout(
plot_bgcolor="rgba(0,0,0,0)",
xaxis=dict(
tickmode="linear",
),
yaxis=dict(
tickmode="linear",
),
)
fig.add_trace(
go.Scatter(x=x_labels, y=probabilities["wild_type_amino_acid"], mode="markers")
)
fig.update_xaxes(fixedrange=True)
fig.update_yaxes(fixedrange=True)
return fig
[docs]def plot_mutagenesis(result_kwargs):
"""BETA: visualize in-silico mutagenesis as a heatmap with plotly
mandatory:
* residue_probabilities_file
"""
required_kwargs = [
"protocol",
"prefix",
"stage_name",
"residue_probabilities_file",
]
check_required(result_kwargs, required_kwargs)
file_manager = get_file_manager()
file_manager.create_stage(result_kwargs["prefix"], result_kwargs["stage_name"])
probabilities_all = pandas.read_csv(result_kwargs["residue_probabilities_file"])
assert (
list(probabilities_all.columns) == PROBABILITIES_COLUMNS
), f"probabilities file is expected to have the following columns: {PROBABILITIES_COLUMNS}"
number_of_proteins = len(set(probabilities_all["id"]))
for sequence_id, probabilities in tqdm(
probabilities_all.groupby("id"), total=number_of_proteins
):
fig = plot(probabilities)
plotly.offline.plot(
fig,
filename=file_manager.create_file(
result_kwargs.get("prefix"),
result_kwargs.get("stage_name"),
sequence_id,
extension=".html",
),
)
return result_kwargs