import logging
import re
from typing import Union, Optional, List, Dict

import torch
import transformers
from tqdm import tqdm
from transformers import BertTokenizer, BertForMaskedLM

from bio_embeddings.embed import ProtTransBertBFDEmbedder
from bio_embeddings.mutagenesis import AMINO_ACIDS
from bio_embeddings.utilities import (

class FilterBertForMaskedLMWeightsWarning(logging.Filter):
    """transformers complains that we don't use some of the weights with BertForMaskedLM instead of BertModel,
    which we can ignore"""

    def filter(self, record: logging.LogRecord) -> bool:
        return (
            "were not used when initializing BertForMaskedLM: "
            "['cls.seq_relationship.weight', 'cls.seq_relationship.bias']"
            not in record.getMessage()


[docs]class ProtTransBertBFDMutagenesis: """BETA: in-silico mutagenesis using BertForMaskedLM""" device: torch.device model: BertForMaskedLM tokenizer: BertTokenizer _half_precision_model: bool
[docs] def __init__( self, device: Union[None, str, torch.device] = None, model_directory: Optional[str] = None, half_precision_model: bool = False, ): """Loads the Bert Model for Masked LM""" self.device = get_device(device) self._half_precision_model = half_precision_model if not model_directory: model_directory = get_model_directories_from_zip(, directory="model_directory" ) self.tokenizer = BertTokenizer.from_pretrained( model_directory, do_lower_case=False ) self.model = BertForMaskedLM.from_pretrained(model_directory) # Compute in half precision, which is a lot faster and saves us half the memory if self._half_precision_model: self.model = self.model.half() self.model = self.model.eval().to(self.device)
[docs] def get_sequence_probabilities( self, sequence: str, temperature: float = 1, start: Optional[int] = None, stop: Optional[int] = None, progress_bar: Optional[tqdm] = None, ) -> List[Dict[str, float]]: """Returns the likelihood for each of the 20 natural amino acids to be at residue positions between `start` and `end` considering the context of the remainder of the sequence (aka: by using. BERT's mask token and reconstructing the corrupted sequence). Probabilities may be adjusted by a `temperature` factor. If set to `1` (default) no adjustment is made. :param sequence: The amino acid sequence. Please pass whole sequences, not regions :param start: the start index (inclusive) of the region for which to compute residue probabilities (starting with 0) :param stop: the end (exclusive) of the region for which to compute residue probabilities :param temperature: temperature for the softmax computation :param progress_bar: optional tqdm progress bar :return: An ordered list for the region of probabilities for each of the 20 natural amino acids to be at said position.""" # # init softmax to get mutagenesis later on sm = torch.nn.Softmax(dim=0) AA_tokens = [ self.tokenizer.convert_tokens_to_ids(AA) for AA in list(AMINO_ACIDS) ] # Create L sequences with each position masked once probabilities_list = list() # Remove rare amino acids current_sequence = re.sub(r"[UZOB]", "X", sequence) # Mask each token individually for i in range(start or 0, stop or len(sequence)): masked_sequence = list(current_sequence) masked_sequence = ( masked_sequence[:i] + [self.tokenizer.mask_token] + masked_sequence[i + 1 :] ) # Each AA is a word, so we need spaces in between masked_sequence = " ".join(masked_sequence) tokenized_sequence = self.tokenizer.encode( masked_sequence, return_tensors="pt" ) # get the position of the masked token # noinspection PyTypeChecker masked_position = torch.nonzero( tokenized_sequence.squeeze() == self.tokenizer.mask_token_id ).item() # TODO: can batch this! output = self.model( last_hidden_state = output[0].squeeze(0) # only get output for masked token # output is the size of the vocabulary mask_hidden_state = last_hidden_state[masked_position].cpu() # convert to mutagenesis (softmax) # giving a probability for each item in the vocabulary probabilities = sm(mask_hidden_state / temperature) # Get a dictionary of AA and probability of it being there at given position result = dict( zip(list(AMINO_ACIDS), [probabilities[AA].item() for AA in AA_tokens]) ) result["position"] = i # Append orderly to mutagenesis probabilities_list.append(result) if progress_bar: progress_bar.update() return probabilities_list