Source code for drugforge.spectrum.fitness

import json

import numpy as np
import pandas as pd
from drugforge.data.metadata.resources import (
    SARS_CoV_2_fitness_data,
    ZIKV_NS2B_NS3pro_fitness_data,
    ZIKV_RdRppro_fitness_data,
    targets_with_fitness_data,
)
from drugforge.data.services.postera.manifold_data_validation import (
    TargetTags,
    TargetVirusMap,
    VirusTags,
)

_TARGET_TO_GENE = {  # contains some entries for finding targets when subselecting a genome-wide fitness result.
    TargetTags("SARS-CoV-2-Mpro").value: "nsp5 (Mpro)",
    TargetTags("SARS-CoV-2-Mac1").value: "nsp3",
    TargetTags("SARS-CoV-2-N-protein").value: "N",
}

_TARGET_TO_FITNESS_DATA = {  # points to the vendored fitness data.
    TargetTags("SARS-CoV-2-Mpro").value: SARS_CoV_2_fitness_data,
    TargetTags("SARS-CoV-2-Mac1").value: SARS_CoV_2_fitness_data,
    TargetTags("SARS-CoV-2-N-protein").value: SARS_CoV_2_fitness_data,
    TargetTags("ZIKV-NS2B-NS3pro").value: ZIKV_NS2B_NS3pro_fitness_data,
    TargetTags("ZIKV-RdRppro").value: ZIKV_RdRppro_fitness_data,
}

_FITNESS_DATA_IS_CROSSGENOME = {  # sets whether the fitness data we have for this virus is the whole genome or a single target.
    VirusTags("SARS-CoV-2").value: True,
    VirusTags("ZIKV").value: False,
}

_FITNESS_DATA_FIT_THRESHOLD = {  # sets threshold at which a mutant is considered 'fit' for the specific fitness experiment. Directed by Bloom et al.
    VirusTags("SARS-CoV-2").value: -1.0,
    VirusTags("ZIKV").value: -1.0,  # this is OK for both NS2B-NS3pro and RdRppro
}


[docs] def target_has_fitness_data(target: TargetTags) -> bool: return target in targets_with_fitness_data
[docs] def bloom_abstraction(fitness_scores_this_site: dict, threshold: float) -> int: """ Applies prescribed abstraction of how mutable a residue is given fitness data. Although the mean fitness was used at first, the current (2023.08.08) prescribed method is as follows (by Bloom et al): > something like “what is the number of mutations at a site that are reasonably well tolerated.” You could do this as something like number (or fraction) of mutations at a site that have a score >= -1 (that is probably a reasonable cutoff), using -1 as a cutoff where mutations start to cross from “highly deleterious” to “conceivably tolerated.” Parameters ---------- fitness_scores_this_site: dict Dictionary containing fitness scores for a single site threshold: float fitness value to use as minimum value threshold to treat a mutation as acceptably fit. Returns ------- num_tolerated_mutations: int """ tolerated_mutations = [ val for val in fitness_scores_this_site["fitness"] if val >= threshold ] return len(tolerated_mutations)
[docs] def apply_bloom_abstraction(fitness_dataframe: pd.DataFrame, threshold: float) -> dict: """ Read a pandas DF containing fitness data parsed from a JSON in .parse_fitness_json() and return a processed dictionary with averaged fitness scores per residue. This is the current recommended method to get to a single value per residue. This function can be extended when the recommendation changes. Parameters ---------- fitness_dataframe: pd.DataFrame DataFrame containing columns [gene, site, mutant, fitness, expected_count, wildtype] threshold: float fitness value to use as minimum value threshold to treat a mutation as acceptably fit. Returns ------- fitness_dict : dict Dictionary where keys are residue indices underscored with chain IDs, keys are: [ mean_fitness, wildtype_residue, most fit mutation, least fit mutation, total count (~confidence) ] """ # add this column in case we're pulling in an experiment that has different data. We need to find # a good way of dealing with all this data coming from different labs. See Issue #649 if "expected_count" not in fitness_dataframe.columns: fitness_dataframe["expected_count"] = 0 fitness_dict = {} for (idx, chain), site_df in fitness_dataframe.groupby(by=["site", "chain"]): # remove wild type fitness score (this is always 0) fitness_scores_this_site = site_df[site_df["fitness"] != 0] # add all values to a dict fitness_dict[f"{idx}_{chain}"] = [ bloom_abstraction(fitness_scores_this_site, threshold), fitness_scores_this_site["wildtype"].values[0], # wildtype residue fitness_scores_this_site.sort_values(by="fitness")["mutant"].values[ -1 ], # most fit mutation fitness_scores_this_site.sort_values(by="fitness")["mutant"].values[ 0 ], # least fit mutation np.sum(fitness_scores_this_site["expected_count"].values), # total count ] return fitness_dict
[docs] def normalize_fitness(fitness_df_abstract: pd.DataFrame) -> pd.DataFrame: """ Read a pandas DF containing fitness data and normalizes values to 0-1. Normalization is as MinMax: - fitness: 0-100 ranges from non-fit to most fit (i.e., >>100 would mean residue is highly mutable). - confidence: 0-1 ranges from least confident to most confident. Parameters ---------- fitness_df_abstract: pd.DataFrame Dataframe containing per-residue fitness data. Returns ------- fitness_df_abstract: pd.DataFrame Dataframe containing per-residue fitness data normalized. """ return fitness_df_abstract # can reactivate below as required - return the above makes fitness categorical from 1 to n, where # n = number of fit mutants fitness_df_abstract["fitness"] = ( (fitness_df_abstract["fitness"] - fitness_df_abstract["fitness"].min()) / (fitness_df_abstract["fitness"].max() - fitness_df_abstract["fitness"].min()) * 100 ) fitness_df_abstract["confidence"] = ( fitness_df_abstract["confidence"] - fitness_df_abstract["confidence"].min() ) / ( fitness_df_abstract["confidence"].max() - fitness_df_abstract["confidence"].min() ) return fitness_df_abstract
[docs] def parse_fitness_json(target: TargetTags) -> pd.DataFrame: """ Read a per-aa fitness JSON's specified target into a pandas DF. Parameters ---------- target: str Specifies the target and virus, conforming to drugforge.data.postera.manifold_data_validation.TargetTags Returns ------- fitness_df_abstract : pandas DataFrame Dataframe where indices are residue numbers, columns are: "fitness" -> normalized fitness (0 is not mutable, 1 is highly mutable) "wildtype_residue" "most_fit_mutation" "least_fit_mutation" "confidence" -> normalized confidence (0 is not confident, 1 is highly confident) """ if target not in TargetTags.get_values(): raise ValueError( f"Specified target is not valid, must be one of: {TargetTags.get_values()}" ) if not target_has_fitness_data(target): raise NotImplementedError( f"Fitness data not yet available for {target}. Add to metadata if/when available." ) fitness_scores_bloom = get_fitness_scores_bloom_by_target(target) threshold = _FITNESS_DATA_FIT_THRESHOLD[TargetVirusMap[target]] # now apply the abstraction currently recommended by Bloom et al to get to a single float per residue. fitness_dict_abstract = apply_bloom_abstraction(fitness_scores_bloom, threshold) fitness_df_abstract = pd.DataFrame.from_dict( fitness_dict_abstract, orient="index", columns=[ "fitness", "wildtype_residue", "most_fit_mutation", "least_fit_mutation", "confidence", ], ) fitness_df_abstract.index.name = "residue" # normalize fitness and confidence values to 0-1 for easier parsing by visualizers downstream and return df. # can instead return DF if ever we need to provide more info (top/worst mutation, confidence etc). fitness_df_abstract = normalize_fitness(fitness_df_abstract) return dict(zip(fitness_df_abstract.index, fitness_df_abstract["fitness"]))
[docs] def get_fitness_scores_bloom_by_target(target: TargetTags) -> pd.DataFrame: # find the virus that corresponds to the target virus = TargetVirusMap[target] # find the fitness data that corresponds to the virus fitness_data = _TARGET_TO_FITNESS_DATA[target] # read the fitness data into a dataframe with open(fitness_data) as f: data = json.load(f) if "data" in data.keys(): # this is SARS-CoV-2 cross-genome phylo data - can directly grab 'data' key from json. data = data["data"] fitness_scores_bloom = pd.DataFrame(data) elif "ZIKV NS2B-NS3 (Closed)" in data.keys(): # this is ZIKV NS2B-NS3 DMS data - need to grab data differently from json. data = data["ZIKV NS2B-NS3 (Closed)"]["mut_metric_df"] fitness_scores_bloom = pd.DataFrame(data).rename( columns={"reference_site": "site", "Log2(Effect)": "fitness"} ) if _FITNESS_DATA_IS_CROSSGENOME[virus]: # now get the target-specific entries. Need to do because the phylo data is cross-genome. fitness_scores_bloom = fitness_scores_bloom[ fitness_scores_bloom["gene"] == _TARGET_TO_GENE[target] ] else: pass # no need to subselect # post-processing for specific targets # TODO: replace all the below by a more intelligent + robust alignment algorithm. if target == "SARS-CoV-2-Mac1": # need to subselect from nsp3 multidomain to get just Mac1. See https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7113668/ fitness_scores_bloom = fitness_scores_bloom[ fitness_scores_bloom["site"].between(209, 372) ] fitness_scores_bloom["site"] -= 204 # PDB starts at resindex 5 fitness_scores_bloom["chain"] = "A" elif target == "SARS-CoV-2-Mpro": fitness_scores_bloom["chain"] = "A" elif target == "SARS-CoV-2-N-protein": # For N-protein, we want to show both monomers in the dimer because they are inter-locked and ligands may bind the interface, # so can't get away with showing one of the monomers as blue. We'll make separate rows in the data for each chain. # first double the DF. doubled_df = pd.concat([fitness_scores_bloom] * 2).reset_index() # now add chain A/C to the first/second half of the doubled DF. doubled_df.loc[: len(fitness_scores_bloom), "chain"] = "A" doubled_df.loc[len(fitness_scores_bloom) :, "chain"] = "C" if not len(doubled_df[doubled_df["chain"] == "A"].values) == len( doubled_df[doubled_df["chain"] == "C"].values ): raise ValueError( "Chain lengths between chains A/C are not equal - unable to naively duplicate fitness data across; please debug." ) else: fitness_scores_bloom = doubled_df elif target == "ZIKV-NS2B-NS3pro": # cursed. TODO: replace this with an auto-align. ns2b_section = fitness_scores_bloom[ fitness_scores_bloom["site"].str.contains("NS2B") ] ns2b_section.loc[ns2b_section.index, "site"] = [ int(site.replace("(NS2B)", "")) for site in ns2b_section["site"].values ] ns2b_section = ns2b_section[ns2b_section["site"].between(46, 89)] # tag NS2B as chain A ns2b_section["chain"] = "A" # repeat for NS3 ns3_section = fitness_scores_bloom[ fitness_scores_bloom["site"].str.contains("NS3") ] ns3_section.loc[ns3_section.index, "site"] = [ int(site.replace("(NS3)", "")) for site in ns3_section["site"].values ] ns3_section = ns3_section[ns3_section["site"].between(10, 177)] # tag NS3 as chain B ns3_section["chain"] = "B" # then add back together and treat as normal downstream. fitness_scores_bloom = pd.concat([ns2b_section, ns3_section]) return fitness_scores_bloom