{ "cells": [ { "cell_type": "markdown", "id": "b688f8fb-e6b4-495a-8eda-4d82ed2b0422", "metadata": {}, "source": [ "# Inference" ] }, { "cell_type": "markdown", "id": "e49c994b-1b95-4f4e-acf7-91f810daa20d", "metadata": {}, "source": [ "## Prerequisites\n", "\n", "Before starting this tutorial, you should have already worked through the tutorial on [Running ML inference](https://drugforge.readthedocs.io/en/latest/tutorials/index.html#running-ml-inference).\n", "\n", "We will use publicly available test files for the model weights and inference files for this tutorial, but feel free to substitute with your own data." ] }, { "cell_type": "code", "execution_count": 1, "id": "681dcee8-3af1-4c6f-9311-16c5fd0b57b2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CC(=O)NNC(=O)c1cc2CCCCc2s1\n" ] } ], "source": [ "# First download the needed files\n", "from drugforge.data.testing.test_resources import fetch_test_file\n", "\n", "docked_file = fetch_test_file(\n", " \"ml_testing/docked/AAR-POS-5507155c-1_Mpro-P0018_0A_0_bound_best.pdb\"\n", ")\n", "smi_file = fetch_test_file(\"Mpro_combined_labeled.smi\")\n", "pred_smiles = smi_file.read_text().split()[0]\n", "print(pred_smiles)" ] }, { "cell_type": "markdown", "id": "257eb033-8485-4b6d-b0f5-32a7a1871fe7", "metadata": {}, "source": [ "## Intro\n", "\n", "In this guide, we will start with an existing model weights and model config file, which should both have been generated in the previous tutorial (or can be downloaded as above).\n", "\n", "We will first build the model spec object, and then the `Inference` model object. We will show examples of this using a local model weights file, or pulling one of our available pre-trained models. For the inference, we will pass the model a complex structure PDB file or SMILES string, depending on the model type." ] }, { "cell_type": "markdown", "id": "e6ffcf4a-8b8b-47b7-8f40-e35147e4aa46", "metadata": {}, "source": [ "## Building the `Inference` model from a local weights file\n", "\n", "The first way that we will show to build an `Inference` model is using local weights and config files. We will first define the `drugforge.ml.models.LocalMLModelSpec`, and then use that to construct the `Inference` model." ] }, { "cell_type": "code", "execution_count": 2, "id": "cbf3f6f2-ea84-4946-892b-72bd91c97325", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "GAT prediction: 2.397318124771118\n", "SchNet prediction: -0.17303919792175293\n", "e3nn prediction: 0.01794694922864437\n" ] } ], "source": [ "from drugforge.ml.inference import E3nnInference, GATInference, SchnetInference\n", "from drugforge.ml.models import LocalMLModelSpec\n", "from datetime import datetime\n", "from pathlib import Path\n", "\n", "# Define weights and config files\n", "gat_wts_path = next(iter(Path(\"gat_training/\").glob(\"*/final.th\")))\n", "gat_config_path = Path(\"gat_model_config.json\")\n", "schnet_wts_path = next(iter(Path(\"schnet_training/\").glob(\"*/final.th\")))\n", "schnet_config_path = Path(\"schnet_model_config.json\")\n", "e3nn_wts_path = next(iter(Path(\"e3nn_training/\").glob(\"*/final.th\")))\n", "e3nn_config_path = Path(\"e3nn_model_config.json\")\n", "\n", "# Build LocalMLModelSpecs and Inference models\n", "gat_model_spec = LocalMLModelSpec(\n", " name=\"GAT\",\n", " type=\"GAT\",\n", " last_updated=datetime.today(),\n", " targets={\"SARS-CoV-2-Mpro\"},\n", " weights_file=gat_wts_path,\n", " config_file=gat_config_path,\n", ")\n", "gat_inf_model = GATInference.from_local_model_spec(gat_model_spec)\n", "schnet_model_spec = LocalMLModelSpec(\n", " name=\"SchNet\",\n", " type=\"schnet\",\n", " last_updated=datetime.today(),\n", " targets={\"SARS-CoV-2-Mpro\"},\n", " weights_file=schnet_wts_path,\n", " config_file=schnet_config_path,\n", ")\n", "schnet_inf_model = SchnetInference.from_local_model_spec(schnet_model_spec)\n", "e3nn_model_spec = LocalMLModelSpec(\n", " name=\"e3nn\",\n", " type=\"e3nn\",\n", " last_updated=datetime.today(),\n", " targets={\"SARS-CoV-2-Mpro\"},\n", " weights_file=e3nn_wts_path,\n", " config_file=e3nn_config_path,\n", ")\n", "e3nn_inf_model = E3nnInference.from_local_model_spec(e3nn_model_spec)\n", "\n", "# Make predictions\n", "gat_pred = gat_inf_model.predict_from_smiles(pred_smiles)\n", "schnet_pred = schnet_inf_model.predict_from_structure_file(docked_file)\n", "e3nn_pred = e3nn_inf_model.predict_from_structure_file(docked_file)\n", "\n", "print(\"GAT prediction:\", gat_pred)\n", "print(\"SchNet prediction:\", schnet_pred)\n", "print(\"e3nn prediction:\", e3nn_pred)" ] }, { "cell_type": "markdown", "id": "539c68d2-8aba-4d3d-8aa1-a288f572b061", "metadata": {}, "source": [ "## Building the `Inference` model from ASAP-trained models\n", "\n", "We also offer several models that we have pre-trained on various Moonshot/ASAP data. In this section we simply specify the ASAP target that we want to pull the model for, and `drugforge` takes care of pulling the necessary files, and building the `Inference` model." ] }, { "cell_type": "code", "execution_count": 3, "id": "327c7144-a618-4ff7-8a94-2f06c677f6dc", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "GAT prediction: 4.059902191162109\n", "SchNet prediction: 1.3631529808044434\n", "e3nn prediction: 3.001865863800049\n" ] } ], "source": [ "from drugforge.ml.inference import E3nnInference, GATInference, SchnetInference\n", "\n", "# Get latest model trained on SARS-CoV-2 Mpro data for each architecture\n", "gat_inf_model = GATInference.from_latest_by_target(\"SARS-CoV-2-Mpro\")\n", "schnet_inf_model = SchnetInference.from_latest_by_target(\"SARS-CoV-2-Mpro\")\n", "e3nn_inf_model = E3nnInference.from_latest_by_target(\"SARS-CoV-2-Mpro\")\n", "\n", "# Make predictions\n", "gat_pred = gat_inf_model.predict_from_smiles(pred_smiles)\n", "schnet_pred = schnet_inf_model.predict_from_structure_file(docked_file)\n", "e3nn_pred = e3nn_inf_model.predict_from_structure_file(docked_file)\n", "\n", "print(\"GAT prediction:\", gat_pred)\n", "print(\"SchNet prediction:\", schnet_pred)\n", "print(\"e3nn prediction:\", e3nn_pred)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 5 }