Source code for bio_embeddings.embed.prottrans_t5_embedder
importabcimportloggingimportrefromitertoolsimportzip_longestfromtypingimportList,Generator,UnionimporttorchimporttransformersfromnumpyimportndarrayfromtransformersimportT5Tokenizer,T5Model,T5EncoderModelfrombio_embeddings.embed.embedder_interfacesimportEmbedderWithFallbacklogger=logging.getLogger(__name__)classFilterT5DecoderWeightsWarning(logging.Filter):"""transformers complains at length that we pass decoder weights when initializing only the encoder, which we can ignore"""deffilter(self,record:logging.LogRecord)->bool:return("were not used when initializing T5EncoderModel: ['decoder."notinrecord.getMessage())transformers.modeling_utils.logger.addFilter(FilterT5DecoderWeightsWarning())classProtTransT5Embedder(EmbedderWithFallback,abc.ABC):"""Encoder of the ProtTrans T5 model, both BFD and BFD finetuned on UniRef50. To embed please pick either ProtTransT5BFDEmbedder or ProtTransT5UniRef50Embedder Note that this model alone takes 13GB, so you need a GPU with a lot of memory. """_model:Union[T5Model,T5EncoderModel]_decoder:bool=False_half_precision_model:bool=Falseembedding_dimension=1024number_of_layers=1necessary_directories=["model_directory"]def__init__(self,**kwargs):""" Initialize T5 embedder. :param str model_directory: where the weights of the model can be found :param device: whether to compute on the CPU or GPU :type device: str or torch.device or None :param bool decoder: Whether to use also the decoder (default: False) :param bool half_precision_model: Use the model in half precision (float16) mode (default: False) """super().__init__(**kwargs)self._model_directory=self._options["model_directory"]# Until we know whether we need the decoder, let's keep it here as an undocumented option.# Should the need arise we can just split this class in to an encoder and a decoder subclass# by setting one subclass to _decoder=True and the other to _decoder=Falseself._decoder=self._options.get("decoder",False)self._half_precision_model=self._options.get("half_precision_model",False)self._model=self.get_model().to(self._device).eval()self._model_fallback=Noneself._tokenizer=T5Tokenizer.from_pretrained(self._model_directory,do_lower_case=False)defget_model(self)->Union[T5Model,T5EncoderModel]:ifnotself._decoder:model=T5EncoderModel.from_pretrained(self._model_directory)else:model=T5Model.from_pretrained(self._model_directory)# Compute in half precision, which is a lot faster and saves us half the memoryifself._half_precision_model:model=model.half()returnmodeldef_get_fallback_model(self)->Union[T5Model,T5EncoderModel]:""" Returns the CPU model """ifself._half_precision_model:raiseNotImplementedError("You sequence was too long for the GPU, ""but we can't fall back to the CPU with half_precision_model=True ""(https://github.com/huggingface/transformers/issues/11546)")ifnotself._model_fallback:self._model_fallback=self.get_model()returnself._model_fallbackdef_embed_batch_impl(self,batch:List[str],model:T5Model)->Generator[ndarray,None,None]:seq_lens=[len(seq)forseqinbatch]# Remove rare amino acidsbatch=[re.sub(r"[UZOB]","X",sequence)forsequenceinbatch]# transformers needs spaces between the amino acidsbatch=[" ".join(list(seq))forseqinbatch]ifnotbatch:returnids=self._tokenizer.batch_encode_plus(batch,add_special_tokens=True,padding="longest")tokenized_sequences=torch.tensor(ids["input_ids"]).to(model.device)attention_mask=torch.tensor(ids["attention_mask"]).to(model.device)withtorch.no_grad():ifnotself._decoder:embeddings=model(input_ids=tokenized_sequences,attention_mask=attention_mask)else:embeddings=model(input_ids=tokenized_sequences,attention_mask=attention_mask,decoder_input_ids=tokenized_sequences,)embeddings=embeddings[0].cpu().numpy()forseq_num,seq_leninzip_longest(range(len(embeddings)),seq_lens):# slice off last position (special token)embedding=embeddings[seq_num][:seq_len]assert(seq_len==embedding.shape[0]),f"Sequence length mismatch: {seq_len} vs {embedding.shape[0]}"yieldembedding@staticmethoddefreduce_per_protein(embedding):returnembedding.mean(axis=0)defembed(self,sequence:str)->ndarray:[embedding]=self.embed_batch([sequence])returnembedding
[docs]classProtTransT5BFDEmbedder(ProtTransT5Embedder):"""Encoder of the ProtTrans T5 model trained on BFD. Consider using :class:`ProtTransT5XLU50Embedder` instead of this We recommend settings `half_model=True`, which on the tested GPU (Quadro RTX 3000) reduces memory consumption from 12GB to 7GB while the effect in benchmarks is negligible (±0.1 percentages points in different sets, generally below standard error) """name="prottrans_t5_bfd"
[docs]classProtTransT5UniRef50Embedder(ProtTransT5Embedder):"""Encoder of the ProtTrans T5 model trained on BFD and finetuned on UniRef 50. Consider using :class:`ProtTransT5XLU50Embedder` instead of this We recommend settings `half_model=True`, which on the tested GPU (Quadro RTX 3000) reduces memory consumption from 12GB to 7GB while the effect in benchmarks is negligible (±0.1 percentages points in different sets, generally below standard error) """name="prottrans_t5_uniref50"
[docs]classProtTransT5XLU50Embedder(ProtTransT5Embedder):"""Encoder of the ProtTrans T5 model trained on BFD and finetuned on UniRef 50. We recommend settings `half_model=True`, which on the tested GPU (Quadro RTX 3000) reduces memory consumption from 12GB to 7GB while the effect in benchmarks is negligible (±0.1 percentages points in different sets, generally below standard error) """name="prottrans_t5_xl_u50"