Source code for bio_embeddings.embed.prottrans_t5_bfd_embedder

import logging
import re
from itertools import zip_longest
from typing import List, Generator, Iterable, Optional

import torch
from numpy import ndarray
from transformers import T5Tokenizer, T5Model, T5EncoderModel

from bio_embeddings.embed.embedder_interfaces import EmbedderWithFallback

logger = logging.getLogger(__name__)


[docs]class ProtTransT5BFDEmbedder(EmbedderWithFallback): """Encoder of the ProtTrans T5 BFD model Note that this model alone takes 13GB, so you need a GPU with a lot of memory. """ _model: T5EncoderModel _decoder: bool = False name = "prottrans_t5_bfd" embedding_dimension = 1024 number_of_layers = 1 _necessary_directories = ["model_directory"] def __init__(self, **kwargs): """ Initialize T5 embedder. :param model_directory: """ super().__init__(**kwargs) self._model_directory = self._options["model_directory"] # Until we know whether we need the decoder, let's keep it here as an undocumented option. # Should the need arise we can just split this class in to an encoder and a decoder subclass # by setting one subclass to _decoder=True and the other to _decoder=False self._decoder = self._options.get("decoder", False) # make model if self._decoder: self._model = T5EncoderModel.from_pretrained(self._model_directory) else: self._model = T5Model.from_pretrained(self._model_directory) self._model = self._model.to(self._device).eval() self._model_fallback = None self._tokenizer = T5Tokenizer.from_pretrained( self._model_directory, do_lower_case=False ) def _get_fallback_model(self) -> T5Model: """ Returns the CPU model """ if not self._model_fallback: self._model_fallback = T5Model.from_pretrained(self._model_directory).eval() return self._model_fallback def _embed_batch_impl( self, batch: List[str], model: T5Model ) -> Generator[ndarray, None, None]: seq_lens = [len(seq) for seq in batch] # Remove rare amino acids batch = [re.sub(r"[UZOB]", "X", sequence) for sequence in batch] # transformers needs spaces between the amino acids batch = [" ".join(list(seq)) for seq in batch] if not batch: return ids = self._tokenizer.batch_encode_plus( batch, add_special_tokens=True, padding="longest" ) tokenized_sequences = torch.tensor(ids["input_ids"]).to(model.device) attention_mask = torch.tensor(ids["attention_mask"]).to(model.device) with torch.no_grad(): embeddings = model( input_ids=tokenized_sequences, attention_mask=attention_mask, decoder_input_ids=tokenized_sequences, ) # See comment in __init__ if self._decoder: embeddings = embeddings[0].cpu().numpy() else: embeddings = embeddings[2].cpu().numpy() for seq_num, seq_len in zip_longest(range(len(embeddings)), seq_lens): # slice off last position (special token) embedding = embeddings[seq_num][:seq_len] assert ( seq_len == embedding.shape[0] ), f"Sequence length mismatch: {seq_len} vs {embedding.shape[0]}" yield embedding
[docs] def embed_many( self, sequences: Iterable[str], batch_size: Optional[int] = None ) -> Generator[ndarray, None, None]: if batch_size is not None: raise RuntimeError( "There is a bug in batching T5, so you currently must set batch_size to `None` for T5" ) return super().embed_many(sequences, None)
[docs] @staticmethod def reduce_per_protein(embedding): return embedding.mean(axis=0)
[docs] def embed(self, sequence: str) -> ndarray: [embedding] = self.embed_batch([sequence]) return embedding