Skip to content

DMIRLAB-Group/Dassl.pytorch

 
 

Repository files navigation

Dassl

Dassl is a research toolbox for domain adaptation and semi-supervised learning, written in PyTorch.

It is designed for the following tasks:

  • Single-source domain adaptation
  • Multi-source domain adaptation
  • Domain generalization
  • Semi-supervised learning

Overview

Dassl has implemented the following papers:

Dassl supports the following datasets.

Get started

Installation

Make sure conda is installed properly.

# Clone this repo
git clone https://github.com/KaiyangZhou/Dassl.pytorch.git
cd Dassl.pytorch/

# Create a conda environment
conda create -n dassl python=3.7

# Activate the environment
conda activate dassl

# Install dependencies
pip install -r requirements.txt

# Install torch and torchvision (select a version that suits your machine)
conda install pytorch torchvision cudatoolkit=10.1 -c pytorch

# Install this library (no need to re-build if the source code is modified)
python setup.py develop

Follow the instructions in DATASETS.md to prepare the datasets.

Training

The main interface is implemented in tools/train.py, which basically does three things:

  1. Initialize the config with cfg = setup_cfg(args) where args contains the command-line input (see tools/train.py for the list of input arguments).
  2. Instantiate a trainer with build_trainer(cfg) which loads the dataset and builds a deep neural network model.
  3. Call trainer.train() for training and evaluating the model.

Below we provide an example for training a source-only baseline on the popular domain adaptation dataset -- Office-31,

CUDA_VISIBLE_DEVICES=0 python tools/train.py \
--root $DATA \
--trainer SourceOnly \
--source-domains amazon \
--target-domains webcam \
--dataset-config-file configs/datasets/da/office31.yaml \
--config-file configs/trainers/da/source_only/office31.yaml \
--output-dir output/source_only_office31

$DATA denotes the path to the dataset folder. --dataset-config-file loads the common setting for the dataset such as image size and model architecture. --config-file loads the algorithm-specific setting such as hyper-parameters and optimization parameters.

To use multiple sources, namely the multi-source domain adaptation task, one just needs to add more sources to --source-domains. For instance, to train a source-only baseline on miniDomainNet, one can do

CUDA_VISIBLE_DEVICES=0 python tools/train.py \
--root $DATA \
--trainer SourceOnly \
--source-domains clipart painting real \
--target-domains sketch \
--dataset-config-file configs/datasets/da/mini_domainnet.yaml \
--config-file configs/trainers/da/source_only/mini_domainnet.yaml \
--output-dir output/source_only_minidn

After the training finishes, the model weights will be saved under the specified output directory, along with a log file and a tensorboard file for visualization.

Write a new trainer

A good practice is to go through dassl/engine/trainer.py to get familar with the base trainer classes, which provide generic functions and training loops. To write a trainer class for domain adaptation or semi-supervised learning, the new class can subclass TrainerXU. For domain generalization, the new class can subclass TrainerX. In particular, TrainerXU and TrainerX mainly differ in whether using a data loader for unlabeled data. With the base classes, a new trainer may only need to implement the forward_backward() method, which performs loss computation and model update. See dassl/enigne/da/source_only.py for example.

Citation

Please cite the following paper if you find Dassl useful to your research.

@article{zhou2020domain,
  title={Domain Adaptive Ensemble Learning},
  author={Zhou, Kaiyang and Yang, Yongxin and Qiao, Yu and Xiang, Tao},
  journal={arXiv preprint arXiv:2003.07325},
  year={2020}
}

About

A PyTorch toolbox for domain adaptation and semi-supervised learning.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 99.7%
  • Shell 0.3%