Source code for bio_embeddings.extract.basic.BasicAnnotationExtractor
import logging
import torch
import collections
from typing import List, Union
from numpy import ndarray
from enum import Enum
from bio_embeddings.extract.annotations import Location, Membrane, Disorder, SecondaryStructure
from bio_embeddings.extract.basic.annotation_inference_models import SUBCELL_FNN, SECSTRUCT_CNN
from bio_embeddings.utilities import get_device, get_model_file
logger = logging.getLogger(__name__)
# Label mappings
_loc_labels = {
0: Location.CELL_MEMBRANE,
1: Location.CYTOPLASM,
2: Location.ENDOPLASMATIC_RETICULUM,
3: Location.GOLGI_APPARATUS,
4: Location.LYSOSOME_OR_VACUOLE,
5: Location.MITOCHONDRION,
6: Location.NUCLEUS,
7: Location.PEROXISOME,
8: Location.PLASTID,
9: Location.EXTRACELLULAR
}
_mem_labels = {
0: Membrane.SOLUBLE,
1: Membrane.MEMBRANE
}
_dssp8_labels = {
0: SecondaryStructure.THREE_HELIX,
1: SecondaryStructure.ALPHA_HELIX,
2: SecondaryStructure.FIVE_HELIX,
3: SecondaryStructure.ISOLATED_BETA_BRIDGE,
4: SecondaryStructure.EXTENDED_STRAND,
5: SecondaryStructure.BEND,
6: SecondaryStructure.TURN,
7: SecondaryStructure.IRREGULAR
}
_dssp3_labels = {
0: SecondaryStructure.ALPHA_HELIX,
1: SecondaryStructure.EXTENDED_STRAND,
2: SecondaryStructure.IRREGULAR
}
_disor_labels = {
0: Disorder.ORDER,
1: Disorder.DISORDER
}
BasicSecondaryStructureResult = collections.namedtuple('BasicSecondaryStructureResult', 'DSSP3 DSSP8 disorder')
BasicSubcellularLocalizationResult = collections.namedtuple('BasicSubcellularLocalizationResult', 'localization membrane')
BasicExtractedAnnotations = collections.namedtuple('BasicExtractedAnnotations', 'DSSP3 DSSP8 disorder localization membrane')
[docs]class BasicAnnotationExtractor(object):
necessary_files = ["secondary_structure_checkpoint_file", "subcellular_location_checkpoint_file"]
def __init__(self, model_type: str, device: Union[None, str, torch.device] = None, **kwargs):
"""
Initialize annotation extractor. Must define non-positional arguments for paths of files.
:param secondary_structure_checkpoint_file: path of secondary structure inference model checkpoint file
:param subcellular_location_checkpoint_file: path of the subcellular location inference model checkpoint file
"""
self._options = kwargs
self._model_type = model_type
self._device = get_device(device)
# Create un-trained (raw) model and ensure self._model_type is valid
if self._model_type == "seqvec_from_publication":
self._subcellular_location_model = SUBCELL_FNN().to(self._device)
elif self._model_type == "bert_from_publication": # Drop batchNorm for ProtTrans models
self._subcellular_location_model = SUBCELL_FNN(use_batch_norm=False).to(self._device)
else:
print("You first need to define your custom model architecture.")
raise NotImplementedError
# Download the checkpoint files if needed
for file in self.necessary_files:
if not self._options.get(file):
self._options[file] = get_model_file(model=f"{self._model_type}_annotations_extractors", file=file)
self._secondary_structure_checkpoint_file = self._options['secondary_structure_checkpoint_file']
self._subcellular_location_checkpoint_file = self._options['subcellular_location_checkpoint_file']
# Read in pre-trained model
self._secondary_structure_model = SECSTRUCT_CNN().to(self._device)
# load pre-trained weights for annotation machines
subcellular_state = torch.load(self._subcellular_location_checkpoint_file, map_location=self._device)
secondary_structure_state = torch.load(self._secondary_structure_checkpoint_file, map_location=self._device)
# load pre-trained weights into raw model
self._subcellular_location_model.load_state_dict(subcellular_state['state_dict'])
self._secondary_structure_model.load_state_dict(secondary_structure_state['state_dict'])
# ensure that model is in evaluation mode (important for batchnorm and dropout)
self._subcellular_location_model.eval()
self._secondary_structure_model.eval()
[docs] def get_subcellular_location(self, raw_embedding: ndarray) -> BasicSubcellularLocalizationResult:
# Reduce embedding to fixed size, per-sequence (aka: Lx3x2014 --> 1024).
# This is similar to embedder.reduce_per_protein(),
# but more efficient since may be run in GPU (see self._device)
# TODO: xxmh I forgot that SeqVec requires different pooling to derive fixed size rep.
# SeqVec requires summing over 3 layers, ProtTrans models only extract last layers
# Quick&Dirty solution is to check for shape of embedding tensors as SeqVec has 3 dims,
# while ProtTrans should only have 2 dims.
# Better way would be to access some internal variable (probably I just missed this flag)
# XXCD: can check embedder type via protol in embed config, but this may become complicated...
if self._model_type == "seqvec_from_publication":
# SeqVec case
embedding = torch.tensor(raw_embedding).to(self._device).sum(dim=0).mean(dim=0, keepdim=True)
elif self._model_type == "bert_from_publication":
# Bert case
embedding = torch.tensor(raw_embedding).to(self._device).mean(dim=0, keepdim=True)
else:
raise NotImplementedError
yhat_loc, yhat_mem = self._subcellular_location_model(embedding)
pred_loc = _loc_labels[torch.max(yhat_loc, dim=1)[1].item()] # get index of output node with max. activation,
pred_mem = _mem_labels[torch.max(yhat_mem, dim=1)[1].item()] # this corresponds to the predicted class
return BasicSubcellularLocalizationResult(localization=pred_loc, membrane=pred_mem)
[docs] def get_secondary_structure(self, raw_embedding: ndarray) -> BasicSecondaryStructureResult:
# TODO: xxmh: same as for subcell loc.:
# SeqVec requires summing over layers while ProtTrans models only extract last layers
if self._model_type == "seqvec_from_publication":
# SeqVec case
embedding = torch.tensor(raw_embedding).to(self._device).sum(dim=0, keepdim=True).permute(0, 2, 1).unsqueeze(dim=-1)
elif self._model_type == "bert_from_publication":
# Bert case
# Flip dimensions for ProtTrans models in order to make feature dimension the first dimension
embedding = torch.tensor(raw_embedding).to(self._device).T[None, :, :, None]
else:
raise NotImplementedError
yhat_dssp3, yhat_dssp8, yhat_disor = self._secondary_structure_model(embedding)
pred_dssp3 = self._class2label(_dssp3_labels, yhat_dssp3)
pred_dssp8 = self._class2label(_dssp8_labels, yhat_dssp8)
pred_disor = self._class2label(_disor_labels, yhat_disor)
return BasicSecondaryStructureResult(DSSP3=pred_dssp3, DSSP8=pred_dssp8, disorder=pred_disor)
[docs] def get_annotations(self, raw_embedding: ndarray) -> BasicExtractedAnnotations:
secstruct = self.get_secondary_structure(raw_embedding)
subcell = self.get_subcellular_location(raw_embedding)
return BasicExtractedAnnotations(disorder=secstruct.disorder, DSSP8=secstruct.DSSP8,
DSSP3=secstruct.DSSP3, localization=subcell.localization,
membrane=subcell.membrane)
@staticmethod
def _class2label(label_dict, yhat) -> List[Enum]:
# get index of output node with max. activation (=predicted class)
class_indices = torch.max(yhat, dim=1)[1].squeeze()
return [label_dict[class_idx.item()] for class_idx in class_indices]