Source code for bio_embeddings.mutagenesis.pipeline
importmathfromcopyimportdeepcopyfromtypingimportList,Dict,TypeimporttorchfromBioimportSeqIOfrompandasimportDataFramefromtqdmimporttqdmfrombio_embeddings.mutagenesis.constantsimportPROBABILITIES_COLUMNStry:frombio_embeddings.mutagenesis.protbert_bfdimportProtTransBertBFDMutagenesisexceptImportErrorase:classProtTransBertBFDMutagenesis:def__init__(self,*args,**kwargs):raiseImportError(f"The 'transformers' required for protbert_bfd_mutagenesis is missing. ""See https://docs.bioembeddings.com/#installation on how to install all extras")fromefrombio_embeddings.utilitiesimportcheck_required,get_device,get_file_manager,read_mapping_file# list of available mutagenesis protocols_PROTOCOLS={"protbert_bfd_mutagenesis":ProtTransBertBFDMutagenesis,}
[docs]defprobabilities_as_dataframe(mapping_file:DataFrame,probabilities_all:Dict[str,List[Dict[str,float]]],sequences:List[str],)->DataFrame:"""Let's build a csv with all the data"""records=[]forsequence,(sequence_id,probabilities)inzip(sequences,probabilities_all.items()):forwild_type_amino_acid,position_probabilitiesinzip(sequence,probabilities):records.append({"id":sequence_id,"original_id":mapping_file.loc[sequence_id]["original_id"],"wild_type_amino_acid":wild_type_amino_acid,**position_probabilities,})returnDataFrame(records,columns=PROBABILITIES_COLUMNS)
[docs]defrun(**kwargs):"""BETA: in-silico mutagenesis using BertForMaskedLM optional (see extract stage for details): * model_directory * device * half_precision * half_precision_model * temperature: temperature for softmax """required_kwargs=["protocol","prefix","stage_name","remapped_sequences_file","mapping_file",]check_required(kwargs,required_kwargs)result_kwargs=deepcopy(kwargs)ifresult_kwargs["protocol"]notin_PROTOCOLS:raiseRuntimeError(f"Passed protocol {result_kwargs['protocol']}, but allowed are: {', '.join(_PROTOCOLS)}")temperature=result_kwargs.setdefault("temperature",1)device=get_device(result_kwargs.get("device"))model_class:Type[ProtTransBertBFDMutagenesis]=_PROTOCOLS[result_kwargs["protocol"]]model=model_class(device,result_kwargs.get("model_directory"),result_kwargs.get("half_precision_model"),)file_manager=get_file_manager()file_manager.create_stage(result_kwargs["prefix"],result_kwargs["stage_name"])# The mapping file contains the corresponding ids in the same ordersequences=[str(entry.seq)forentryinSeqIO.parse(result_kwargs["remapped_sequences_file"],"fasta")]mapping_file=read_mapping_file(result_kwargs["mapping_file"])probabilities_all=dict()withtqdm(total=int(mapping_file["sequence_length"].sum()))asprogress_bar:forsequence_id,original_id,sequenceinzip(mapping_file.index,mapping_file["original_id"],sequences):withtorch.no_grad():probabilities=model.get_sequence_probabilities(sequence,temperature,progress_bar=progress_bar)forpinprobabilities:assertmath.isclose(1,(sum(p.values())-p["position"]),rel_tol=1e-6),"softmax values should add up to 1"probabilities_all[sequence_id]=probabilitiesresidue_probabilities=probabilities_as_dataframe(mapping_file,probabilities_all,sequences)probabilities_file=file_manager.create_file(result_kwargs.get("prefix"),result_kwargs.get("stage_name"),"residue_probabilities_file",extension=".csv",)residue_probabilities.to_csv(probabilities_file,index=False)result_kwargs["residue_probabilities_file"]=probabilities_filereturnresult_kwargs