6.0 Finetuning

What is finetuning?

Finetuning is the process of taking a foundational model (a model that has already been trained on a very large dataset) and training it on a smaller, more specific dataset for a specific task or domain. Rather than going through the arduous and time consuming process of training the model from scratch with your specific dataset, you can leverage the pretrained model’s existing learnings of useful general representations (e.g. chemical structure, descriptors, etc.) and focus the model on your specific task.

Why finetune?

Finetuning can be very useful, particularly when you have little or no empirical data to fully train a model. This is usually the case since empirical data is limited by resources, time, manpower, and throughput.

Consider this scenario: you have just started a new drug program and you want to use a model like ChemProp or CheMeleon to help you predict certain ADMET endpoints so you know which compounds to prioritize in assays. You start by finetuning a CheMeleon model on data curated from ChEMBL as shown in 01_Curate_ChEMBL_Data.ipynb, since you have no empirical data of your own yet. This model gives you some predictions good enough that you’re able to narrow the pool of compounds to run. Now that you have your own empirical data, you want finetune your CheMeleon model further on the specific subset of molecules that you think show the most promise for your disease target, to ideally get more accurate predictions.

This demo

This notebook will show you how to finetune a CheMeleon model for prediction on microsomal clearance for human liver microsomes (LOG_CLint_HLM) on:

  1. a curated ChEMBL dataset

  2. ChEMBL + the train dataset from the OpenADMET ExpansionRx Therapeutics Blind Challenge

Model training with our ANVIL infrastructure is exactly the same as outlined in 02_Model_Training, so we won’t go into the same level of detail here. We will show how performance of finetuned model improves when predicting on the ExpansionRx test set when finetuning with ChEMBL+ExpansionRx train data vs. finetuning on ChEMBL alone. Inference is also done with ANVIL as shown in 05_Ensemble_Model_Inference.

We have already curated a ChEMBL dataset for microsomal clearance values in data/chembl.csv. The raw dataset is available here.

The ExpansionRx train dataset is available on our HuggingFace and is downloaded as data/exprx.csv.

The ExpansionRx test datset is also available on HuggingFace and is downloaded as data/exprx_test.csv.

[1]:
import pandas as pd

chembl = pd.read_csv("data/chembl.csv")

exprx = pd.read_csv("data/exprx.csv")

We’ve already done some preprocessing of both these datasets for ease of work.

[2]:
chembl.head(10)
[2]:
OPENADMET_INCHIKEY OPENADMET_CANONICAL_SMILES LOG_CLint_HLM CLint_HLM LOG_CLint_RLM CLint_RLM LOG_CLint_MLM CLint_MLM
0 MOPSJMVERRFXAX-HNNXBMFYSA-N COCCOC1=CN=C2C=CC([C@H](C)N3N=NC4=C3C(=O)N(C3=... 2.147367 140.40 2.276462 189.00 2.583958 383.670
1 NNBYANZTLQEDDY-UHFFFAOYSA-N CC1=NC=C(C2=NC3=C(NC4(C)CC4)N=C(NC4=CN(C5CC6(C... 1.109579 12.87 NaN NaN NaN NaN
2 UAXPPQPDGOMVPD-LDFOQBSTSA-N CCC(NC(=O)C1=CC=C(Cl)C=C1)[C@H]1[C@@H]2C[C@H](... 2.499687 316.00 2.980458 956.00 NaN NaN
3 YTTYYCAUZAPRNY-UHFFFAOYSA-N CC(C)S(=O)(=O)C1=CC=CC=C1NC1=NC(NC2=CC3=NNC(NC... NaN NaN 2.184180 152.82 NaN NaN
4 AWLUMUGHLSXLNA-UHFFFAOYSA-N CC1=NC(=O)C2=C(C)N=C(C3=C(C4=CC=C(C(F)(F)F)C=C... 1.204120 16.00 NaN NaN NaN NaN
5 DKQOAOQZQUBSIZ-UHFFFAOYSA-N CC(C)OC(=O)NC1CCCOC2=C(C3=CC=C(Cl)C=C3)N(C3=CC... 1.176091 15.00 NaN NaN NaN NaN
6 VRXFDHAGFYWGHT-UHFFFAOYSA-N CCNCC1=CC(NC2=CC=NC3=CC(Cl)=CC=C23)=CC=C1O NaN NaN 0.534026 3.42 1.634124 43.065
7 YSZNYCUUEJLLGR-UHFFFAOYSA-N COC(=O)C1C(CN2CCC(F)(F)C2)=NC(C2=NC=CS2)=NC1(C... NaN NaN NaN NaN 1.949390 89.000
8 HAKUBDUVAGUCGC-UHFFFAOYSA-N CCCN(CC1CC1)C1=CC(C(=O)NC2=CC=C3NN=CC3=C2)=NC=N1 1.298635 19.89 2.011147 102.60 NaN NaN
9 IPSMAKOQUQMHED-CZIZESTLSA-N CCOC(=O)N1CC/C(=C2/C3=CC=C(Cl)C=C3CCC3=CC=CN=C... 1.775756 59.67 NaN NaN NaN NaN
[3]:
exprx.head(10)
[3]:
OPENADMET_INCHIKEY OPENADMET_CANONICAL_SMILES LOG_CLint_HLM CLint_HLM LOG_CLint_RLM CLint_RLM LOG_CLint_MLM CLint_MLM
0 XXHPNGNGSXDBMK-UHFFFAOYSA-N CN1N=CC2=CC(C3=CC=C(OCCN4CCOCC4)C=C3)=CC=C21 1.475671 29.9 NaN NaN 2.460296 288.6
1 SMQZYWBKOHASSN-UHFFFAOYSA-N CN(CCOC1=CC=C(C2=CC=C3C=CN=CC3=C2)C=C1)C1CCC1 1.409933 25.7 NaN NaN 3.361803 2300.4
2 DAXGQXFVQUYVRY-UHFFFAOYSA-N CC1=C(N)C(C2=CC=C(OCCN3CCCC3=O)C=C2)=CC2=CN=CC... 1.385606 24.3 NaN NaN 2.591287 390.2
3 UEZJJVAJDJNXMQ-CQSZACIVSA-N CC1=CN(C)C(=O)C2=C1C1=CC=C(OC[C@H]3CN(C)CCO3)C... 1.369216 23.4 NaN NaN 2.534280 342.2
4 JEUZRQMZLPUKSV-NSHDSACASA-N C[C@H](N)COC1=CC=C2C(=C1)C1=NC3=CN=CC=C3C=C1N2C 0.949390 8.9 NaN NaN 2.477555 300.3
5 RZMOYHIJQUQVQI-UHFFFAOYSA-N CN1CCC(C2=CC=CC(NC3=CC=CC(C4=CC=CC(F)=C4)=N3)=... 0.908485 8.1 NaN NaN 2.476397 299.5
6 MXTHNLAVZWLDKB-UHFFFAOYSA-N CN1C=NC(C2=CC3=CC(C(=O)NC4CCN(C5=CC=NC=C5F)CC4... 1.587711 38.7 NaN NaN 1.439333 27.5
7 JLZFYAJPIXDVEW-UHFFFAOYSA-N CC1=NN=C(C2=CN=C3C(=C2)C=C(C(=O)NC2CCN(C4=CC=C... 1.509203 32.3 NaN NaN 1.969882 93.3
8 SNZZVPGTGGBZQK-UHFFFAOYSA-N O=C1C2=C(C=CN1CCCN1CCCCC1)C1=CC=CC(F)=C1N2 0.662758 4.6 NaN NaN 1.984527 96.5
9 PMJZFVGMJMIKQE-UHFFFAOYSA-N CN1C=C(C2=CN=CN=C2)C2=CC(OCC3CNC3)=CC=C21 0.740363 5.5 NaN NaN 2.505286 320.1

We can simply concatenate these two datasets into /data/chembl+exprx.csv

[4]:
combined = pd.concat([chembl, exprx])

combined.to_csv("data/chembl+exprx.csv", index=False)

We will now finetune 2 models:

  1. a CheMeleon model on ChEMBL only

  2. a CheMeleon model on ChEMBL and the ExpRx train datasets combined

Given these are deep learning models, we’ve gone ahead and trained those models for you in chemeleon_chembl/ and chemeleon_chembl+exprx so that we can do inference:

openadmet predict \
    --input-path data/exprx_test.csv \
    --input-col OPENADMET_CANONICAL_SMILES \
    --model-dir ../06_Finetuning/chemeleon_chembl/ \
    --output-csv chembl_predictions.csv \
    --accelerator cpu
openadmet predict \
    --input-path data/exprx_test.csv \
    --input-col OPENADMET_CANONICAL_SMILES \
    --model-dir ../06_Finetuning/chemeleon_chembl+exprx/ \
    --output-csv chembl+exprx_predictions.csv \
    --accelerator cpu

Again, we have run these commands for you, for ease of this demo. the files generated are chembl_predictions.csv and chembl+exprx_predictions.csv

NOTE: The error shown in these two cells below is due to a keyboard interrupt. Our inference CLI gives a hanging terminal, even after the predictions.csv file has been generated, when using a deep learning model to predict while using CPU, so we have to manually stop the cell.

Comparing finetuned model performance

Now, we can compare the performance of each of these models when predicting on the ExpRx test set. Recall that in this hypothetical use-case, the region of chemical space we’re interested in exploring more is represented by the ExpRx test dataset.

[5]:
from openadmet.models.eval.regression import RegressionMetrics, RegressionPlots

#read in the predictions
chembl_preds = pd.read_csv("chembl_predictions.csv")

chembl_exprx_preds = pd.read_csv("chembl+exprx_predictions.csv")
[6]:
chembl_y_preds = chembl_preds['OADMET_PRED_chemprop_LOG_CLint_HLM'].to_numpy()
chembl_y_true = chembl_preds['LOG_CLint_HLM'].to_numpy()

chembl_exprx_y_preds = chembl_exprx_preds['OADMET_PRED_chemprop_LOG_CLint_HLM'].to_numpy()
chembl_exprx_y_true = chembl_exprx_preds['LOG_CLint_HLM'].to_numpy()
[7]:
evaluator = RegressionMetrics()
data = RegressionMetrics.evaluate(
    evaluator,
    y_true=chembl_y_true,
    y_pred=chembl_y_preds,
    tag="CLint",
    target_label="Log10_CLint"
    )

stat_dict = RegressionMetrics.get_stat_dict(evaluator, "task_0")

RegressionPlots.regplot(
    y_true=chembl_y_true,
    y_pred=chembl_y_preds,
    xlabel=r"Measured $\log_{10}(\mathrm{CL}_{\mathrm{int}})$",
    ylabel=r"Predicted $\log_{10}(\mathrm{CL}_{\mathrm{int}})$",
    title="Model: ChEMBL-Finetuned CheMeleon Multitask\n"
    "Predicting on: ExpRx Test Data \n"
    "Target: HLM",
    stat_dict=stat_dict,
    pXC50=True,
    min_val=-1,
    max_val=4,
    fit_reg=False
)
[7]:
<seaborn.axisgrid.JointGrid at 0x15f5f3c80>
../../_images/demos_06_Finetuning_06_Finetuning_12_1.png

Oof! The ChEMBL-only finetuned CheMeleon model isn’t the greatest predicting on the ExpRx test data.

Now, let’s take a look at the CheMeleon model finetuned on both ChEMBL and ExpRx train data:

[8]:
evaluator = RegressionMetrics()
data = RegressionMetrics.evaluate(
    evaluator,
    y_true=chembl_exprx_y_preds,
    y_pred=chembl_exprx_y_true,
    tag="CLint",
    target_label="Log10_CLint"
    )

stat_dict = RegressionMetrics.get_stat_dict(evaluator, "task_0")

RegressionPlots.regplot(
    y_true=chembl_exprx_y_preds,
    y_pred=chembl_exprx_y_true,
    xlabel=r"Measured $\log_{10}(\mathrm{CL}_{\mathrm{int}})$",
    ylabel=r"Predicted $\log_{10}(\mathrm{CL}_{\mathrm{int}})$",
    title="Model: ChEMBL & ExpRx Train-Finetuned CheMeleon Multitask\n"
    "Predicting on: ExpRx Test Data \n"
    "Target: HLM",
    stat_dict=stat_dict,
    pXC50=True,
    min_val=-1,
    max_val=4,
    fit_reg=False
)
[8]:
<seaborn.axisgrid.JointGrid at 0x15fdff800>
../../_images/demos_06_Finetuning_06_Finetuning_14_1.png

Wow! That looks much better.

This demo showcases the potential gains of finetuning in model performance, even with just a small subset of data. Try it out with your own data!

An alternative method of finetuning with ANVIL

In the above example, we showed finetuning by training two models on ChEMBL data and then ChEMBL+ExpRx train data. Alternatively, you can first finetune the model on ChEMBL data, then load the model checkpoint and finetune on the ExpRx data. This negates the need to concatenate a new dataset.

You can do this very easily in the ANVIL recipe yaml file, under the model section:

model:
  # Indicate model type
  # See openadmet.models.architecture
  type: ChemPropModel
  # load Pre-trained model
  serial_path: chemeleon_chembl/model.pth
  param_path: chemeleon_chembl/model.json
  freeze_weights:
    message_passing: true
    batch_norm: false
  # Specify model parameters
  params:
    ffn_hidden_dim: 512
    ffn_hidden_num_layers: 3
    mpnn_lr: 1e-4
    ffn_lr: 1e-3
    mpnn_weight_decay: 0
    ffn_weight_decay: 1e-4
    dropout: 0.25
    batch_norm: False
    scheduler: plateau
    reduce_lr_patience: 5
    reduce_lr_factor: 0.5
    n_tasks: 3 # Number of tasks should match the number of target columns
    from_chemeleon: false # use the pre-trained Chemeleon model, will overwrite the model parameters except FFN
[10]:
checkpoint_preds = pd.read_csv("checkpoint_training/predictions.csv")

ckpt_preds = checkpoint_preds['OADMET_PRED_chemprop_LOG_CLint_HLM'].to_numpy()
ckpt_true = checkpoint_preds['LOG_CLint_HLM'].to_numpy()

evaluator = RegressionMetrics()
data = RegressionMetrics.evaluate(
    evaluator,
    y_true=ckpt_preds,
    y_pred=ckpt_true,
    tag="CLint",
    target_label="Log10_CLint"
    )

stat_dict = RegressionMetrics.get_stat_dict(evaluator, "task_0")

RegressionPlots.regplot(
    y_true=ckpt_preds,
    y_pred=ckpt_true,
    xlabel=r"Measured $\log_{10}(\mathrm{CL}_{\mathrm{int}})$",
    ylabel=r"Predicted $\log_{10}(\mathrm{CL}_{\mathrm{int}})$",
    title="Model: ChEMBL & ExpRx Train-Finetuned CheMeleon Multitask\n"
    "Predicting on: ExpRx Test Data \n"
    "Target: HLM",
    stat_dict=stat_dict,
    pXC50=True,
    min_val=-1,
    max_val=4,
    fit_reg=False
)
[10]:
<seaborn.axisgrid.JointGrid at 0x15f831af0>
../../_images/demos_06_Finetuning_06_Finetuning_17_1.png

~ End of 06_Finetuning ~