Skip to content

Commit

Permalink
release zero123++ fine-tuning code
Browse files Browse the repository at this point in the history
  • Loading branch information
bluestyle97 committed May 7, 2024
1 parent 55c8561 commit 34c193c
Show file tree
Hide file tree
Showing 4 changed files with 454 additions and 3 deletions.
14 changes: 11 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@ This repo is the official implementation of InstantMesh, a feed-forward framewor

https://github.com/TencentARC/InstantMesh/assets/20635237/dab3511e-e7c6-4c0b-bab7-15772045c47d

# 🚩 Todo List

# 🚩 Features and Todo List
- [x] 🔥🔥 Release Zero123++ fine-tuning code.
- [x] 🔥🔥 Support for running gradio demo on two GPUs to save memory.
- [x] 🔥🔥 Support for running demo with docker. Please refer to the [docker](docker/) directory.
- [x] Release inference and training code.
- [x] Release model weights.
- [x] Release huggingface gradio demo. Please try it at [demo](https://huggingface.co/spaces/TencentARC/InstantMesh) link.
- [x] Add support for running gradio demo on two GPUs to save memory.
- [ ] Add support for more multi-view diffusion models.

# ⚙️ Dependencies and Installation
Expand Down Expand Up @@ -76,6 +77,8 @@ If you have multiple GPUs in your machine, the demo app will run on two GPUs aut
CUDA_VISIBLE_DEVICES=0 python app.py
```

Alternatively, you can run the demo with docker. Please follow the instructions in the [docker](docker/) directory.

## Running with command line

To generate 3D meshes from images via command line, simply run:
Expand Down Expand Up @@ -112,6 +115,11 @@ python train.py --base configs/instant-nerf-large-train.yaml --gpus 0,1,2,3,4,5,
python train.py --base configs/instant-mesh-large-train.yaml --gpus 0,1,2,3,4,5,6,7 --num_nodes 1
```

We also provide our Zero123++ fine-tuning code since it is frequently requested. The running command is:
```bash
python train.py --base configs/zero123plus-finetune.yaml --gpus 0,1,2,3,4,5,6,7 --num_nodes 1
```

# :books: Citation

If you find our work useful for your research or applications, please cite using this BibTeX:
Expand Down
47 changes: 47 additions & 0 deletions configs/zero123plus-finetune.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
model:
base_learning_rate: 1.0e-05
target: zero123plus.model.MVDiffusion
params:
drop_cond_prob: 0.1

stable_diffusion_config:
pretrained_model_name_or_path: sudo-ai/zero123plus-v1.2
custom_pipeline: ./zero123plus

data:
target: src.data.objaverse_zero123plus.DataModuleFromConfig
params:
batch_size: 6
num_workers: 8
train:
target: src.data.objaverse_zero123plus.ObjaverseData
params:
root_dir: data/objaverse
meta_fname: lvis-annotations.json
image_dir: rendering_zero123plus
validation: false
validation:
target: src.data.objaverse_zero123plus.ObjaverseData
params:
root_dir: data/objaverse
meta_fname: lvis-annotations.json
image_dir: rendering_zero123plus
validation: true


lightning:
modelcheckpoint:
params:
every_n_train_steps: 1000
save_top_k: -1
save_last: true
callbacks: {}

trainer:
benchmark: true
max_epochs: -1
gradient_clip_val: 1.0
val_check_interval: 1000
num_sanity_val_steps: 0
accumulate_grad_batches: 1
check_val_every_n_epoch: null # if not set this, validation does not run
124 changes: 124 additions & 0 deletions src/data/objaverse_zero123plus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import os
import json
import numpy as np
import webdataset as wds
import pytorch_lightning as pl
import torch
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler
from PIL import Image
from pathlib import Path

from src.utils.train_util import instantiate_from_config


class DataModuleFromConfig(pl.LightningDataModule):
def __init__(
self,
batch_size=8,
num_workers=4,
train=None,
validation=None,
test=None,
**kwargs,
):
super().__init__()

self.batch_size = batch_size
self.num_workers = num_workers

self.dataset_configs = dict()
if train is not None:
self.dataset_configs['train'] = train
if validation is not None:
self.dataset_configs['validation'] = validation
if test is not None:
self.dataset_configs['test'] = test

def setup(self, stage):

if stage in ['fit']:
self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
else:
raise NotImplementedError

def train_dataloader(self):

sampler = DistributedSampler(self.datasets['train'])
return wds.WebLoader(self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler)

def val_dataloader(self):

sampler = DistributedSampler(self.datasets['validation'])
return wds.WebLoader(self.datasets['validation'], batch_size=4, num_workers=self.num_workers, shuffle=False, sampler=sampler)

def test_dataloader(self):

return wds.WebLoader(self.datasets['test'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)


class ObjaverseData(Dataset):
def __init__(self,
root_dir='objaverse/',
meta_fname='valid_paths.json',
image_dir='rendering_zero123plus',
validation=False,
):
self.root_dir = Path(root_dir)
self.image_dir = image_dir

with open(os.path.join(root_dir, meta_fname)) as f:
lvis_dict = json.load(f)
paths = []
for k in lvis_dict.keys():
paths.extend(lvis_dict[k])
self.paths = paths

total_objects = len(self.paths)
if validation:
self.paths = self.paths[-16:] # used last 16 as validation
else:
self.paths = self.paths[:-16]
print('============= length of dataset %d =============' % len(self.paths))

def __len__(self):
return len(self.paths)

def load_im(self, path, color):
pil_img = Image.open(path)

image = np.asarray(pil_img, dtype=np.float32) / 255.
alpha = image[:, :, 3:]
image = image[:, :, :3] * alpha + color * (1 - alpha)

image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
return image, alpha

def __getitem__(self, index):
while True:
image_path = os.path.join(self.root_dir, self.image_dir, self.paths[index])

'''background color, default: white'''
bkg_color = [1., 1., 1.]

img_list = []
try:
for idx in range(7):
img, alpha = self.load_im(os.path.join(image_path, '%03d.png' % idx), bkg_color)
img_list.append(img)

except Exception as e:
print(e)
index = np.random.randint(0, len(self.paths))
continue

break

imgs = torch.stack(img_list, dim=0).float()

data = {
'cond_imgs': imgs[0], # (3, H, W)
'target_imgs': imgs[1:], # (6, 3, H, W)
}
return data
Loading

0 comments on commit 34c193c

Please sign in to comment.