import math
from copy import deepcopy
from typing import List, Dict, Type
import torch
from Bio import SeqIO
from pandas import DataFrame
from tqdm import tqdm
from bio_embeddings.mutagenesis.constants import PROBABILITIES_COLUMNS
try:
from bio_embeddings.mutagenesis.protbert_bfd import ProtTransBertBFDMutagenesis
except ImportError as e:
class ProtTransBertBFDMutagenesis:
def __init__(self, *args, **kwargs):
raise ImportError(
f"The 'transformers' required for protbert_bfd_mutagenesis is missing. "
"See https://docs.bioembeddings.com/#installation on how to install all extras"
) from e
from bio_embeddings.utilities import check_required, get_device, get_file_manager, read_mapping_file
# list of available mutagenesis protocols
_PROTOCOLS = {
"protbert_bfd_mutagenesis": ProtTransBertBFDMutagenesis,
}
[docs]def probabilities_as_dataframe(
mapping_file: DataFrame,
probabilities_all: Dict[str, List[Dict[str, float]]],
sequences: List[str],
) -> DataFrame:
"""Let's build a csv with all the data"""
records = []
for sequence, (sequence_id, probabilities) in zip(
sequences, probabilities_all.items()
):
for wild_type_amino_acid, position_probabilities in zip(
sequence, probabilities
):
records.append(
{
"id": sequence_id,
"original_id": mapping_file.loc[sequence_id]["original_id"],
"wild_type_amino_acid": wild_type_amino_acid,
**position_probabilities,
}
)
return DataFrame(records, columns=PROBABILITIES_COLUMNS)
[docs]def run(**kwargs):
"""BETA: in-silico mutagenesis using BertForMaskedLM
optional:
* model_directory
* device
* half_precision_model
* temperature: temperature for softmax
"""
required_kwargs = [
"protocol",
"prefix",
"stage_name",
"remapped_sequences_file",
"mapping_file",
]
check_required(kwargs, required_kwargs)
result_kwargs = deepcopy(kwargs)
if result_kwargs["protocol"] not in _PROTOCOLS:
raise RuntimeError(
f"Passed protocol {result_kwargs['protocol']}, but allowed are: {', '.join(_PROTOCOLS)}"
)
temperature = result_kwargs.setdefault("temperature", 1)
device = get_device(result_kwargs.get("device"))
model_class: Type[ProtTransBertBFDMutagenesis] = _PROTOCOLS[
result_kwargs["protocol"]
]
model = model_class(
device,
result_kwargs.get("model_directory"),
result_kwargs.get("half_precision_model"),
)
file_manager = get_file_manager()
file_manager.create_stage(result_kwargs["prefix"], result_kwargs["stage_name"])
# The mapping file contains the corresponding ids in the same order
sequences = [
str(entry.seq)
for entry in SeqIO.parse(result_kwargs["remapped_sequences_file"], "fasta")
]
mapping_file = read_mapping_file(result_kwargs["mapping_file"])
probabilities_all = dict()
with tqdm(total=int(mapping_file["sequence_length"].sum())) as progress_bar:
for sequence_id, original_id, sequence in zip(
mapping_file.index, mapping_file["original_id"], sequences
):
with torch.no_grad():
probabilities = model.get_sequence_probabilities(
sequence, temperature, progress_bar=progress_bar
)
for p in probabilities:
assert math.isclose(
1, (sum(p.values()) - p["position"]), rel_tol=1e-6
), "softmax values should add up to 1"
probabilities_all[sequence_id] = probabilities
residue_probabilities = probabilities_as_dataframe(
mapping_file, probabilities_all, sequences
)
probabilities_file = file_manager.create_file(
result_kwargs.get("prefix"),
result_kwargs.get("stage_name"),
"residue_probabilities_file",
extension=".csv",
)
residue_probabilities.to_csv(probabilities_file, index=False)
result_kwargs["residue_probabilities_file"] = probabilities_file
return result_kwargs