A PyTorch implementation of the Well-Tempered Adjoint Schrödinger Bridge Sampler (WT-ASBS), which enhances diffusion-based sampling of chemical systems by introducing a sequential bias along low-dimensional projections of atomic coordinates (collective variables).
📄 Preprint: Enhancing Diffusion-Based Sampling with Molecular Collective Variables
Note: This codebase started as a fork of the adjoint_samplers repository, with modifications to support the fairchem data format and the well-tempered bias mechanism.
We recommend using uv to install the dependencies. After cloning the repository, run the following command to install the dependencies:
git clone https://github.com/facebookresearch/wt-asbs.git && cd wt-asbs
uv syncAlternatively, you can use the pip install command:
pip install -e .
# Optional: install the development dependencies
pip install -e ".[dev]"
prek installFor the reactive sampling experiments, you need to apply for model access to the UMA model repository, and log in to the Hugging Face CLI using an access token.
huggingface-cli loginMD simulation scripts to generate pretraining data and run metadynamics simulations are provided in the md_data directory.
Each subdirectory corresponds to:
ala2/: Alanine dipeptideala4/: Alanine tetrapeptidesn2/: Nucleophilic substitutionptsb/: Post-transition-state bifurcation
We also provide initial frames and the pretraining data in the md_data directory for each system.
For metadynamics simulations, PLUMED and OpenMM-PLUMED have to be additionally installed. Please refer to their documentation for installation instructions.
Logging is handled by Weights & Biases, and you need to create an account and get an API key.
You can set the WANDB_API_KEY environment variable or use the wandb login command to save the key.
We provide an example script for alanine dipeptide training and sampling in the scripts/ala2.sh file.
We assume that the script is run in the root directory of the repository.
Configuration files for all systems are provided in the src/wt_asbs/config/experiment directory, and you can modify experiment input argument to run the training for different systems.
Currently, all configurations are set for 8 GPUs.
To use a different number of GPUs, override the fabric.devices input argument.
Also, note that {train,inference}_batch_size, {adjoint,corrector}_buffer.max_size, initial_buffer_samples, and buffer_samples_per_epoch are defined as per-GPU values, so you need to multiply them by the number of GPUs to get the total values.
Pretraining saves the checkpoints in {root}/{name}/checkpoints/, and the intermediate samples in {root}/{name}/samples/.
python -u -m wt_asbs.experiment.pretrain_bm \
root=ckpts \
experiment=ala2 \
name=ala2_pretrainSimilarly, WT-ASBS training saves the checkpoints in {root}/{name}/checkpoints/, and the intermediate samples in {root}/{name}/samples/.
We need to provide the pretrained controller checkpoint from the pretraining step to the pretrained_controller_checkpoint input argument.
python -u -m wt_asbs.experiment.train_asbs \
root=ckpts \
experiment=ala2 \
name=ala2_wt_asbs \
pretrained_controller_checkpoint=ckpts/ala2_pretrain/checkpoints/epoch_1000.ptFor example, intermediate bias values (for on-the-fly PMF calculation) can be retrieved from the checkpoint as follows:
ckpt = torch.load("ckpts/ala2_wt_asbs/checkpoints/epoch_1000.pt", weights_only=False)
bias_values = ckpt["metadynamics_bias"]["bias_grid.bias_values"]
meshgrid = ckpt["metadynamics_bias"]["bias_grid.meshgrid"]Sampling from the final checkpoint saves the samples to the {root}/{name}/samples.pt file, which contains a dictionary with the following keys:
pos: (num_samples, num_atoms, 3) array of atomic positions, in units of Åcv: (num_samples, num_cvs) array of collective variablesbias: (num_samples,) array of bias values for each sample for reweighting, in units of eV
python -u -m wt_asbs.experiment.inference \
checkpoint_directory=ckpts/ala2_wt_asbs \
num_samples=1000000If you find this code useful in your research or projects, please consider citing our paper:
@misc{nam2025enhancing,
title={Enhancing Diffusion-Based Sampling with Molecular Collective Variables},
author={Juno Nam and Bálint Máté and Artur P. Toshev and Manasa Kaniselvan and Rafael Gómez-Bombarelli and Ricky T. Q. Chen and Brandon Wood and Guan-Horng Liu and Benjamin Kurt Miller},
year={2025},
eprint={2510.11923},
archivePrefix={arXiv},
primaryClass={physics.chem-ph},
url={https://arxiv.org/abs/2510.11923},
}This repository is licensed under the MIT License.
