Skip to content

Commit

Permalink
[MNT] solve flake8 error in examples
Browse files Browse the repository at this point in the history
  • Loading branch information
GeneLiuXe committed Jan 11, 2024
1 parent c5f9a42 commit 8a4db13
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 11 deletions.
6 changes: 1 addition & 5 deletions examples/dataset_image_workflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,15 @@
@torch.no_grad()
def evaluate(model, evaluate_set: Dataset, device=None, distribution=True):
device = choose_device(0) if device is None else device

if isinstance(model, nn.Module):
model.eval()
mapping = lambda m, x: m(x)
else:
mapping = lambda m, x: m.predict(x)

criterion = nn.CrossEntropyLoss(reduction="sum")
total, correct, loss = 0, 0, torch.as_tensor(0.0, dtype=torch.float32, device=device)
dataloader = DataLoader(evaluate_set, batch_size=1024, shuffle=True)
for i, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
out = mapping(model, X)
out = model(X) if isinstance(model, nn.Module) else model.predict(X)
if not torch.is_tensor(out):
out = torch.from_numpy(out).to(device)

Expand Down
6 changes: 3 additions & 3 deletions examples/dataset_image_workflow/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _plot_labeled_peformance_curves(self, all_user_curves_data):

plt.xlabel("Amout of Labeled User Data", fontsize=14)
plt.ylabel("1 - Accuracy", fontsize=14)
plt.title(f"Results on Image Experimental Scenario", fontsize=16)
plt.title("Results on Image Experimental Scenario", fontsize=16)
plt.legend(fontsize=14)
plt.tight_layout()
plt.savefig(os.path.join(self.fig_path, "image_labeled_curves.svg"), bbox_inches="tight", dpi=700)
Expand All @@ -61,7 +61,7 @@ def _prepare_market(self, rebuild=False):
self.user_semantic = client.get_semantic_specification(self.image_benchmark.learnware_ids[0])
self.user_semantic["Name"]["Values"] = ""

if len(self.image_market) == 0 or rebuild == True:
if len(self.image_market) == 0 or rebuild is True:
for learnware_id in self.image_benchmark.learnware_ids:
with tempfile.TemporaryDirectory(prefix="image_benchmark_") as tempdir:
zip_path = os.path.join(tempdir, f"{learnware_id}.zip")
Expand All @@ -71,7 +71,7 @@ def _prepare_market(self, rebuild=False):
client.download_learnware(learnware_id, zip_path)
self.image_market.add_learnware(zip_path, semantic_spec)
break
except:
except Exception:
time.sleep(1)
continue

Expand Down
6 changes: 3 additions & 3 deletions examples/dataset_text_workflow/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _plot_labeled_peformance_curves(self, all_user_curves_data):

plt.xlabel("Amout of Labeled User Data", fontsize=14)
plt.ylabel("1 - Accuracy", fontsize=14)
plt.title(f"Results on Text Experimental Scenario", fontsize=16)
plt.title("Results on Text Experimental Scenario", fontsize=16)
plt.legend(fontsize=14)
plt.tight_layout()
plt.savefig(os.path.join(self.fig_path, "text_labeled_curves.svg"), bbox_inches="tight", dpi=700)
Expand All @@ -76,7 +76,7 @@ def _prepare_market(self, rebuild=False):
self.user_semantic = client.get_semantic_specification(self.text_benchmark.learnware_ids[0])
self.user_semantic["Name"]["Values"] = ""

if len(self.text_market) == 0 or rebuild == True:
if len(self.text_market) == 0 or rebuild is True:
for learnware_id in self.text_benchmark.learnware_ids:
with tempfile.TemporaryDirectory(prefix="text_benchmark_") as tempdir:
zip_path = os.path.join(tempdir, f"{learnware_id}.zip")
Expand All @@ -86,7 +86,7 @@ def _prepare_market(self, rebuild=False):
client.download_learnware(learnware_id, zip_path)
self.text_market.add_learnware(zip_path, semantic_spec)
break
except:
except Exception:
time.sleep(1)
continue

Expand Down

0 comments on commit 8a4db13

Please sign in to comment.