Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
kunmukh authored Feb 20, 2023
2 parents ef74239 + a566b60 commit db9348c
Show file tree
Hide file tree
Showing 109 changed files with 7,283 additions and 5,037 deletions.
2 changes: 1 addition & 1 deletion dglgo/dglgo/apply_pipeline/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .graphpred import ApplyGraphpredPipeline
from .nodepred import ApplyNodepredPipeline
from .nodepred_sample import ApplyNodepredNsPipeline
from .graphpred import ApplyGraphpredPipeline
76 changes: 53 additions & 23 deletions dglgo/dglgo/apply_pipeline/graphpred/gen.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,50 @@
from copy import deepcopy
from pathlib import Path
from typing import Optional

import ruamel.yaml
import torch
import typer

from copy import deepcopy
from jinja2 import Template
from pathlib import Path
from pydantic import BaseModel, Field
from typing import Optional

from ...utils.factory import ApplyPipelineFactory, PipelineBase, DataFactory, GraphModelFactory
from ...utils.factory import (
ApplyPipelineFactory,
DataFactory,
GraphModelFactory,
PipelineBase,
)
from ...utils.yaml_dump import deep_convert_dict, merge_comment

pipeline_comments = {
"batch_size": "Graph batch size",
"num_workers": "Number of workers for data loading",
"save_path": "Directory to save the inference results"
"save_path": "Directory to save the inference results",
}


class ApplyGraphpredPipelineCfg(BaseModel):
batch_size: int = 32
num_workers: int = 4
save_path: str = "apply_results"


@ApplyPipelineFactory.register("graphpred")
class ApplyGraphpredPipeline(PipelineBase):
def __init__(self):
self.pipeline = {
"name": "graphpred",
"mode": "apply"
}
self.pipeline = {"name": "graphpred", "mode": "apply"}

@classmethod
def setup_user_cfg_cls(cls):
from ...utils.enter_config import UserConfig

class ApplyGraphPredUserConfig(UserConfig):
data: DataFactory.filter("graphpred").get_pydantic_config() = Field(..., discriminator="name")
general_pipeline: ApplyGraphpredPipelineCfg = ApplyGraphpredPipelineCfg()
data: DataFactory.filter("graphpred").get_pydantic_config() = Field(
..., discriminator="name"
)
general_pipeline: ApplyGraphpredPipelineCfg = (
ApplyGraphpredPipelineCfg()
)

cls.user_cfg_cls = ApplyGraphPredUserConfig

Expand All @@ -45,9 +54,13 @@ def user_cfg_cls(self):

def get_cfg_func(self):
def config(
data: DataFactory.filter("graphpred").get_dataset_enum() = typer.Option(None, help="input data name"),
cfg: Optional[str] = typer.Option(None, help="output configuration file path"),
cpt: str = typer.Option(..., help="input checkpoint file path")
data: DataFactory.filter(
"graphpred"
).get_dataset_enum() = typer.Option(None, help="input data name"),
cfg: Optional[str] = typer.Option(
None, help="output configuration file path"
),
cpt: str = typer.Option(..., help="input checkpoint file path"),
):
# Training configuration
train_cfg = torch.load(cpt)["cfg"]
Expand All @@ -57,7 +70,12 @@ def config(
else:
data = data.name
if cfg is None:
cfg = "_".join(["apply", "graphpred", data, train_cfg["model_name"]]) + ".yaml"
cfg = (
"_".join(
["apply", "graphpred", data, train_cfg["model_name"]]
)
+ ".yaml"
)

self.__class__.setup_user_cfg_cls()
generated_cfg = {
Expand All @@ -66,23 +84,31 @@ def config(
"device": train_cfg["device"],
"data": {"name": data},
"cpt_path": cpt,
"general_pipeline": {"batch_size": train_cfg["general_pipeline"]["eval_batch_size"],
"num_workers": train_cfg["general_pipeline"]["num_workers"]}
"general_pipeline": {
"batch_size": train_cfg["general_pipeline"][
"eval_batch_size"
],
"num_workers": train_cfg["general_pipeline"]["num_workers"],
},
}
output_cfg = self.user_cfg_cls(**generated_cfg).dict()
output_cfg = deep_convert_dict(output_cfg)
# Not applicable for inference
output_cfg['data'].pop('split_ratio')
output_cfg["data"].pop("split_ratio")
comment_dict = {
"device": "Torch device name, e.g., cpu or cuda or cuda:0",
"cpt_path": "Path to the checkpoint file",
"general_pipeline": pipeline_comments
"general_pipeline": pipeline_comments,
}
comment_dict = merge_comment(output_cfg, comment_dict)

yaml = ruamel.yaml.YAML()
yaml.dump(comment_dict, Path(cfg).open("w"))
print("Configuration file is generated at {}".format(Path(cfg).absolute()))
print(
"Configuration file is generated at {}".format(
Path(cfg).absolute()
)
)

return config

Expand All @@ -100,8 +126,12 @@ def gen_script(cls, user_cfg_dict):
model_name = train_cfg["model_name"]
model_code = GraphModelFactory.get_source_code(model_name)
render_cfg["model_code"] = model_code
render_cfg["model_class_name"] = GraphModelFactory.get_model_class_name(model_name)
render_cfg.update(DataFactory.get_generated_code_dict(user_cfg_dict["data"]["name"]))
render_cfg["model_class_name"] = GraphModelFactory.get_model_class_name(
model_name
)
render_cfg.update(
DataFactory.get_generated_code_dict(user_cfg_dict["data"]["name"])
)

# Dict for defining cfg in the rendered code
generated_user_cfg = deepcopy(user_cfg_dict)
Expand Down
66 changes: 45 additions & 21 deletions dglgo/dglgo/apply_pipeline/nodepred/gen.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,35 @@
from copy import deepcopy
from pathlib import Path
from typing import Optional

import ruamel.yaml
import torch
import typer

from copy import deepcopy
from jinja2 import Template
from pathlib import Path
from pydantic import Field
from typing import Optional

from ...utils.factory import ApplyPipelineFactory, PipelineBase, DataFactory, NodeModelFactory
from ...utils.factory import (
ApplyPipelineFactory,
DataFactory,
NodeModelFactory,
PipelineBase,
)
from ...utils.yaml_dump import deep_convert_dict, merge_comment


@ApplyPipelineFactory.register("nodepred")
class ApplyNodepredPipeline(PipelineBase):

def __init__(self):
self.pipeline = {
"name": "nodepred",
"mode": "apply"
}
self.pipeline = {"name": "nodepred", "mode": "apply"}

@classmethod
def setup_user_cfg_cls(cls):
from ...utils.enter_config import UserConfig

class ApplyNodePredUserConfig(UserConfig):
data: DataFactory.filter("nodepred").get_pydantic_config() = Field(..., discriminator="name")
data: DataFactory.filter("nodepred").get_pydantic_config() = Field(
..., discriminator="name"
)

cls.user_cfg_cls = ApplyNodePredUserConfig

Expand All @@ -34,9 +39,13 @@ def user_cfg_cls(self):

def get_cfg_func(self):
def config(
data: DataFactory.filter("nodepred").get_dataset_enum() = typer.Option(None, help="input data name"),
cfg: Optional[str] = typer.Option(None, help="output configuration file path"),
cpt: str = typer.Option(..., help="input checkpoint file path")
data: DataFactory.filter(
"nodepred"
).get_dataset_enum() = typer.Option(None, help="input data name"),
cfg: Optional[str] = typer.Option(
None, help="output configuration file path"
),
cpt: str = typer.Option(..., help="input checkpoint file path"),
):
# Training configuration
train_cfg = torch.load(cpt)["cfg"]
Expand All @@ -46,7 +55,12 @@ def config(
else:
data = data.name
if cfg is None:
cfg = "_".join(["apply", "nodepred", data, train_cfg["model_name"]]) + ".yaml"
cfg = (
"_".join(
["apply", "nodepred", data, train_cfg["model_name"]]
)
+ ".yaml"
)

self.__class__.setup_user_cfg_cls()
generated_cfg = {
Expand All @@ -55,22 +69,28 @@ def config(
"device": train_cfg["device"],
"data": {"name": data},
"cpt_path": cpt,
"general_pipeline": {"save_path": "apply_results"}
"general_pipeline": {"save_path": "apply_results"},
}
output_cfg = self.user_cfg_cls(**generated_cfg).dict()
output_cfg = deep_convert_dict(output_cfg)
# Not applicable for inference
output_cfg['data'].pop('split_ratio')
output_cfg["data"].pop("split_ratio")
comment_dict = {
"device": "Torch device name, e.g., cpu or cuda or cuda:0",
"cpt_path": "Path to the checkpoint file",
"general_pipeline": {"save_path": "Directory to save the inference results"}
"general_pipeline": {
"save_path": "Directory to save the inference results"
},
}
comment_dict = merge_comment(output_cfg, comment_dict)

yaml = ruamel.yaml.YAML()
yaml.dump(comment_dict, Path(cfg).open("w"))
print("Configuration file is generated at {}".format(Path(cfg).absolute()))
print(
"Configuration file is generated at {}".format(
Path(cfg).absolute()
)
)

return config

Expand All @@ -88,8 +108,12 @@ def gen_script(cls, user_cfg_dict):
model_name = train_cfg["model_name"]
model_code = NodeModelFactory.get_source_code(model_name)
render_cfg["model_code"] = model_code
render_cfg["model_class_name"] = NodeModelFactory.get_model_class_name(model_name)
render_cfg.update(DataFactory.get_generated_code_dict(user_cfg_dict["data"]["name"]))
render_cfg["model_class_name"] = NodeModelFactory.get_model_class_name(
model_name
)
render_cfg.update(
DataFactory.get_generated_code_dict(user_cfg_dict["data"]["name"])
)

# Dict for defining cfg in the rendered code
generated_user_cfg = deepcopy(user_cfg_dict)
Expand Down
Loading

0 comments on commit db9348c

Please sign in to comment.