Skip to content

Constructing Balanced and Robust Splits for Molecular Dataset


Notifications You must be signed in to change notification settings


Folders and files

Last commit message
Last commit date

Latest commit



11 Commits

Repository files navigation


A tool to create well-balanced and robust data splits for (sparse) molecular datasets without data leakage between different tasks.

This package is based on the work of Giovanni Tricarico presented in Construction of balanced, chemically dissimilar training, validation and test sets for machine learning on molecular datasets.

The main idea is to first cluster the dataset based on chemistry alone, which usually results in a fairly large number of clusters, and then use linear programming to recombine those initial clusters into final subsets, so that the fraction of molecules and data points per task, in each subset, is as close as possible to the desired one. Additionally, stratification can be used to ensure that labels (for categorical data) and distributions (for continuous data) are homogeneous in all final subsets.


pip install git+https://[email protected]/sohviluukkonen/BalanceSplit.git@main


The split can be easily created from the command line with

python -m balancesplit.cli -i <dataset.csv>

with <datasets.csv> a pivoted dataset where each row corresponds to unique molecules and each task has its own column. For more options use -h/--help.


sklearn-style splitter

BalanceSplit offers two main modes, either to

  1. create a single split with an arbitrary number of subsets of custom size, by specifying the argument size
  2. create a k-fold cross-validation split, by specifying the argument n_splits

By default, the molecules are initially clustered based on their Morgan fingerprints with MaxMinClustering. See clustering for alternative clustering approaches and details.

Also by default, each of the input tasks are stratified into subtasks that are used during the linear programming. For more details see additional parameters.

Example data:

import pandas as pd

# Load data
data = pd.read_csv('balancesplit/test_data.csv')
smiles_list = data.SMILES.tolist()
y = data.drop(columns=['SMILES']).to_numpy()
X = np.zeros((y.shape[0], 1)) # dummy not need in balancesplit but required by the sklearn-format

Single Split

from balancesplit.splitter import BalanceSplit

# Single 80-10-10 train-validation-test split
split = BalanceSplit(sizes=[0.8, 0.1, 0.1]).split(X, y, smiles_list)
for (train_idx, val_idx, test_idx) for split:
    print(f"Train ({len(train_idx)}): {train_idx}")
    print(f"Validation ({len(val_idx)}): {val_idx}")
    print(f"Test ({len(test_idx)}): {test_idx}")

Train (697): [1, 2, 3, ..., 872, 873, 874]
Validation (89): [0, 44, 50, ..., 827, 828, 829]
Test (89): [10, 14, 15, ..., 830, 833, 834]


from balancesplit.splitter import BalanceSplit

# 5-fold train-test split
split = BalanceSplit(n_splits=5).split(X, y, smiles_list)
for i, (train_idx, test_idx) for split:
    print(f"Fold {i}")
    print(f"Train ({len(train_idx)}): {train_idx}")
    print(f"Test ({len(test_idx)}): {test_idx}")

Fold 0
Train (699): [2, 4, 5, ..., 872, 873, 874]
Test (176): [0, 1, 3, ..., 833, 834, 843]
Fold 1
Train (700): [0, 1, 3, ..., 872, 873, 874]
Test (175): [2, 4, 5, ..., 822, 831, 841]
Fold 2
Train (699): [0, 1, 2, ..., 836, 841, 843]
Test (176): [14, 15, 20, ..., 872, 873, 874]
Fold 3
Train (700): [0, 1, 2, ..., 872, 873, 874]
Test (175): [16, 17, 25, ..., 777, 811, 832]
Fold 4
Train (702): [0, 1, 2, ..., 872, 873, 874]
Test (173): [13, 18, 19, ..., 810, 835, 836]

Splitting a dataset

BalanceSplit also provides a wrapper function split_dataset which creates splits for data in a .csv file or pd.DataFrame, to compute some metrics (balance and chemical dissimilarity score) and appends the generated splits to the data.

from import split_dataset

data_with_splits = split_dataset(data_path="balancesplit/test_data.csv", splitter=BalanceSplit(sizes=[0.8, 0.1, 0.1]))

Alt text

and balancesplit.log file is created in which some splits statistics are saved:
Alt text

Additional BalanceSplit attributes

Multiple splits

  • n_repeats (int, default=1): multiple splits can be created at once (only works with some of the clustering methods)


  • stratify (bool | list[str], default=True):
  • stratify_reg_nbins (int, default=5):

Objective weights

  • equal_weight_perc_compounds_as_tasks (bool, default=False):
  • custom_weights (list[float], default=None):

LP early stopping

  • absolute_gap (float, default=1e-3)
  • time_limit_seconds (int, default=None)


Analysing and visualising a split


Constructing Balanced and Robust Splits for Molecular Dataset







No releases published


No packages published
