This is an official pytorch implementation of our paper [Transformer for Non-rigid Tracking and Reconstruction]. In this repository, we provide PyTorch code for training and testing our proposed TR4TR model.
First, clone the repository locally:
git clone https://github.com/xfliu1998/tr4tr-main.git
After cloning this repo, cd
into it and create a conda environment for the project:
cd resources
conda env create --file env.yaml
cd ..
Then, activate the environment:
conda activate tr4tr
We train and evaluate our network using the DeepDeform dataset,
the original data can be obtained at the DeepDeform repository.
After downloading the data, you need to change the dataset path in file config.yaml
and utils/data_utils.py
to the path where you downloaded the data.
Then generate the json file to train and evaluate with the following command:
cd utils
sh data_utils.sh
cd ..
You can set your customized model parameters by modifying the file config.yaml
,
including modifying the input form of data, network architecture parameters, and training hyperparameters.
Then you need to modify the GPU parameters in file main.sh
.
You can train and evaluate the model with the following command:
sh main.sh
If you want to visualize the results, you need to specify the path of the pre-trained model in file config.yaml
and your own file path,
and write the following json file val_.json
to put under the same level file as the data you want to visualize.
[
{
"source_color": "val/color/shirt_000000.jpg",
"source_depth": "val/depth/shirt_000000.png",
"target_color": "val/color/shirt_000100.jpg",
"target_depth": "val/depth/shirt_000100.png",
"object_id": "shirt",
"source_id": "000000",
"target_id": "000100",
"optical_flow": "val/optical_flow/shirt_000000_000100.oflow",
"scene_flow": "val/scene_flow/shirt_000000_000100.sflow"
}
]
Then you need to modify the parameter experiment_mode='predict'
in the file main.sh
and run the command:
sh main.sh
You can refer to file utils/visual_utils.py
for related view instructions.
TR4TR is released under the Apache 2.0 license. Please see the LICENSE file for more information.