diff --git a/examples/toy-complex-train-sampling.yaml b/examples/toy-complex-train-sampling.yaml index 2e7c27546..242f23173 100644 --- a/examples/toy-complex-train-sampling.yaml +++ b/examples/toy-complex-train-sampling.yaml @@ -1,7 +1,7 @@ job.type: train -dataset.name: fb15k +dataset.name: toy model: complex - +train.num_workers: 4 train: type: negative_sampling optimizer: Adagrad @@ -22,7 +22,7 @@ negative_sampling: s: True o: True p: True - implementation: python + implementation: fast diff --git a/kge/config-default.yaml b/kge/config-default.yaml index 1716f1a22..275428a08 100644 --- a/kge/config-default.yaml +++ b/kge/config-default.yaml @@ -192,7 +192,7 @@ negative_sampling: s: False # filter s samples until all are true negatives p: False # see above o: False # see above - implementation: numba # numba; python + implementation: fast # standard; fast # Whether to share the s/p/o corruptions for all triples in the batch. This # can make training more efficient. Cannot be used with together with diff --git a/kge/indexing.py b/kge/indexing.py index 7e578b191..8e3685549 100644 --- a/kge/indexing.py +++ b/kge/indexing.py @@ -5,7 +5,7 @@ def _group_by(keys, values) -> dict: - """ Groups values by keys. + """Group values by keys. :param keys: list of keys :param values: list of values @@ -240,7 +240,7 @@ def create_default_index_functions(dataset: "Dataset"): @njit def index_where_in(x, y, t_f=True): - """Retrieves the indices of the elements in x which are also in y. + """Retrieve the indices of the elements in x which are also in y. x and y are assumed to be 1 dimensional arrays. diff --git a/kge/util/sampler.py b/kge/util/sampler.py index 2979c7cfd..cda45ec3b 100644 --- a/kge/util/sampler.py +++ b/kge/util/sampler.py @@ -29,11 +29,12 @@ def __init__(self, config: Config, configuration_key: str, dataset: Dataset): self.vocabulary_size[slot] = ( dataset.num_relations() if slot == P else dataset.num_entities() ) + self.check_option("filtering.implementation", ["standard", "fast"]) + self.filter_fast = True + if self.get_option("filtering.implementation") == "standard": + self.filter_fast = False self.dataset = dataset self.shared = self.get_option("shared") - - self.times = [] - # auto config for slot, copy_from in [(S, O), (P, None), (O, S)]: if self.num_samples[slot] < 0: @@ -69,7 +70,6 @@ def sample( entities (`slot`=0 or `slot`=2) or relations (`slot`=1). """ - start = time.time() if num_samples is None: num_samples = self.num_samples[slot] if self.shared: @@ -88,7 +88,14 @@ def sample( negative_samples = self._filter(negative_samples, slot, positive_triples) elif self.get_option("filtering.implementation") == "numba": negative_samples = self._filter_fast(negative_samples, slot, positive_triples) - print(f"sample took {time.time() -start}") + if self.filter_fast: + negative_samples = self._filter_fast( + negative_samples, slot, positive_triples + ) + else: + negative_samples = self._filter( + negative_samples, slot, positive_triples + ) return negative_samples def _sample(self, positive_triples: torch.Tensor, slot: int, num_samples: int): @@ -145,55 +152,64 @@ def _filter( negative_samples[i, resample_idx] = new return negative_samples - def _filter_fast(self, result: torch.Tensor, slot: int, spo: torch.Tensor): - """ Use numba for filtering. + def _filter_fast( + self, negative_samples: torch.Tensor, slot: int, positive_triples: torch.Tensor + ): + """Sampler specific filtering. """ + raise NotImplementedError( + "Use filtering.implementation=standard for this sampler." + ) + - A sampler has to implement a @njit _filter_numba function and registers it - in self.get_numba_sampler(). +class KgeUniformSampler(KgeSampler): + def __init__(self, config: Config, configuration_key: str, dataset: Dataset): + super().__init__(config, configuration_key, dataset) - """ - spo_char = "spo" - pair = spo_char.replace(spo_char[slot], "") + def _sample(self, positive_triples: torch.Tensor, slot: int, num_samples: int): + return torch.randint( + self.vocabulary_size[slot], (positive_triples.size(0), num_samples) + ) + + def _filter_fast( + self, negative_samples: torch.Tensor, slot: int, positive_triples: torch.Tensor + ): + pair = ["po", "so", "sp"][slot] # holding the positive indices for the respective pair - index = self.dataset.index(f"train_{pair}_to_{spo_char[slot]}") - cols = [0, 1, 2] - cols.remove(slot) - pairs = (spo[:, cols]).numpy() - batch_size = spo.size(0) + index = self.dataset.index(f"train_{pair}_to_{SLOT_STR[slot]}") + cols = [[P, O], [S, O], [S, P]][slot] + pairs = positive_triples[:, cols].numpy() + batch_size = positive_triples.size(0) voc_size = self.vocabulary_size[slot] + # filling a numba-dict here and then call the function was faster than 1. Using + # numba lists 2. Using a python list and convert it to an np.array and use + # offsets 3. Growing a np.array with np.append 4. leaving the loop in python and + # calling a numba function within the loop positives = Dict() for i in range(batch_size): positives[tuple(pairs[i].tolist())] = np.array( index[tuple(pairs[i].tolist())] ) - result = np.array(result) - KgeSampler._filter_numba( - positives, - pairs, - result, - batch_size, - int(voc_size), - self.get_numba_sampler(), + negative_samples = negative_samples.numpy() + KgeUniformSampler._filter_numba( + positives, pairs, negative_samples, batch_size, int(voc_size), ) - return torch.tensor(result, dtype=torch.int64) + return torch.tensor(negative_samples, dtype=torch.int64) @njit - def _filter_numba(positives, pairs, result, batch_size, voc_size, sample_func): + def _filter_numba(positives, pairs, negative_samples, batch_size, voc_size): for i in range(batch_size): pos = positives[(pairs[i][0], pairs[i][1])] # inlining the idx_wherein function here results in an internal numba # error which asks to file a bug report - resample_idx = index_where_in(result[i], pos, True) + resample_idx = index_where_in(negative_samples[i], pos, True) # number of new samples needed num_new = len(resample_idx) new = np.empty(num_new) # number already found of the new samples needed num_found = 0 num_remaining = num_new - num_found - while True: - if not num_remaining: - break - new_samples = sample_func(voc_size, num_remaining) + while num_remaining: + new_samples = np.random.randint(0, voc_size, num_remaining) idx = index_where_in(new_samples, pos, False) # store the correct (true negatives) samples found if len(idx): @@ -203,7 +219,7 @@ def _filter_numba(positives, pairs, result, batch_size, voc_size, sample_func): ctr = 0 # numba does not support result[i, resample_idx] = new for j in resample_idx: - result[i, j] = new[ctr] + negative_samples[i, j] = new[ctr] ctr += 1 def get_numba_sampler(self): @@ -248,4 +264,4 @@ def idx_where_in(x, y, t_f=True): # casting y to a set instead a list was always slower in test scripts # setting njit(parallel=True) slowed down the function list_y = list(y) - return np.where(np.array([i in list_y for i in x]) == t_f)[0] + return np.where(np.array([i in list_y for i in x]) == t_f)[0] \ No newline at end of file diff --git a/setup.py b/setup.py index 49829e347..1217f62f8 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ "sqlalchemy", "torchviz", "dataclasses", - "numba" + "numba==0.47.0" ], zip_safe=False, )