This repository contains the source code for our paper:
RAFT: Recurrent All Pairs Field Transforms for Optical
Zachary Teed and Jia Deng
The code has been tested with PyTorch 1.6 and Cuda 10.1.
conda create --name raft conda activate raft conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 -c pytorch conda install matplotlib conda install tensorboard conda install scipy conda install opencv
Pretrained models can be downloaded by running
or downloaded from google drive
You can demo a trained model on a sequence of frames
python demo.py --model=models/raft-things.pth --path=demo-frames
To evaluate/train RAFT, you will need to download the required datasets.
datasets.py will search for the datasets
in these locations. You can create symbolic links to wherever the
datasets were downloaded in the
├── datasets ├── Sintel ├── test ├── training ├── KITTI ├── testing ├── training ├── devkit ├── FlyingChairs_release ├── data ├── FlyingThings3D ├── frames_cleanpass ├── frames_finalpass ├── optical_flow
You can evaluate a trained model using
python evaluate.py --model=models/raft-things.pth --dataset=sintel --mixed_precision
We used the following training schedule in our paper (2 GPUs).
Training logs will be written to the
runs which can be
visualized using tensorboard
If you have a RTX GPU, training can be accelerated using mixed precision. You can expect similiar results in this setting (1 GPU)
(Optional) Efficent Implementation
You can optionally use our alternate (efficent) implementation by compiling the provided cuda extension
cd alt_cuda_corr && python setup.py install && cd ..
--alternate_corr flag Note, this
implementation is somewhat slower than all-pairs, but uses
significantly less GPU memory during the forward pass.