Skip to content

Commit

Permalink
update docker file and push all models
Browse files Browse the repository at this point in the history
  • Loading branch information
YojanaGadiya committed Jun 25, 2024
1 parent 022dad3 commit a304fca
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 27 deletions.
10 changes: 5 additions & 5 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ COPY start-script.sh $HOME/kg/start-script.sh
RUN mkdir $HOME/kg/models

RUN apt-get install curl -y
# RUN curl https://zenodo.org/api/records/3407840/files/chem_phys_rf.pkl -o $HOME/kg/models/chem_phys_rf.pkl -L
# RUN curl https://zenodo.org/api/records/3407840/files/ecfp_rf.pkl -o $HOME/kg/models/ecfp_rf.pkl -L
# RUN curl https://zenodo.org/api/records/3407840/files/erg_rf.pkl -o $HOME/kg/models/erg_rf.pkl -L
# RUN curl https://zenodo.org/api/records/3407840/files/maccs_rf.pkl -o $HOME/kg/models/maccs_rf.pkl -L
RUN curl https://zenodo.org/api/records/3407840/files/chem_phys_rf.pkl -o $HOME/kg/models/chem_phys_rf.pkl -L
RUN curl https://zenodo.org/api/records/3407840/files/ecfp_rf.pkl -o $HOME/kg/models/ecfp_rf.pkl -L
RUN curl https://zenodo.org/api/records/3407840/files/erg_rf.pkl -o $HOME/kg/models/erg_rf.pkl -L
RUN curl https://zenodo.org/api/records/3407840/files/maccs_rf.pkl -o $HOME/kg/models/maccs_rf.pkl -L
RUN curl https://zenodo.org/api/records/3407840/files/mhfp6_rf.pkl -o $HOME/kg/models/mhfp6_rf.pkl -L
# RUN curl https://zenodo.org/api/records/3407840/files/rdkit_rf.pkl -o $HOME/kg/models/rdkit_rf.pkl -L
RUN curl https://zenodo.org/api/records/3407840/files/rdkit_rf.pkl -o $HOME/kg/models/rdkit_rf.pkl -L

RUN pip install --no-cache-dir -r requirements.txt \
&& chmod +x start-script.sh \
Expand Down
124 changes: 102 additions & 22 deletions pages/2_Model_Prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import pandas as pd
import streamlit as st
import torch
import pickle

from rdkit.Chem import (
CanonSmiles,
Expand Down Expand Up @@ -57,33 +56,103 @@
)


st.write(
"**The best model combination is Random Forest with MHFP6 fingerprint. \
By default, this is used for prediction.**"
)
# Model selection
st.markdown("### Select the model to use for prediction:")

st.header(
"🧠 Prediction results",
divider="orange",
help="Results of the model prediction.",
)
cols = st.columns(2)
with cols[0]:
st.write("#### Select the fingerprint:")
fingerprint = st.radio(
"Fingerprint",
("MHFP6", "ECFP4", "RDKIT", "MACCS", "ErG"),
index=0,
help="Select the fingerprint representation to use for prediction",
)

st.write("**The best model combination is Random Forest with MHFP6 fingerprint.**")

with cols[1]:

metric_df = pd.read_csv("data/test_metrics.tsv", sep="\t").reset_index(drop=True)
metric_df.rename(columns={"Unnamed: 0": "model_name"}, inplace=True)
metric_df["model"] = metric_df["model_name"].apply(
lambda x: (
x.split("_")[1].upper()
if len(x.split("_")) < 3
else x.split("_")[-1].upper()
)
)
metric_df["fingerprints"] = metric_df["model_name"].apply(
lambda x: (
x.split("_")[0].upper()
if len(x.split("_")) < 3
else x.split("_")[0].upper() + "_" + x.split("_")[1].upper()
)
)
metric_df["fingerprints"] = metric_df["fingerprints"].replace(
{
"ERG": "ErG",
"CHEM_PHYS": "ChemPhys",
}
)

colors = {
"MHFP6": "#3a2c20",
"ECFP4": "#b65c11",
"RDKIT": "#e7a504",
"MACCS": "#719842",
"ErG": "#3d8ebf",
"ChemPhys": "#901b1b",
"RF": "#3a2c20",
"XGBOOST": "#719842",
}

metric_df["accuracy"] = metric_df["accuracy"] * 100
plt.figure(figsize=(5, 5))
sns.violinplot(x="model", y="accuracy", data=metric_df, palette=colors, hue="model")
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.xlabel("Model", fontsize=15)
plt.ylabel("Accuracy", fontsize=15)
st.pyplot(plt)

if st.button("Predict"):
if uploaded_file is not None:
smiles_df = pd.read_csv(uploaded_file, header=None)
else:
smiles_df = pd.DataFrame(text_input.split("\n"), columns=["smiles"])

model_name = "rf"
fingerprint_name = "mhfp6"
if fingerprint == "MHFP6":
fingerprint_name = "mhfp6"
elif fingerprint == "ECFP4":
fingerprint_name = "ecfp4"
elif fingerprint == "RDKIT":
fingerprint_name = "rdkit"
elif fingerprint == "MACCS":
fingerprint_name = "maccs"
else:
fingerprint_name = "erg"

logger.info("⏳ Loading models")

model = torch.load(f"./models/{fingerprint_name}_{model_name}.pkl")
model_name = "rf"
model = torch.load(f"models/{fingerprint_name}_{model_name}.pkl")

# warnings.simplefilter("error", InconsistentVersionWarning)
logger.info("🔮 Processing SMILES to fingerprints")

mfpgen = rdFingerprintGenerator.GetMorganGenerator(
radius=4, fpSize=1024
) # ECFP4 fingerprint
rdkgen = rdFingerprintGenerator.GetRDKitFPGenerator(
fpSize=1024
) # RDKit fingerprint
mhfp_encoder = MHFPEncoder(n_permutations=2048, seed=42) # MHFP6 fingerprint

ecfp_fingerprints = []
rdkit_fingerprints = []
maccs_fingerprints = []
mhfp6_fingerprints = []
erg_fingerprints = []

if smiles_df.empty:
st.write("No SMILES provided.")
Expand All @@ -94,34 +163,46 @@
try:
can_smiles = CanonSmiles(smiles)
except Exception as e:
can_smiles = mhfp6_fingerprints.append(None)
continue
can_smiles = smiles

# Generate the mol object
mol = MolFromSmiles(can_smiles)

if not mol:
ecfp_fingerprints.append(None)
rdkit_fingerprints.append(None)
maccs_fingerprints.append(None)
mhfp6_fingerprints.append(None)
erg_fingerprints.append(None)
continue

ecfp_fingerprints.append(mfpgen.GetFingerprint(mol))
rdkit_fingerprints.append(rdkgen.GetFingerprint(mol))
maccs_fingerprints.append(MACCSkeys.GenMACCSKeys(mol))
mhfp6_fingerprints.append(mhfp_encoder.encode(can_smiles, radius=3))
erg_fingerprints.append(rdReducedGraphs.GetErGFingerprint(mol))

vals = Descriptors.CalcMolDescriptors(mol)
smiles_df["ecfp4"] = ecfp_fingerprints
smiles_df["rdkit"] = rdkit_fingerprints
smiles_df["maccs"] = maccs_fingerprints
smiles_df["mhfp6"] = mhfp6_fingerprints
smiles_df["erg"] = erg_fingerprints

smiles_df["mhfp6"] = mhfp6_fingerprints

logger.info("🏃 Running model")

smiles_df_subset = smiles_df.dropna(subset=[fingerprint_name])[
["smiles", fingerprint_name]
]

if smiles_df_subset.empty:
st.write("No valid SMILES provided.")
st.stop()

predictions = model.predict(smiles_df_subset[fingerprint_name].tolist())
prediction_proba = model.predict_proba(smiles_df_subset[fingerprint_name].tolist())
label_classes = model.classes_.tolist()

logger.info("✅ Finished task")

st.write("### Predictions")
smiles_df_subset["Prediction"] = predictions
probs = []
for idx, probability in enumerate(prediction_proba):
Expand All @@ -130,4 +211,3 @@
smiles_df_subset["Probability"] = probs

st.dataframe(smiles_df_subset[["smiles", "Prediction", "Probability"]])
st.write("Note: The compounds that could not generate fingerprints are not shown.")

0 comments on commit a304fca

Please sign in to comment.