-
Notifications
You must be signed in to change notification settings - Fork 86
/
sample_unidiffuser_v0.py
54 lines (44 loc) · 1.22 KB
/
sample_unidiffuser_v0.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import ml_collections
def d(**kwargs):
"""Helper of creating a config dict."""
return ml_collections.ConfigDict(initial_dictionary=kwargs)
def get_config():
config = ml_collections.ConfigDict()
config.seed = 1234
config.pred = 'noise_pred'
config.z_shape = (4, 64, 64)
config.clip_img_dim = 512
config.clip_text_dim = 768
config.text_dim = 64 # reduce dimension
config.autoencoder = d(
pretrained_path='models/autoencoder_kl.pth',
)
config.caption_decoder = d(
pretrained_path="models/caption_decoder.pth",
hidden_dim=config.get_ref('text_dim')
)
config.nnet = d(
name='uvit_multi_post_ln',
img_size=64,
in_chans=4,
patch_size=2,
embed_dim=1536,
depth=30,
num_heads=24,
mlp_ratio=4,
qkv_bias=False,
pos_drop_rate=0.,
drop_rate=0.,
attn_drop_rate=0.,
mlp_time_embed=False,
text_dim=config.get_ref('text_dim'),
num_text_tokens=77,
clip_img_dim=config.get_ref('clip_img_dim'),
use_checkpoint=True
)
config.sample = d(
sample_steps=50,
scale=7.,
t2i_cfg_mode='true_uncond'
)
return config