-
Notifications
You must be signed in to change notification settings - Fork 0
/
e_1_binary_classif.py
118 lines (93 loc) · 2.91 KB
/
e_1_binary_classif.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
from ltn_pytorch import Theory
K = Theory()
class A_gr(torch.nn.Module):
def __init__(self):
super(A_gr, self).__init__()
self.l1 = torch.nn.Linear(2, 8, True)
self.a1 = torch.relu
self.l2 = torch.nn.Linear(8, 8, True)
self.a2 = torch.relu
self.l3 = torch.nn.Linear(8, 1, False)
self.a3 = torch.sigmoid
def forward(self, x):
x = self.l1(x)
x = self.a1(x)
x = self.l2(x)
x = self.a2(x)
x = self.l3(x)
x = self.a3(x)
return x
# domains
K.domain('point', 'fp32', (2,))
# vars
# K.variable('x', 'point', )
K.variable(Variable('x_pos', 'point'))
K.variable(Variable('x_neg', 'point'))
# constants
# K.constant(name, domain, grounding)
# functions
# K.function(name, domain_in, domain_out, grounding)
# predicates
K.predicate(Predicate('A', ['point']), A_gr(), True)
# axioms
K.axiom(ALL('x_pos', Call('A'['x_pos'])))#'FORALL(x_pos):A(x_pos)')
K.axiom(ALL('x_neg', NOT(Call('A'['x_neg']))))#'FORALL(x_neg):!A(x_neg)')
# optimizer
opt = torch.optim.Adam(K.parameters())
# dataset
## generate data
N_pts_total = 1100
generated = []
import random
for _ in range(N_pts_total):
xy = (random.uniform(0., 1.), random.uniform(0., 1.))
l = ((xy[0] - 0.5) ** 2 + (xy[1] - 0.5) ** 2) ** 0.5 < 0.2
generated.append((xy, l))
## split
test_data = generated[:50]
train_data = generated[50:100]
validation = generated[100:]
## dataset
from torch.utils.data import Dataset
class PointsDataset(Dataset):
def __init__(self, lst):
self.items = lst
def __getitem__(self, item):
it = self.items[item]
return {'x': it[0], 'label': it[1]}
def __len__(self):
return len(self.items)
train_dataset = PointsDataset(train_data)
test_dataset = PointsDataset(test_data)
##dataloader
from torch.utils.data import DataLoader
batchsz = 32
train_loader = DataLoader(train_dataset, batchsz, shuffle=True)
test_loader = DataLoader(test_dataset, batchsz, shuffle=False)
## objective function and loss function
for ep in range(2000):
for batch in train_loader:
x_pos_i = [i for i, l in enumerate(batch['label']) if l]
x_neg_i = [i for i, l in enumerate(batch['label']) if not l]
if len(x_pos_i) == 0 or len(x_neg_i) == 0:
print('empty var case')
continue
x_pos = torch.stack([batch['x'][0][x_pos_i], batch['x'][1][x_pos_i]], 1).to(torch.float32)
x_neg = torch.stack([batch['x'][0][x_neg_i], batch['x'][1][x_neg_i]], 1).to(torch.float32)
opt.zero_grad()
sat, loss = K(x_pos=x_pos, x_neg=x_neg)
print('ep', ep, 'sat:', sat.cpu().data.numpy(), '\tloss:', loss.cpu().data.numpy(), )
loss.backward()
opt.step()
# quering
# 1.
# K.evaluate_formula('A(x)', x=[[]])
# 2.
# y = K.A(x=[])
# 3.
y = K.predicates['A'](x=torch.as_tensor([[0.5, 0.5]]))
# 4.
# y = K('A(x)', x=[[]])
# y = K('A', x=[[]])
print(y)