diff --git a/Dockerfile b/Dockerfile index f153b46..3954796 100644 --- a/Dockerfile +++ b/Dockerfile @@ -27,12 +27,12 @@ COPY start-script.sh $HOME/kg/start-script.sh RUN mkdir $HOME/kg/models RUN apt-get install curl -y -# RUN curl https://zenodo.org/api/records/3407840/files/chem_phys_rf.pkl -o $HOME/kg/models/chem_phys_rf.pkl -L -# RUN curl https://zenodo.org/api/records/3407840/files/ecfp_rf.pkl -o $HOME/kg/models/ecfp_rf.pkl -L -# RUN curl https://zenodo.org/api/records/3407840/files/erg_rf.pkl -o $HOME/kg/models/erg_rf.pkl -L -# RUN curl https://zenodo.org/api/records/3407840/files/maccs_rf.pkl -o $HOME/kg/models/maccs_rf.pkl -L +RUN curl https://zenodo.org/api/records/3407840/files/chem_phys_rf.pkl -o $HOME/kg/models/chem_phys_rf.pkl -L +RUN curl https://zenodo.org/api/records/3407840/files/ecfp_rf.pkl -o $HOME/kg/models/ecfp_rf.pkl -L +RUN curl https://zenodo.org/api/records/3407840/files/erg_rf.pkl -o $HOME/kg/models/erg_rf.pkl -L +RUN curl https://zenodo.org/api/records/3407840/files/maccs_rf.pkl -o $HOME/kg/models/maccs_rf.pkl -L RUN curl https://zenodo.org/api/records/3407840/files/mhfp6_rf.pkl -o $HOME/kg/models/mhfp6_rf.pkl -L -# RUN curl https://zenodo.org/api/records/3407840/files/rdkit_rf.pkl -o $HOME/kg/models/rdkit_rf.pkl -L +RUN curl https://zenodo.org/api/records/3407840/files/rdkit_rf.pkl -o $HOME/kg/models/rdkit_rf.pkl -L RUN pip install --no-cache-dir -r requirements.txt \ && chmod +x start-script.sh \ diff --git a/pages/2_Model_Prediction.py b/pages/2_Model_Prediction.py index b7c68b4..baf3d50 100644 --- a/pages/2_Model_Prediction.py +++ b/pages/2_Model_Prediction.py @@ -4,7 +4,6 @@ import pandas as pd import streamlit as st import torch -import pickle from rdkit.Chem import ( CanonSmiles, @@ -57,16 +56,65 @@ ) -st.write( - "**The best model combination is Random Forest with MHFP6 fingerprint. \ - By default, this is used for prediction.**" -) +# Model selection +st.markdown("### Select the model to use for prediction:") -st.header( - "🧠 Prediction results", - divider="orange", - help="Results of the model prediction.", -) +cols = st.columns(2) +with cols[0]: + st.write("#### Select the fingerprint:") + fingerprint = st.radio( + "Fingerprint", + ("MHFP6", "ECFP4", "RDKIT", "MACCS", "ErG"), + index=0, + help="Select the fingerprint representation to use for prediction", + ) + + st.write("**The best model combination is Random Forest with MHFP6 fingerprint.**") + +with cols[1]: + + metric_df = pd.read_csv("data/test_metrics.tsv", sep="\t").reset_index(drop=True) + metric_df.rename(columns={"Unnamed: 0": "model_name"}, inplace=True) + metric_df["model"] = metric_df["model_name"].apply( + lambda x: ( + x.split("_")[1].upper() + if len(x.split("_")) < 3 + else x.split("_")[-1].upper() + ) + ) + metric_df["fingerprints"] = metric_df["model_name"].apply( + lambda x: ( + x.split("_")[0].upper() + if len(x.split("_")) < 3 + else x.split("_")[0].upper() + "_" + x.split("_")[1].upper() + ) + ) + metric_df["fingerprints"] = metric_df["fingerprints"].replace( + { + "ERG": "ErG", + "CHEM_PHYS": "ChemPhys", + } + ) + + colors = { + "MHFP6": "#3a2c20", + "ECFP4": "#b65c11", + "RDKIT": "#e7a504", + "MACCS": "#719842", + "ErG": "#3d8ebf", + "ChemPhys": "#901b1b", + "RF": "#3a2c20", + "XGBOOST": "#719842", + } + + metric_df["accuracy"] = metric_df["accuracy"] * 100 + plt.figure(figsize=(5, 5)) + sns.violinplot(x="model", y="accuracy", data=metric_df, palette=colors, hue="model") + plt.xticks(fontsize=12) + plt.yticks(fontsize=12) + plt.xlabel("Model", fontsize=15) + plt.ylabel("Accuracy", fontsize=15) + st.pyplot(plt) if st.button("Predict"): if uploaded_file is not None: @@ -74,16 +122,37 @@ else: smiles_df = pd.DataFrame(text_input.split("\n"), columns=["smiles"]) - model_name = "rf" - fingerprint_name = "mhfp6" + if fingerprint == "MHFP6": + fingerprint_name = "mhfp6" + elif fingerprint == "ECFP4": + fingerprint_name = "ecfp4" + elif fingerprint == "RDKIT": + fingerprint_name = "rdkit" + elif fingerprint == "MACCS": + fingerprint_name = "maccs" + else: + fingerprint_name = "erg" + + logger.info("⏳ Loading models") - model = torch.load(f"./models/{fingerprint_name}_{model_name}.pkl") + model_name = "rf" + model = torch.load(f"models/{fingerprint_name}_{model_name}.pkl") - # warnings.simplefilter("error", InconsistentVersionWarning) + logger.info("🔮 Processing SMILES to fingerprints") + mfpgen = rdFingerprintGenerator.GetMorganGenerator( + radius=4, fpSize=1024 + ) # ECFP4 fingerprint + rdkgen = rdFingerprintGenerator.GetRDKitFPGenerator( + fpSize=1024 + ) # RDKit fingerprint mhfp_encoder = MHFPEncoder(n_permutations=2048, seed=42) # MHFP6 fingerprint + ecfp_fingerprints = [] + rdkit_fingerprints = [] + maccs_fingerprints = [] mhfp6_fingerprints = [] + erg_fingerprints = [] if smiles_df.empty: st.write("No SMILES provided.") @@ -94,34 +163,46 @@ try: can_smiles = CanonSmiles(smiles) except Exception as e: - can_smiles = mhfp6_fingerprints.append(None) - continue + can_smiles = smiles # Generate the mol object mol = MolFromSmiles(can_smiles) if not mol: + ecfp_fingerprints.append(None) + rdkit_fingerprints.append(None) + maccs_fingerprints.append(None) mhfp6_fingerprints.append(None) + erg_fingerprints.append(None) continue + ecfp_fingerprints.append(mfpgen.GetFingerprint(mol)) + rdkit_fingerprints.append(rdkgen.GetFingerprint(mol)) + maccs_fingerprints.append(MACCSkeys.GenMACCSKeys(mol)) mhfp6_fingerprints.append(mhfp_encoder.encode(can_smiles, radius=3)) + erg_fingerprints.append(rdReducedGraphs.GetErGFingerprint(mol)) - vals = Descriptors.CalcMolDescriptors(mol) + smiles_df["ecfp4"] = ecfp_fingerprints + smiles_df["rdkit"] = rdkit_fingerprints + smiles_df["maccs"] = maccs_fingerprints + smiles_df["mhfp6"] = mhfp6_fingerprints + smiles_df["erg"] = erg_fingerprints smiles_df["mhfp6"] = mhfp6_fingerprints + logger.info("🏃 Running model") + smiles_df_subset = smiles_df.dropna(subset=[fingerprint_name])[ ["smiles", fingerprint_name] ] - if smiles_df_subset.empty: - st.write("No valid SMILES provided.") - st.stop() - predictions = model.predict(smiles_df_subset[fingerprint_name].tolist()) prediction_proba = model.predict_proba(smiles_df_subset[fingerprint_name].tolist()) label_classes = model.classes_.tolist() + logger.info("✅ Finished task") + + st.write("### Predictions") smiles_df_subset["Prediction"] = predictions probs = [] for idx, probability in enumerate(prediction_proba): @@ -130,4 +211,3 @@ smiles_df_subset["Probability"] = probs st.dataframe(smiles_df_subset[["smiles", "Prediction", "Probability"]]) - st.write("Note: The compounds that could not generate fingerprints are not shown.")