importwarningsfromtypingimportList,Generator,Union,Iterable,Optional,Any,Tuplefromitertoolsimportteeimporttorchfromesm.pretrainedimportload_model_and_alphabet_corefromnumpyimportndarrayfrombio_embeddings.embedimportEmbedderInterfacefrombio_embeddings.utilitiesimportget_model_filedefload_model_and_alphabet_local(model_location:str)->Tuple[Any,Any]:"""Custom bio_embeddings versions because we change names and don't have regression weights"""# We don't predict contactswarnings.filterwarnings("ignore",category=UserWarning,message="Regression weights not found, predicting contacts will not produce correct results.",)model_data=torch.load(model_location,map_location="cpu")returnload_model_and_alphabet_core(model_data,None)classESMEmbedderBase(EmbedderInterface):# The only thing we need to overwrite is the name and _picked_layerembedding_dimension=1280number_of_layers=1# Following ESM, we only consider layer 34 (ESM) or 33 (ESM1b)necessary_files=["model_file"]# https://github.com/facebookresearch/esm/issues/49#issuecomment-803110092max_len=1022_picked_layer:intdef__init__(self,device:Union[None,str,torch.device]=None,**kwargs):super().__init__(device,**kwargs)model,alphabet=load_model_and_alphabet_local(self._options["model_file"])self._model=model.to(self._device)self._batch_converter=alphabet.get_batch_converter()defembed(self,sequence:str)->ndarray:[embedding]=self.embed_batch([sequence])returnembeddingdefembed_batch(self,batch:List[str])->Generator[ndarray,None,None]:"""https://github.com/facebookresearch/esm/blob/dfa524df54f91ef45b3919a00aaa9c33f3356085/README.md#quick-start-"""batch,batch_copy=tee(batch)self._assert_max_len(batch_copy)data=[(str(pos),sequence)forpos,sequenceinenumerate(batch)]batch_labels,batch_strs,batch_tokens=self._batch_converter(data)withtorch.no_grad():results=self._model(batch_tokens.to(self._device),repr_layers=[self._picked_layer])token_embeddings=results["representations"][self._picked_layer]# Generate per-sequence embeddings via averaging# NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.fori,(_,seq)inenumerate(data):yieldtoken_embeddings[i,1:len(seq)+1].cpu().numpy()defembed_many(self,sequences:Iterable[str],batch_size:Optional[int]=None)->Generator[ndarray,None,None]:sequences,sequences_copy=tee(sequences)self._assert_max_len(sequences_copy)yield fromsuper().embed_many(sequences,batch_size)def_assert_max_len(self,sequences:Iterable[str]):max_len=max((len(i)foriinsequences),default=0)ifmax_len>self.max_len:raiseValueError(f"{self.name} only allows sequences up to {self.max_len} residues, "f"but your longest sequence is {max_len} residues long")@staticmethoddefreduce_per_protein(embedding:ndarray)->ndarray:returnembedding.mean(0)
[docs]classESMEmbedder(ESMEmbedderBase):"""ESM Embedder (Note: This is not ESM-1b) Rives, Alexander, et al. "Biological structure and function emerge from scaling unsupervised learning to 250 million protein sequences." Proceedings of the National Academy of Sciences 118.15 (2021). https://doi.org/10.1073/pnas.2016239118 """name="esm"_picked_layer=34
[docs]classESM1bEmbedder(ESMEmbedderBase):"""ESM-1b Embedder (Note: This is not the original ESM) Rives, Alexander, et al. "Biological structure and function emerge from scaling unsupervised learning to 250 million protein sequences." Proceedings of the National Academy of Sciences 118.15 (2021). https://doi.org/10.1073/pnas.2016239118 """name="esm1b"_picked_layer=33
[docs]classESM1vEmbedder(ESMEmbedderBase):"""ESM-1v Embedder (one of five) ESM1v uses an ensemble of five models, called `esm1v_t33_650M_UR90S_[1-5]`. An instance of this class is one of the five, specified by `ensemble_id`. Meier, Joshua, et al. "Language models enable zero-shot prediction of the effects of mutations on protein function." bioRxiv (2021). https://doi.org/10.1101/2021.07.09.450648 """name="esm1v"ensemble_id:int_picked_layer=33
[docs]def__init__(self,ensemble_id:int,device:Union[None,str,torch.device]=None,**kwargs):"""You must pass the number of the model (1-5) as first parameter, though you can override the weights file with model_file"""assertensemble_idinrange(1,6),"The model number must be in 1-5"self.ensemble_id=ensemble_id# EmbedderInterface assumes static model files, but we need to dynamically select one of the fiveif"model_file"notinkwargs:kwargs["model_file"]=get_model_file(model=self.name,file=f"model_{ensemble_id}_file")super().__init__(device,**kwargs)