Source code for bio_embeddings.project.pb_tucker

from pathlib import Path
from typing import Union

import torch
from numpy import ndarray
from torch import nn


[docs]class PBTuckerModel(nn.Module): """This is the torch module behind :class:`PBTucker`"""
[docs] def __init__(self): super(PBTuckerModel, self).__init__() self.tucker = nn.Sequential( nn.Linear(1024, 512), nn.Tanh(), nn.Linear(512, 128), )
[docs] def forward(self, data: torch.tensor) -> torch.tensor: return self.tucker(data)
[docs]class PBTucker: """Tucker is a contrastive learning model trained to distinguish CATH superfamilies. It consumes prottrans_bert_bfd embeddings and reduces the embedding dimensionality from 1024 to 128. See https://www.biorxiv.org/content/10.1101/2021.01.21.427551v1 To use it outside of the pipeline, first instantiate it with `pb_tucker = PBTucker("/path/to/model", device)`, then project your reduced bert embedding with `pb_tucker.project_reduced_embedding(bert_embedding)`. """ _device: torch.device name: str = "pb_tucker"
[docs] def __init__(self, model_file: Union[str, Path], device: torch.device): self._device = device self.model = PBTuckerModel() self.model.load_state_dict( torch.load(model_file, map_location=device)["state_dict"] ) self.model.eval() self.model = self.model.to(self._device)
[docs] def project_reduced_embedding(self, reduced_embedding: ndarray) -> ndarray: with torch.no_grad(): reduced_embedding_tensor = torch.tensor( reduced_embedding, device=self._device ) return self.model.tucker(reduced_embedding_tensor).cpu().numpy()