Skip to content

Commit

Permalink
Use sampler specific sample_fast functions in sampler.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Nzteb committed Feb 14, 2020
1 parent 89fa647 commit 088dc63
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 41 deletions.
6 changes: 3 additions & 3 deletions examples/toy-complex-train-sampling.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
job.type: train
dataset.name: fb15k
dataset.name: toy
model: complex

train.num_workers: 4
train:
type: negative_sampling
optimizer: Adagrad
Expand All @@ -22,7 +22,7 @@ negative_sampling:
s: True
o: True
p: True
implementation: python
implementation: fast



2 changes: 1 addition & 1 deletion kge/config-default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ negative_sampling:
s: False # filter s samples until all are true negatives
p: False # see above
o: False # see above
implementation: numba # numba; python
implementation: fast # standard; fast

# Whether to share the s/p/o corruptions for all triples in the batch. This
# can make training more efficient. Cannot be used with together with
Expand Down
4 changes: 2 additions & 2 deletions kge/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


def _group_by(keys, values) -> dict:
""" Groups values by keys.
"""Group values by keys.
:param keys: list of keys
:param values: list of values
Expand Down Expand Up @@ -240,7 +240,7 @@ def create_default_index_functions(dataset: "Dataset"):

@njit
def index_where_in(x, y, t_f=True):
"""Retrieves the indices of the elements in x which are also in y.
"""Retrieve the indices of the elements in x which are also in y.
x and y are assumed to be 1 dimensional arrays.
Expand Down
84 changes: 50 additions & 34 deletions kge/util/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@ def __init__(self, config: Config, configuration_key: str, dataset: Dataset):
self.vocabulary_size[slot] = (
dataset.num_relations() if slot == P else dataset.num_entities()
)
self.check_option("filtering.implementation", ["standard", "fast"])
self.filter_fast = True
if self.get_option("filtering.implementation") == "standard":
self.filter_fast = False
self.dataset = dataset
self.shared = self.get_option("shared")

self.times = []

# auto config
for slot, copy_from in [(S, O), (P, None), (O, S)]:
if self.num_samples[slot] < 0:
Expand Down Expand Up @@ -69,7 +70,6 @@ def sample(
entities (`slot`=0 or `slot`=2) or relations (`slot`=1).
"""
start = time.time()
if num_samples is None:
num_samples = self.num_samples[slot]
if self.shared:
Expand All @@ -88,7 +88,14 @@ def sample(
negative_samples = self._filter(negative_samples, slot, positive_triples)
elif self.get_option("filtering.implementation") == "numba":
negative_samples = self._filter_fast(negative_samples, slot, positive_triples)
print(f"sample took {time.time() -start}")
if self.filter_fast:
negative_samples = self._filter_fast(
negative_samples, slot, positive_triples
)
else:
negative_samples = self._filter(
negative_samples, slot, positive_triples
)
return negative_samples

def _sample(self, positive_triples: torch.Tensor, slot: int, num_samples: int):
Expand Down Expand Up @@ -145,55 +152,64 @@ def _filter(
negative_samples[i, resample_idx] = new
return negative_samples

def _filter_fast(self, result: torch.Tensor, slot: int, spo: torch.Tensor):
""" Use numba for filtering.
def _filter_fast(
self, negative_samples: torch.Tensor, slot: int, positive_triples: torch.Tensor
):
"""Sampler specific filtering. """
raise NotImplementedError(
"Use filtering.implementation=standard for this sampler."
)


A sampler has to implement a @njit _filter_numba function and registers it
in self.get_numba_sampler().
class KgeUniformSampler(KgeSampler):
def __init__(self, config: Config, configuration_key: str, dataset: Dataset):
super().__init__(config, configuration_key, dataset)

"""
spo_char = "spo"
pair = spo_char.replace(spo_char[slot], "")
def _sample(self, positive_triples: torch.Tensor, slot: int, num_samples: int):
return torch.randint(
self.vocabulary_size[slot], (positive_triples.size(0), num_samples)
)

def _filter_fast(
self, negative_samples: torch.Tensor, slot: int, positive_triples: torch.Tensor
):
pair = ["po", "so", "sp"][slot]
# holding the positive indices for the respective pair
index = self.dataset.index(f"train_{pair}_to_{spo_char[slot]}")
cols = [0, 1, 2]
cols.remove(slot)
pairs = (spo[:, cols]).numpy()
batch_size = spo.size(0)
index = self.dataset.index(f"train_{pair}_to_{SLOT_STR[slot]}")
cols = [[P, O], [S, O], [S, P]][slot]
pairs = positive_triples[:, cols].numpy()
batch_size = positive_triples.size(0)
voc_size = self.vocabulary_size[slot]
# filling a numba-dict here and then call the function was faster than 1. Using
# numba lists 2. Using a python list and convert it to an np.array and use
# offsets 3. Growing a np.array with np.append 4. leaving the loop in python and
# calling a numba function within the loop
positives = Dict()
for i in range(batch_size):
positives[tuple(pairs[i].tolist())] = np.array(
index[tuple(pairs[i].tolist())]
)
result = np.array(result)
KgeSampler._filter_numba(
positives,
pairs,
result,
batch_size,
int(voc_size),
self.get_numba_sampler(),
negative_samples = negative_samples.numpy()
KgeUniformSampler._filter_numba(
positives, pairs, negative_samples, batch_size, int(voc_size),
)
return torch.tensor(result, dtype=torch.int64)
return torch.tensor(negative_samples, dtype=torch.int64)

@njit
def _filter_numba(positives, pairs, result, batch_size, voc_size, sample_func):
def _filter_numba(positives, pairs, negative_samples, batch_size, voc_size):
for i in range(batch_size):
pos = positives[(pairs[i][0], pairs[i][1])]
# inlining the idx_wherein function here results in an internal numba
# error which asks to file a bug report
resample_idx = index_where_in(result[i], pos, True)
resample_idx = index_where_in(negative_samples[i], pos, True)
# number of new samples needed
num_new = len(resample_idx)
new = np.empty(num_new)
# number already found of the new samples needed
num_found = 0
num_remaining = num_new - num_found
while True:
if not num_remaining:
break
new_samples = sample_func(voc_size, num_remaining)
while num_remaining:
new_samples = np.random.randint(0, voc_size, num_remaining)
idx = index_where_in(new_samples, pos, False)
# store the correct (true negatives) samples found
if len(idx):
Expand All @@ -203,7 +219,7 @@ def _filter_numba(positives, pairs, result, batch_size, voc_size, sample_func):
ctr = 0
# numba does not support result[i, resample_idx] = new
for j in resample_idx:
result[i, j] = new[ctr]
negative_samples[i, j] = new[ctr]
ctr += 1

def get_numba_sampler(self):
Expand Down Expand Up @@ -248,4 +264,4 @@ def idx_where_in(x, y, t_f=True):
# casting y to a set instead a list was always slower in test scripts
# setting njit(parallel=True) slowed down the function
list_y = list(y)
return np.where(np.array([i in list_y for i in x]) == t_f)[0]
return np.where(np.array([i in list_y for i in x]) == t_f)[0]
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"sqlalchemy",
"torchviz",
"dataclasses",
"numba"
"numba==0.47.0"
],
zip_safe=False,
)

0 comments on commit 088dc63

Please sign in to comment.