Source code for bio_embeddings.embed.prottrans_xlnet_uniref100_embedder

import logging
import re
from itertools import zip_longest
from pathlib import Path
from typing import Optional, Generator, List

import torch
from numpy import ndarray
from transformers import XLNetModel, XLNetTokenizer

from bio_embeddings.embed.embedder_interfaces import EmbedderInterface

logger = logging.getLogger(__name__)

[docs]class ProtTransXLNetUniRef100Embedder(EmbedderInterface): """ProtTrans-XLNet-UniRef100 Embedder (ProtXLNet) Elnaggar, Ahmed, et al. "ProtTrans: Towards Cracking the Language of Life's Code Through Self-Supervised Deep Learning and High Performance Computing." arXiv preprint arXiv:2007.06225 (2020). """ name = "prottrans_xlnet_uniref100" embedding_dimension = 1024 number_of_layers = 1 _model: XLNetModel _model_fallback: Optional[XLNetModel] necessary_directories = ["model_directory"]
[docs] def __init__(self, **kwargs): """ Initialize XLNet embedder. :param model_directory: """ super().__init__(**kwargs) # Get file locations from kwargs self.model_directory = self._options["model_directory"] # 512 is from self._model = ( XLNetModel.from_pretrained(self.model_directory, mem_len=512) .to(self._device) .eval() ) self._model_fallback = None # sentence piece model # A standard text tokenizer which creates the input for NNs trained on text. # This one is just indexing single amino acids because we only have words of L=1. spm_model = str(Path(self.model_directory).joinpath("spm_model.model")) self._tokenizer = XLNetTokenizer.from_pretrained(spm_model, do_lower_case=False)
[docs] def embed(self, sequence: str) -> ndarray: [embedding] = self.embed_batch([sequence]) return embedding
[docs] def embed_batch(self, batch: List[str]) -> Generator[ndarray, None, None]: seq_lens = [len(seq) for seq in batch] # transformers needs spaces between the amino acids batch = [" ".join(list(seq)) for seq in batch] # Remove rare amino acids batch = [re.sub(r"[UZOBX]", "<unk>", sequence) for sequence in batch] ids = self._tokenizer.batch_encode_plus( batch, add_special_tokens=True, padding="longest" ) tokenized_sequences = torch.tensor(ids["input_ids"]).to(self._model.device) attention_mask = torch.tensor(ids["attention_mask"]).to(self._model.device) with torch.no_grad(): embeddings, memory = self._model( input_ids=tokenized_sequences, attention_mask=attention_mask, mems=None, return_dict=False, ) embeddings = embeddings.cpu().numpy() for seq_num, seq_len in zip_longest(range(len(embeddings)), seq_lens): attention_len = (attention_mask[seq_num] == 1).sum() padded_seq_len = len(attention_mask[seq_num]) embedding = embeddings[seq_num][ padded_seq_len - attention_len : padded_seq_len - 2 ] assert ( seq_len == embedding.shape[0] ), f"Sequence length mismatch: {seq_len} vs {embedding.shape[0]}" yield embedding
[docs] @staticmethod def reduce_per_protein(embedding): return embedding.mean(axis=0)