Skip to content

Commit

Permalink
Merge pull request #125 from uma-pi1/refactor_job_run
Browse files Browse the repository at this point in the history
Streamline preparing job across jobs; adding pre and post run hooks
  • Loading branch information
samuelbroscheit committed Jun 24, 2020
2 parents ee58734 + b403b48 commit db88bc2
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 35 deletions.
7 changes: 1 addition & 6 deletions kge/job/auto_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,6 @@ def _load(self, checkpoint):

# -- Abstract methods --------------------------------------------------------------

def init_search(self):
"""Initialize to start a new search experiment."""
raise NotImplementedError

def register_trial(self, parameters=None):
"""Start a new trial.
Expand Down Expand Up @@ -95,8 +91,7 @@ def get_best_parameters(self):

# -- Main --------------------------------------------------------------------------

def run(self):
self.init_search()
def _run(self):

# let's go
trial_no = 0
Expand Down
3 changes: 2 additions & 1 deletion kge/job/ax_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def __getstate__(self):
del state["ax_client"]
return state

def init_search(self):
def _prepare(self):
super()._prepare()
if self.num_sobol_trials > 0:
# BEGIN: from /ax/service/utils/dispatch.py
generation_strategy = GenerationStrategy(
Expand Down
7 changes: 2 additions & 5 deletions kge/job/entity_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,9 @@ def __init__(self, config: Config, dataset: Dataset, parent_job, model):
f(self)

def _prepare(self):
super()._prepare()
"""Construct all indexes needed to run."""

if self.is_prepared:
return

# create data and precompute indexes
self.triples = self.dataset.split(self.config.get("eval.split"))
for split in self.filter_splits:
Expand Down Expand Up @@ -79,8 +77,7 @@ def _collate(self, batch):
return batch, label_coords, test_label_coords

@torch.no_grad()
def run(self) -> dict:
self._prepare()
def _run(self) -> dict:

was_training = self.model.training
self.model.eval()
Expand Down
2 changes: 1 addition & 1 deletion kge/job/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def create(config, dataset, parent_job=None, model=None):
else:
raise ValueError("eval.type")

def run(self) -> dict:
def _run(self) -> dict:
""" Compute evaluation metrics, output results to trace file """
raise NotImplementedError

Expand Down
2 changes: 1 addition & 1 deletion kge/job/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, config, dataset, parent_job=None):
for f in Job.job_created_hooks:
f(self)

def run(self):
def _run(self):
# read grid search options range
all_keys = []
all_keys_short = []
Expand Down
32 changes: 32 additions & 0 deletions kge/job/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(self, config: Config, dataset: Dataset, parent_job: "Job" = None):
self.parent_job = parent_job
self.resumed_from_job_id: Optional[str] = None
self.trace_entry: Dict[str, Any] = {}
self._is_prepared = False

# prepend log entries with the job id. Since we use random job IDs but
# want short log entries, we only output the first 8 bytes here
Expand All @@ -58,6 +59,15 @@ def __init__(self, config: Config, dataset: Dataset, parent_job: "Job" = None):
for f in Job.job_created_hooks:
f(self)

#: Hooks before running a job
#: Signature: job
self.pre_run_hooks: List[Callable[[Job], Any]] = []

#: Hooks after running a job
#: Signature: job, dict returned by the run method
self.post_run_hooks: List[Callable[[Job, Dict], Any]] = []


@staticmethod
def create(
config: Config, dataset: Optional[Dataset] = None, parent_job=None, model=None
Expand Down Expand Up @@ -131,7 +141,29 @@ def _load(self, checkpoint: Dict):
"""
pass

def _prepare(self):
pass

def run(self):
"""
Run the job: first prepare it run some pre run hooks, then execute the job
and run some post run hooks and return the result.
:return: Output of the job, if any.
"""
if not self._is_prepared:
self._prepare()

for f in self.pre_run_hooks:
f(self)

result = self._run()

for f in self.post_run_hooks:
f(self, result)

return result

def _run(self):
raise NotImplementedError

def trace(self, **kwargs) -> Dict[str, Any]:
Expand Down
2 changes: 1 addition & 1 deletion kge/job/manual_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, config: Config, dataset: Dataset, parent_job=None):
for f in Job.job_created_hooks:
f(self)

def run(self):
def _run(self):
# read search configurations and expand them to full configs
search_configs = copy.deepcopy(self.config.get("manual_search.configurations"))
all_keys = set()
Expand Down
26 changes: 6 additions & 20 deletions kge/job/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def __init__(
self.valid_job = EvaluationJob.create(
valid_conf, dataset, parent_job=self, model=self.model
)
self.is_prepared = False

# attributes filled in by implementing classes
self.loader = None
Expand Down Expand Up @@ -133,7 +132,7 @@ def create(
# perhaps TODO: try class with specified name -> extensibility
raise ValueError("train.type")

def run(self) -> None:
def _run(self) -> None:
"""Start/resume the training job and run to completion."""
self.config.log("Starting training...")
checkpoint_every = self.config.get("train.checkpoint.every")
Expand Down Expand Up @@ -295,12 +294,6 @@ def _load(self, checkpoint: Dict) -> str:
def run_epoch(self) -> Dict[str, Any]:
"Runs an epoch and returns a trace entry."

# prepare the job is not done already
if not self.is_prepared:
self._prepare()
self.model.prepare_job(self) # let the model add some hooks
self.is_prepared = True

# variables that record various statitics
sum_loss = 0.0
sum_penalty = 0.0
Expand Down Expand Up @@ -476,7 +469,8 @@ def _prepare(self):
Guaranteed to be called exactly once before running the first epoch.
"""
raise NotImplementedError
super()._prepare()
self.model.prepare_job(self) # let the model add some hooks

@dataclass
class _ProcessBatchResult:
Expand Down Expand Up @@ -553,6 +547,7 @@ def __init__(self, config, dataset, parent_job=None, model=None):
f(self)

def _prepare(self):
super()._prepare()
# determine enabled query types
self.query_types = [
key
Expand Down Expand Up @@ -764,7 +759,6 @@ class TrainingJobNegativeSampling(TrainingJob):
def __init__(self, config, dataset, parent_job=None, model=None):
super().__init__(config, dataset, parent_job, model=model)
self._sampler = KgeSampler.create(config, "negative_sampling", dataset)
self.is_prepared = False
self._implementation = self.config.check(
"negative_sampling.implementation", ["triple", "all", "batch", "auto"],
)
Expand All @@ -790,9 +784,7 @@ def __init__(self, config, dataset, parent_job=None, model=None):

def _prepare(self):
"""Construct dataloader"""

if self.is_prepared:
return
super()._prepare()

self.num_examples = self.dataset.split(self.train_split).size(0)
self.loader = torch.utils.data.DataLoader(
Expand All @@ -805,8 +797,6 @@ def _prepare(self):
pin_memory=self.config.get("train.pin_memory"),
)

self.is_prepared = True

def _get_collate_fun(self):
# create the collate function
def collate(batch):
Expand Down Expand Up @@ -1013,7 +1003,6 @@ class TrainingJob1vsAll(TrainingJob):

def __init__(self, config, dataset, parent_job=None, model=None):
super().__init__(config, dataset, parent_job, model=model)
self.is_prepared = False
config.log("Initializing spo training job...")
self.type_str = "1vsAll"

Expand All @@ -1023,9 +1012,7 @@ def __init__(self, config, dataset, parent_job=None, model=None):

def _prepare(self):
"""Construct dataloader"""

if self.is_prepared:
return
super()._prepare()

self.num_examples = self.dataset.split(self.train_split).size(0)
self.loader = torch.utils.data.DataLoader(
Expand All @@ -1040,7 +1027,6 @@ def _prepare(self):
pin_memory=self.config.get("train.pin_memory"),
)

self.is_prepared = True

def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult:
# prepare
Expand Down

0 comments on commit db88bc2

Please sign in to comment.