Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update setup files #104

Merged
merged 2 commits into from
Sep 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 2 additions & 15 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,27 +1,14 @@
STAGED := $(shell git diff --cached --name-only --diff-filter=ACMR -- 'src/***.py' | sed 's| |\\ |g')

all: format lint
echo 'Makefile for meta-learning-for-everyone repository'
all: init format lint

format:
black .
black . --line-length 104
isort .
nbqa black .
nbqa isort .

lint:
pytest src/ --pylint --flake8 --ignore=src/meta_rl/envs

lint-all:
pytest src/ --pylint --flake8 --ignore=src/meta_rl/envs --cache-clear

lint-staged:
ifdef STAGED
pytest $(STAGED) --pylint --flake8 --ignore=src/meta_rl/envs --cache-clear
else
@echo "No Staged Python File in the src folder"
endif

init:
pip install -U pip
pip install -e .
Expand Down
28 changes: 0 additions & 28 deletions pyproject.toml

This file was deleted.

2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,7 @@ jupyter-contrib-nbextensions==0.5.1
jupyter-nbextensions-configurator==0.4.1
matplotlib>=3.5.2
mujoco==2.2.0
psutil==5.9.1
setuptools==59.5.0
torchmeta>=1.8.0
tqdm>=4.62.3
6 changes: 3 additions & 3 deletions scripts/download-torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
def main() -> None:
if sys.platform == "win32" or sys.platform == "linux":
if GPUtil.getAvailable():
cli = "pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html"
cli = "pip install torch==1.9.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html"
else:
cli = "pip install torch==1.8.1+cpu torchvision==0.9.1+cpu torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html"
cli = "pip install torch==1.9.1+cpu -f https://download.pytorch.org/whl/torch_stable.html"
elif sys.platform == "darwin":
cli = "pip install torch==1.8.1 torchvision==0.9.1 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html"
cli = "pip install torch==1.9.1 -f https://download.pytorch.org/whl/torch_stable.html"
print(cli)
os.system(cli)

Expand Down
4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[isort]
line_length = 104
skip_gitignore = true
extend_skip_glob = ""
20 changes: 15 additions & 5 deletions src/meta_sl/load_dataset/load_sinusoid.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
" num_tasks=config[\"num_batches_train\"] * config[\"batch_size\"],\n",
" noise_std=None,\n",
" )\n",
" train_dataloader = BatchMetaDataLoader(train_dataset, batch_size=config[\"batch_size\"])\n",
" train_dataloader = BatchMetaDataLoader(\n",
" train_dataset, batch_size=config[\"batch_size\"]\n",
" )\n",
"\n",
" val_dataset = Sinusoid(\n",
" num_samples_per_task=config[\"num_shots\"] * 2,\n",
Expand Down Expand Up @@ -94,10 +96,18 @@
"source": [
"for batch_idx, batch in enumerate(val_dataloader):\n",
" xs, ys = batch\n",
" support_xs = xs[:, : config[\"num_shots\"], :].to(device=config[\"device\"]).type(torch.float)\n",
" query_xs = xs[:, config[\"num_shots\"] :, :].to(device=config[\"device\"]).type(torch.float)\n",
" support_ys = ys[:, : config[\"num_shots\"], :].to(device=config[\"device\"]).type(torch.float)\n",
" query_ys = ys[:, config[\"num_shots\"] :, :].to(device=config[\"device\"]).type(torch.float)\n",
" support_xs = (\n",
" xs[:, : config[\"num_shots\"], :].to(device=config[\"device\"]).type(torch.float)\n",
" )\n",
" query_xs = (\n",
" xs[:, config[\"num_shots\"] :, :].to(device=config[\"device\"]).type(torch.float)\n",
" )\n",
" support_ys = (\n",
" ys[:, : config[\"num_shots\"], :].to(device=config[\"device\"]).type(torch.float)\n",
" )\n",
" query_ys = (\n",
" ys[:, config[\"num_shots\"] :, :].to(device=config[\"device\"]).type(torch.float)\n",
" )\n",
"\n",
" print(\n",
" f\"support_x shape : {support_xs.shape}\\n\",\n",
Expand Down
42 changes: 30 additions & 12 deletions src/meta_sl/metric-based/matching_network.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,14 @@
" batch_first=True,\n",
" )\n",
"\n",
" self.lstm_cell = nn.LSTMCell(input_size=self.emb_size, hidden_size=self.emb_size)\n",
" self.lstm_cell = nn.LSTMCell(\n",
" input_size=self.emb_size, hidden_size=self.emb_size\n",
" )\n",
"\n",
" @classmethod\n",
" def convBlock(cls, in_channels: int, out_channels: int, kernel_size: int) -> nn.Sequential:\n",
" def convBlock(\n",
" cls, in_channels: int, out_channels: int, kernel_size: int\n",
" ) -> nn.Sequential:\n",
" return nn.Sequential(\n",
" nn.Conv2d(in_channels, out_channels, kernel_size, padding=1),\n",
" nn.BatchNorm2d(out_channels, momentum=1.0, track_running_stats=False),\n",
Expand All @@ -125,7 +129,9 @@
" dim=0, sizes=[batch_size, self.num_support]\n",
" )\n",
" support_bilstm, _ = self.bilstm(support_cnn)\n",
" support_bilstm_for, support_bilstm_rev = torch.tensor_split(support_bilstm, 2, dim=-1)\n",
" support_bilstm_for, support_bilstm_rev = torch.tensor_split(\n",
" support_bilstm, 2, dim=-1\n",
" )\n",
" return support_bilstm_for + support_bilstm_rev + support_cnn\n",
"\n",
" def read_out(self, hidden: torch.Tensor, support_emb: torch.Tensor) -> torch.Tensor:\n",
Expand All @@ -150,11 +156,15 @@
" cell_state = query_cnn.new_zeros(query_cnn.shape)\n",
"\n",
" for _ in range(10):\n",
" hidden_state, cell_state = self.lstm_cell(query_cnn, (hidden_state + read_out, cell_state))\n",
" hidden_state, cell_state = self.lstm_cell(\n",
" query_cnn, (hidden_state + read_out, cell_state)\n",
" )\n",
" hidden_state = hidden_state + query_cnn\n",
" read_out = self.read_out(hidden_state, support_emb)\n",
"\n",
" query_emb = hidden_state.unflatten(dim=0, sizes=[query_x.shape[0], query_x.shape[1]])\n",
" query_emb = hidden_state.unflatten(\n",
" dim=0, sizes=[query_x.shape[0], query_x.shape[1]]\n",
" )\n",
" return query_emb\n",
"\n",
" def forward(\n",
Expand Down Expand Up @@ -191,7 +201,9 @@
" support_emb_repeat, query_emb_repeat, dim=-1, eps=1e-8\n",
" ) # batch_size, num_query, num_support\n",
" attention = F.softmax(similarity, dim=-1) # batch_size, num_query, num_support\n",
" indices = support_y.unsqueeze(1).expand(-1, num_query, -1) # batch_size, num_query, num_support\n",
" indices = support_y.unsqueeze(1).expand(\n",
" -1, num_query, -1\n",
" ) # batch_size, num_query, num_support\n",
"\n",
" prob = attention.new_zeros((batch_size, num_query, num_ways))\n",
" prob.scatter_add_(-1, indices, attention) # batch_size, num_query, num_ways\n",
Expand Down Expand Up @@ -392,9 +404,9 @@
"\n",
"train_dataloader, val_dataloader, test_dataloader = get_dataloader(config)\n",
"\n",
"model = MatchingNet(in_channels=1, num_ways=config[\"num_ways\"], num_shots=config[\"num_shots\"]).to(\n",
" device=config[\"device\"]\n",
")\n",
"model = MatchingNet(\n",
" in_channels=1, num_ways=config[\"num_ways\"], num_shots=config[\"num_shots\"]\n",
").to(device=config[\"device\"])\n",
"\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)"
Expand Down Expand Up @@ -452,7 +464,9 @@
],
"source": [
"# 메타-트레이닝\n",
"with tqdm(zip(train_dataloader, val_dataloader), total=config[\"num_task_batch_train\"]) as pbar:\n",
"with tqdm(\n",
" zip(train_dataloader, val_dataloader), total=config[\"num_task_batch_train\"]\n",
") as pbar:\n",
" train_accuracies, val_accuracies = [], []\n",
" train_losses, val_losses = [], []\n",
"\n",
Expand Down Expand Up @@ -489,7 +503,9 @@
" )\n",
"\n",
" # 모델 저장하기\n",
" save_model(output_folder=config[\"output_folder\"], model=model, title=\"matching_network.th\")\n",
" save_model(\n",
" output_folder=config[\"output_folder\"], model=model, title=\"matching_network.th\"\n",
" )\n",
"\n",
" print_graph(\n",
" train_accuracies=train_accuracies,\n",
Expand Down Expand Up @@ -538,7 +554,9 @@
],
"source": [
"# 모델 불러오기\n",
"load_model(output_folder=config[\"output_folder\"], model=model, title=\"matching_network.th\")\n",
"load_model(\n",
" output_folder=config[\"output_folder\"], model=model, title=\"matching_network.th\"\n",
")\n",
"\n",
"# 메타-테스팅\n",
"with tqdm(test_dataloader, total=config[\"num_task_batch_test\"]) as pbar:\n",
Expand Down
58 changes: 40 additions & 18 deletions src/meta_sl/metric-based/prototypical_network.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -102,36 +102,44 @@
" )\n",
"\n",
" @classmethod\n",
" def convBlock(cls, in_channels: int, out_channels: int, kernel_size: int) -> nn.Sequential:\n",
" def convBlock(\n",
" cls, in_channels: int, out_channels: int, kernel_size: int\n",
" ) -> nn.Sequential:\n",
" return nn.Sequential(\n",
" nn.Conv2d(in_channels, out_channels, kernel_size, padding=1),\n",
" nn.BatchNorm2d(out_channels, momentum=1.0, track_running_stats=False),\n",
" nn.ReLU(),\n",
" nn.MaxPool2d(2),\n",
" )\n",
"\n",
" def get_prototypes(self, embeddings: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:\n",
" def get_prototypes(\n",
" self, embeddings: torch.Tensor, targets: torch.Tensor\n",
" ) -> torch.Tensor:\n",
" batch_size = embeddings.shape[0]\n",
" indices = targets.unsqueeze(-1).expand_as(embeddings)\n",
"\n",
" prototypes = embeddings.new_zeros((batch_size, self.num_ways, self.emb_size))\n",
" prototypes.scatter_add_(1, indices, embeddings).div_(float(self.num_support) / self.num_ways)\n",
" prototypes.scatter_add_(1, indices, embeddings).div_(\n",
" float(self.num_support) / self.num_ways\n",
" )\n",
" return prototypes\n",
"\n",
" def forward(\n",
" self, support_x: torch.Tensor, support_y: torch.Tensor, query_x: torch.Tensor\n",
" ) -> torch.Tensor:\n",
" batch_size = support_x.shape[0]\n",
"\n",
" support_emb = self.embedding_net(support_x.flatten(start_dim=0, end_dim=1)).unflatten(\n",
" dim=0, sizes=[batch_size, self.num_support]\n",
" )\n",
" query_emb = self.embedding_net(query_x.flatten(start_dim=0, end_dim=1)).unflatten(\n",
" dim=0, sizes=[batch_size, self.num_query]\n",
" )\n",
" support_emb = self.embedding_net(\n",
" support_x.flatten(start_dim=0, end_dim=1)\n",
" ).unflatten(dim=0, sizes=[batch_size, self.num_support])\n",
" query_emb = self.embedding_net(\n",
" query_x.flatten(start_dim=0, end_dim=1)\n",
" ).unflatten(dim=0, sizes=[batch_size, self.num_query])\n",
" proto_emb = self.get_prototypes(support_emb, support_y)\n",
"\n",
" distance = torch.sum((query_emb.unsqueeze(2) - proto_emb.unsqueeze(1)) ** 2, dim=-1)\n",
" distance = torch.sum(\n",
" (query_emb.unsqueeze(2) - proto_emb.unsqueeze(1)) ** 2, dim=-1\n",
" )\n",
" return distance"
]
},
Expand Down Expand Up @@ -278,9 +286,9 @@
"\n",
"train_dataloader, val_dataloader, test_dataloader = get_dataloader(config)\n",
"\n",
"model = PrototypicalNet(in_channels=1, num_ways=config[\"num_ways\"], num_shots=config[\"num_shots\"]).to(\n",
" device=config[\"device\"]\n",
")\n",
"model = PrototypicalNet(\n",
" in_channels=1, num_ways=config[\"num_ways\"], num_shots=config[\"num_shots\"]\n",
").to(device=config[\"device\"])\n",
"\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)"
Expand Down Expand Up @@ -338,7 +346,9 @@
],
"source": [
"# 메타-트레이닝\n",
"with tqdm(zip(train_dataloader, val_dataloader), total=config[\"num_task_batch_train\"]) as pbar:\n",
"with tqdm(\n",
" zip(train_dataloader, val_dataloader), total=config[\"num_task_batch_train\"]\n",
") as pbar:\n",
" train_accuracies, val_accuracies = [], []\n",
" train_losses, val_losses = [], []\n",
"\n",
Expand All @@ -354,7 +364,10 @@
" optimizer=optimizer,\n",
" )\n",
" val_accuracy, val_loss = test_proto(\n",
" device=config[\"device\"], task_batch=val_batch, model=model, criterion=criterion\n",
" device=config[\"device\"],\n",
" task_batch=val_batch,\n",
" model=model,\n",
" criterion=criterion,\n",
" )\n",
"\n",
" train_accuracies.append(train_accuracy)\n",
Expand All @@ -370,7 +383,11 @@
" )\n",
"\n",
" # 모델 저장하기\n",
" save_model(output_folder=config[\"output_folder\"], model=model, title=\"prototypical_network.th\")\n",
" save_model(\n",
" output_folder=config[\"output_folder\"],\n",
" model=model,\n",
" title=\"prototypical_network.th\",\n",
" )\n",
"\n",
" print_graph(\n",
" train_accuracies=train_accuracies,\n",
Expand Down Expand Up @@ -427,7 +444,9 @@
],
"source": [
"# 모델 불러오기\n",
"load_model(output_folder=config[\"output_folder\"], model=model, title=\"prototypical_network.th\")\n",
"load_model(\n",
" output_folder=config[\"output_folder\"], model=model, title=\"prototypical_network.th\"\n",
")\n",
"\n",
"# 메타-테스팅\n",
"with tqdm(test_dataloader, total=config[\"num_task_batch_test\"]) as pbar:\n",
Expand All @@ -439,7 +458,10 @@
" break\n",
"\n",
" test_accuracy, test_loss = test_proto(\n",
" device=config[\"device\"], task_batch=test_batch, model=model, criterion=criterion\n",
" device=config[\"device\"],\n",
" task_batch=test_batch,\n",
" model=model,\n",
" criterion=criterion,\n",
" )\n",
" sum_test_accuracies += test_accuracy\n",
" sum_test_losses += test_loss\n",
Expand Down
Loading