[docs]classUniRepEmbedder(EmbedderInterface):"""UniRep Embedder Alley, E.C., Khimulya, G., Biswas, S. et al. Unified rational protein engineering with sequence-based deep representation learning. Nat Methods 16, 1315–1322 (2019). https://doi.org/10.1038/s41592-019-0598-1 We use a reimplementation of unirep: Ma, Eric, and Arkadij Kummer. "Reimplementing Unirep in JAX." bioRxiv (2020). https://doi.org/10.1101/2020.05.11.088344 """name="unirep"# An integer representing the size of the embedding.embedding_dimension=1900# An integer representing the number of layers from the RAW output of the LM.number_of_layers=1_params:Dict[str,Any]_apply_fun:Callable
[docs]def__init__(self,device:Union[None,str,torch.device]=None,**kwargs):fromjax_unirep.utilsimportload_params_1900fromjax_unirep.featurizeimportapply_funself._params=load_params_1900()self._apply_fun=apply_fun# For v2# https://github.com/ElArkk/jax-unirep/issues/107# from jax_unirep.utils import load_params# from jax_unirep.layers import mLSTM# from jax_unirep.utils import validate_mLSTM_params# self._params = load_params()[1]# _, self._apply_fun = mLSTM(output_dim=self.embedding_dimension)# validate_mLSTM_params(self._params, n_outputs=self.embedding_dimension)ifdevice:raiseNotImplementedError("UniRep does not allow configuring the device")super().__init__(device,**kwargs)
[docs]defembed(self,sequence:str)->ndarray:fromjaximportvmap,partialfromjax_unirep.utilsimportget_embeddings# https://github.com/sacdallago/bio_embeddings/issues/117ifnotsequence:returnnumpy.zeros((0,self.embedding_dimension))# Unirep only allows batching with sequences of the same length, so we don't do batching at allembedded_seqs=get_embeddings([sequence])# h and c refer to hidden and cell state# h contains all the hidden states, while h_final and c_final contain only the last stateh_final,c_final,h=vmap(partial(self._apply_fun,self._params))(embedded_seqs)# Go from a batch of 1, which is `(1, len(sequence), 1900)`, to `len(sequence), 1900)`returnnumpy.asarray(h[0])
[docs]@staticmethoddefreduce_per_protein(embedding:ndarray)->ndarray:# This is `h_avg` in jax-unirep terminologyreturnembedding.mean(axis=0)