Skip to content

Commit

Permalink
update requirements and docker image
Browse files Browse the repository at this point in the history
  • Loading branch information
YojanaGadiya committed Jun 24, 2024
1 parent 0733831 commit 2b09cd7
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 186 deletions.
3 changes: 1 addition & 2 deletions AMR-KG_Database.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@
divider="orange",
help="Stats on the underlying data.",
)
DATA_DIR = "../data"
df = pd.read_csv(f"{DATA_DIR}/processed/combined_bioassay_data.tsv", sep="\t")
df = pd.read_csv("data/processed/combined_bioassay_data.tsv", sep="\t")


def get_base_stats():
Expand Down
15 changes: 7 additions & 8 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Select base image (can be ubuntu, python, shiny etc)
FROM python:3.12-slim
FROM python:3.10-slim

# Create user name and home directory variables.
# The variables are later used as $USER and $HOME.
Expand All @@ -22,17 +22,16 @@ COPY requirements.txt $HOME/kg/requirements.txt
COPY pages $HOME/kg/pages/
COPY AMR-KG_Database.py $HOME/kg/AMR-KG_Database.py
COPY data $HOME/kg/data/
COPY models/ $HOME/kg/models/
COPY start-script.sh $HOME/kg/start-script.sh
RUN mkdir $HOME/kg/models

RUN apt-get install curl -y
RUN curl https://zenodo.org/api/records/3407840/files/chem_phys_rf.pkl -o $HOME/kg/models/chem_phys_rf.pkl -L
RUN curl https://zenodo.org/api/records/3407840/files/ecfp_rf.pkl -o $HOME/kg/models/ecfp_rf.pkl -L
RUN curl https://zenodo.org/api/records/3407840/files/erg_rf.pkl -o $HOME/kg/models/erg_rf.pkl -L
RUN curl https://zenodo.org/api/records/3407840/files/maccs_rf.pkl -o $HOME/kg/models/maccs_rf.pkl -L
RUN curl https://zenodo.org/api/records/3407840/files/maccs_xgboost.pickle.dat -o $HOME/kg/models/maccs_xgboost.pickle.dat -L
# RUN curl https://zenodo.org/api/records/3407840/files/chem_phys_rf.pkl -o $HOME/kg/models/chem_phys_rf.pkl -L
# RUN curl https://zenodo.org/api/records/3407840/files/ecfp_rf.pkl -o $HOME/kg/models/ecfp_rf.pkl -L
# RUN curl https://zenodo.org/api/records/3407840/files/erg_rf.pkl -o $HOME/kg/models/erg_rf.pkl -L
# RUN curl https://zenodo.org/api/records/3407840/files/maccs_rf.pkl -o $HOME/kg/models/maccs_rf.pkl -L
RUN curl https://zenodo.org/api/records/3407840/files/mhfp6_rf.pkl -o $HOME/kg/models/mhfp6_rf.pkl -L
RUN curl https://zenodo.org/api/records/3407840/files/rdkit_rf.pkl -o $HOME/kg/models/rdkit_rf.pkl -L
# RUN curl https://zenodo.org/api/records/3407840/files/rdkit_rf.pkl -o $HOME/kg/models/rdkit_rf.pkl -L

RUN pip install --no-cache-dir -r requirements.txt \
&& chmod +x start-script.sh \
Expand Down
5 changes: 2 additions & 3 deletions pages/1_Chemical_Space_Exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
You can zoom in, pan, and hover over the points to view the compound structure."""
)

HtmlFile = open("./amrkg_chemspace.html", "r", encoding="utf-8")
HtmlFile = open("amrkg_chemspace.html", "r", encoding="utf-8")
source_code = HtmlFile.read()
components.html(source_code, height=500, scrolling=True)

Expand All @@ -69,8 +69,7 @@
help="Look for sub-structures in the database based in InChI keys.",
)

DATA_DIR = "../data"
df = pd.read_csv(f"{DATA_DIR}/processed/combined_bioassay_data.tsv", sep="\t")
df = pd.read_csv("data/processed/combined_bioassay_data.tsv", sep="\t")
df["scaffold_inchikey"] = df["compound_inchikey"].str.split("-").str[0]

user_smiles = st.text_input(
Expand Down
199 changes: 28 additions & 171 deletions pages/2_Model_Prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,215 +57,71 @@
)


# Model selection
st.markdown("### Select the model to use for prediction:")

cols = st.columns(2)
with cols[0]:
model = st.radio(
"#### Model",
("Random Forest", "XGBoost"),
index=0,
help="Select the model to use for prediction",
horizontal=True,
)

st.write("#### Select the fingerprint:")
fingerprint = st.radio(
"Fingerprint",
("MHFP6", "ECFP4", "RDKIT", "MACCS", "ErG", "ChemPhys"),
index=0,
help="Select the fingerprint representation to use for prediction",
)

st.write("**The best model combination is Random Forest with MHFP6 fingerprint.**")

with cols[1]:

metric_df = pd.read_csv("../data/test_metrics.tsv", sep="\t").reset_index(drop=True)
metric_df.rename(columns={"Unnamed: 0": "model_name"}, inplace=True)
metric_df["model"] = metric_df["model_name"].apply(
lambda x: (
x.split("_")[1].upper()
if len(x.split("_")) < 3
else x.split("_")[-1].upper()
)
)
metric_df["fingerprints"] = metric_df["model_name"].apply(
lambda x: (
x.split("_")[0].upper()
if len(x.split("_")) < 3
else x.split("_")[0].upper() + "_" + x.split("_")[1].upper()
)
)
metric_df["fingerprints"] = metric_df["fingerprints"].replace(
{
"ERG": "ErG",
"CHEM_PHYS": "ChemPhys",
}
)

colors = {
"MHFP6": "#3a2c20",
"ECFP4": "#b65c11",
"RDKIT": "#e7a504",
"MACCS": "#719842",
"ErG": "#3d8ebf",
"ChemPhys": "#901b1b",
"RF": "#3a2c20",
"XGBOOST": "#719842",
}
st.write(
"**The best model combination is Random Forest with MHFP6 fingerprint. \
By default, this is used for prediction.**"
)

metric_df["accuracy"] = metric_df["accuracy"] * 100
plt.figure(figsize=(5, 5))
sns.violinplot(x="model", y="accuracy", data=metric_df, palette=colors)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.xlabel("Model", fontsize=15)
plt.ylabel("Accuracy", fontsize=15)
st.pyplot(plt)
st.header(
"🧠 Prediction results",
divider="orange",
help="Results of the model prediction.",
)

if st.button("Predict"):
if uploaded_file is not None:
smiles_df = pd.read_csv(uploaded_file, header=None)
else:
smiles_df = pd.DataFrame(text_input.split("\n"), columns=["smiles"])

if model == "Random Forest":
model_name = "rf"
else:
model_name = "xgboost"

if fingerprint == "MHFP6":
fingerprint_name = "mhfp6"
elif fingerprint == "ECFP4":
fingerprint_name = "ecfp4"
elif fingerprint == "RDKIT":
fingerprint_name = "rdkit"
elif fingerprint == "MACCS":
fingerprint_name = "maccs"
elif fingerprint == "ErG":
fingerprint_name = "erg"
else:
model = "chem_phys"

logger.info("⏳ Loading models")
model_name = "rf"
fingerprint_name = "mhfp6"

if model_name == "rf":
model = torch.load(f"../models/{fingerprint_name}_{model_name}.pkl")
else:
model = pickle.load(
open(f"../models/{fingerprint_name}_{model_name}.pickle.dat", "rb")
)
model = torch.load(f"./models/{fingerprint_name}_{model_name}.pkl")

logger.info("🔮 Processing SMILES to fingerprints")
# warnings.simplefilter("error", InconsistentVersionWarning)

mfpgen = rdFingerprintGenerator.GetMorganGenerator(
radius=4, fpSize=1024
) # ECFP4 fingerprint
rdkgen = rdFingerprintGenerator.GetRDKitFPGenerator(
fpSize=1024
) # RDKit fingerprint
mhfp_encoder = MHFPEncoder(n_permutations=2048, seed=42) # MHFP6 fingerprint

ecfp_fingerprints = []
rdkit_fingerprints = []
maccs_fingerprints = []
mhfp6_fingerprints = []
erg_fingerprints = []
chem_phys_props = []

if smiles_df.empty:
st.write("No SMILES provided.")
st.stop()

for smiles in smiles_df["smiles"].values:
# Canonicalize the smiles
try:
can_smiles = CanonSmiles(smiles)
except Exception as e:
can_smiles = smiles
can_smiles = mhfp6_fingerprints.append(None)
continue

# Generate the mol object
mol = MolFromSmiles(can_smiles)

if not mol:
ecfp_fingerprints.append(None)
rdkit_fingerprints.append(None)
maccs_fingerprints.append(None)
chem_phys_props.append(None)
mhfp_encoder.append(None)
erg_fingerprints.append(None)
mhfp6_fingerprints.append(None)
continue

ecfp_fingerprints.append(mfpgen.GetFingerprint(mol))
rdkit_fingerprints.append(rdkgen.GetFingerprint(mol))
maccs_fingerprints.append(MACCSkeys.GenMACCSKeys(mol))
mhfp6_fingerprints.append(mhfp_encoder.encode(can_smiles, radius=3))
erg_fingerprints.append(rdReducedGraphs.GetErGFingerprint(mol))

vals = Descriptors.CalcMolDescriptors(mol)

chem_phys_props.append(
{
"slogp": round(vals["MolLogP"], 2),
"smr": round(vals["MolMR"], 2),
"labute_asa": round(vals["LabuteASA"], 2),
"tpsa": round(vals["TPSA"], 2),
"exact_mw": round(vals["ExactMolWt"], 2),
"num_lipinski_hba": rdMolDescriptors.CalcNumLipinskiHBA(mol),
"num_lipinski_hbd": rdMolDescriptors.CalcNumLipinskiHBD(mol),
"num_rotatable_bonds": vals["NumRotatableBonds"],
"num_hba": vals["NumHAcceptors"],
"num_hbd": vals["NumHDonors"],
"num_amide_bonds": rdMolDescriptors.CalcNumAmideBonds(mol),
"num_heteroatoms": vals["NumHeteroatoms"],
"num_heavy_atoms": vals["HeavyAtomCount"],
"num_atoms": rdMolDescriptors.CalcNumAtoms(mol),
"num_stereocenters": rdMolDescriptors.CalcNumAtomStereoCenters(mol),
"num_unspecified_stereocenters": rdMolDescriptors.CalcNumUnspecifiedAtomStereoCenters(
mol
),
"num_rings": vals["RingCount"],
"num_aromatic_rings": vals["NumAromaticRings"],
"num_aliphatic_rings": vals["NumAliphaticRings"],
"num_saturated_rings": vals["NumSaturatedRings"],
"num_aromatic_heterocycles": vals["NumAromaticHeterocycles"],
"num_aliphatic_heterocycles": vals["NumAliphaticHeterocycles"],
"num_saturated_heterocycles": vals["NumSaturatedHeterocycles"],
"num_aromatic_carbocycles": vals["NumAromaticCarbocycles"],
"num_aliphatic_carbocycles": vals["NumAliphaticCarbocycles"],
"num_saturated_carbocycles": vals["NumSaturatedCarbocycles"],
"fraction_csp3": round(vals["FractionCSP3"], 2),
"num_brdigehead_atoms": rdMolDescriptors.CalcNumBridgeheadAtoms(mol),
"bertz_complexity": GraphDescriptors.BertzCT(mol),
}
)

smiles_df["ecfp4"] = ecfp_fingerprints
smiles_df["rdkit"] = rdkit_fingerprints
smiles_df["maccs"] = maccs_fingerprints
smiles_df["chem_phys"] = chem_phys_props
smiles_df["mhfp6"] = mhfp6_fingerprints
smiles_df["erg"] = erg_fingerprints

logger.info("🏃 Running model")

smiles_df_subset = smiles_df.dropna(subset=[fingerprint_name])[
["smiles", fingerprint_name]
]
if model_name == "rf":
predictions = model.predict(smiles_df_subset[fingerprint_name].tolist())
prediction_proba = model.predict_proba(
smiles_df_subset[fingerprint_name].tolist()
)
label_classes = model.classes_.tolist()
else:
predictions = model.predict(smiles_df_subset[fingerprint_name].values)
prediction_proba = model.predict_proba(
smiles_df_subset[fingerprint_name].values
)
label_classes = model.classes_.tolist()

logger.info("✅ Finished task")
if smiles_df_subset.empty:
st.write("No valid SMILES provided.")
st.stop()

predictions = model.predict(smiles_df_subset[fingerprint_name].tolist())
prediction_proba = model.predict_proba(smiles_df_subset[fingerprint_name].tolist())
label_classes = model.classes_.tolist()

st.write("### Predictions")
smiles_df_subset["Prediction"] = predictions
probs = []
for idx, probability in enumerate(prediction_proba):
Expand All @@ -274,3 +130,4 @@
smiles_df_subset["Probability"] = probs

st.dataframe(smiles_df_subset[["smiles", "Prediction", "Probability"]])
st.write("Note: The compounds that could not generate fingerprints are not shown.")
6 changes: 4 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,7 @@ pandas
rdkit
mhfp==1.9.6
seaborn==0.13.1
torch
streamlit==1.36.0
streamlit==1.36.0
scikit-learn==1.2.2
torch==2.1.2
joblib==1.3.2

0 comments on commit 2b09cd7

Please sign in to comment.