Source code for bio_embeddings.embed.prottrans_bert_bfd_embedder

import logging
from pathlib import Path

from transformers import BertModel, BertTokenizer

from bio_embeddings.embed.prottrans_base_embedder import ProtTransBertBaseEmbedder

logger = logging.getLogger(__name__)

[docs]class ProtTransBertBFDEmbedder(ProtTransBertBaseEmbedder): """ProtTrans-Bert-BFD Embedder (ProtBert-BFD) Elnaggar, Ahmed, et al. "ProtTrans: Towards Cracking the Language of Life's Code Through Self-Supervised Deep Learning and High Performance Computing." arXiv preprint arXiv:2007.06225 (2020). """ _model: BertModel name = "prottrans_bert_bfd" embedding_dimension = 1024 number_of_layers = 1
[docs] def __init__(self, **kwargs): """Initialize Bert embedder. :param model_directory: :param half_precision_model: """ super().__init__(**kwargs) self._model_directory = self._options["model_directory"] self._half_precision_model = self._options.get("half_precision_model", False) # make model self._model = BertModel.from_pretrained(self._model_directory) # Compute in half precision, which is a lot faster and saves us half the memory if self._half_precision_model: self._model = self._model.half() self._model = self._model.eval().to(self._device) self._model_fallback = None self._tokenizer = BertTokenizer( str(Path(self._model_directory) / "vocab.txt"), do_lower_case=False )
def _get_fallback_model(self) -> BertModel: """ Returns the CPU model """ if not self._model_fallback: self._model_fallback = BertModel.from_pretrained( self._model_directory ).eval() return self._model_fallback