-
Notifications
You must be signed in to change notification settings - Fork 2
/
th_hand_prior.py
72 lines (61 loc) · 2.54 KB
/
th_hand_prior.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""
if code works:
Author: Xianghui Xie
else:
Author: Anonymous
Cite: CHORE: Contact, Human and Object REconstruction from a single RGB image. ECCV'2022
"""
import numpy as np
import torch
import pickle as pkl
from os.path import join
import yaml, sys
with open("PATHS.yml", 'r') as stream:
paths = yaml.safe_load(stream)
sys.path.append(paths['CODE'])
SMPL_ASSETS_ROOT = paths["SMPL_ASSETS_ROOT"]
def grab_prior(root_path):
lhand_data, rhand_data = load_grab_prior(root_path)
prior = np.concatenate([lhand_data['mean'], rhand_data['mean']], axis=0)
lhand_prec = lhand_data['precision']
rhand_prec = rhand_data['precision']
return prior, lhand_prec, rhand_prec
def load_grab_prior(root_path):
lhand_path = join(root_path, 'priors', 'lh_prior.pkl')
rhand_path = join(root_path, 'priors', 'rh_prior.pkl')
lhand_data = pkl.load(open(lhand_path, 'rb'))
rhand_data = pkl.load(open(rhand_path, 'rb'))
return lhand_data, rhand_data
def mean_hand_pose(root_path):
"mean hand pose computed from grab dataset"
lhand_data, rhand_data = load_grab_prior(root_path)
lhand_mean = np.array(lhand_data['mean'])
rhand_mean = np.array(rhand_data['mean'])
mean_pose = np.concatenate([lhand_mean, rhand_mean])
return mean_pose
class HandPrior:
HAND_POSE_NUM=45
def __init__(self, prior_path=SMPL_ASSETS_ROOT,
prefix=66,
device='cuda:0',
dtype=torch.float,
type='grab'):
"prefix is the index from where hand pose starts, 66 for SMPL-H"
self.prefix = prefix
if type == 'grab':
prior, lhand_prec, rhand_prec = grab_prior(prior_path)
self.mean = torch.tensor(prior, dtype=dtype).unsqueeze(axis=0).to(device)
self.lhand_prec = torch.tensor(lhand_prec, dtype=dtype).unsqueeze(axis=0).to(device)
self.rhand_prec = torch.tensor(rhand_prec, dtype=dtype).unsqueeze(axis=0).to(device)
else:
raise NotImplemented("Only grab hand prior is supported!")
def __call__(self, full_pose):
"full_pose also include body poses, this function can be used to compute loss"
temp = full_pose[:, self.prefix:] - self.mean
if self.lhand_prec is None:
return (temp*temp).sum(dim=1)
else:
lhand = torch.matmul(temp[:, :self.HAND_POSE_NUM], self.lhand_prec)
rhand = torch.matmul(temp[:, self.HAND_POSE_NUM:], self.rhand_prec)
temp2 = torch.cat([lhand, rhand], axis=1)
return (temp2 * temp2).sum(dim=1)