import logging
from pathlib import Path
from transformers import BertModel, BertTokenizer
from bio_embeddings.embed.prottrans_base_embedder import ProtTransBertBaseEmbedder
logger = logging.getLogger(__name__)
[docs]class ProtTransBertBFDEmbedder(ProtTransBertBaseEmbedder):
    """ProtTrans-Bert-BFD Embedder (ProtBert-BFD)
    Elnaggar, Ahmed, et al. "ProtTrans: Towards Cracking the Language of Life's
    Code Through Self-Supervised Deep Learning and High Performance Computing."
    arXiv preprint arXiv:2007.06225 (2020). https://arxiv.org/abs/2007.06225
    """
    _model: BertModel
    name = "prottrans_bert_bfd"
    embedding_dimension = 1024
    number_of_layers = 1
[docs]    def __init__(self, **kwargs):
        """Initialize Bert embedder.
        :param model_directory:
        :param half_precision_model:
        """
        super().__init__(**kwargs)
        self._model_directory = self._options["model_directory"]
        self._half_precision_model = self._options.get("half_precision_model", False)
        # make model
        self._model = BertModel.from_pretrained(self._model_directory)
        # Compute in half precision, which is a lot faster and saves us half the memory
        if self._half_precision_model:
            self._model = self._model.half()
        self._model = self._model.eval().to(self._device)
        self._model_fallback = None
        self._tokenizer = BertTokenizer(
            str(Path(self._model_directory) / "vocab.txt"), do_lower_case=False
        ) 
    def _get_fallback_model(self) -> BertModel:
        """ Returns the CPU model """
        if not self._model_fallback:
            self._model_fallback = BertModel.from_pretrained(
                self._model_directory
            ).eval()
        return self._model_fallback