Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/write target classes #90

Merged
merged 3 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat: write target_classes to exp folder
  • Loading branch information
ankeko committed Nov 23, 2023
commit 1b3984ca147506a7085572ab1b8aed98b978b304
9 changes: 7 additions & 2 deletions niceml/dagster/ops/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import tqdm
from dagster import Field, Noneable, OpExecutionContext, op, Out
from hydra.utils import ConvertMode, instantiate

from niceml.config.defaultremoveconfigkeys import DEFAULT_REMOVE_CONFIG_KEYS
Expand All @@ -23,8 +24,6 @@
from niceml.mlcomponents.predictionfunction.predictionfunction import PredictionFunction
from niceml.mlcomponents.predictionhandlers.predictionhandler import PredictionHandler
from niceml.utilities.fsspec.locationutils import join_fs_path, open_location
from dagster import Field, Noneable, OpExecutionContext, op, Out

from niceml.utilities.readwritelock import FileLock


Expand Down Expand Up @@ -67,6 +66,12 @@ def prediction(
data_description: DataDescription = (
exp_context.instantiate_datadescription_from_yaml()
)
if hasattr(data_description, "target_classes"):
target_class_dict = {
i: data_description.target_classes[i]
for i in range(len(data_description.target_classes))
}
exp_context.write_json(target_class_dict, "target_classes.txt")

exp_data: ExperimentData = create_expdata_from_expcontext(exp_context)
model_path: str = exp_data.get_model_path(relative_path=True)
Expand Down
20 changes: 20 additions & 0 deletions niceml/experiments/experimentcontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
write_image,
write_parquet,
write_yaml,
write_json,
)
from niceml.utilities.timeutils import generate_timestamp

Expand Down Expand Up @@ -102,6 +103,25 @@ def write_csv(
if apply_last_modified:
self.update_last_modified()

def write_json(
self,
data: dict,
data_path: str,
apply_last_modified: bool = True,
**kwargs,
):
"""Writes a txt file relative to the experiment"""
with open_location(self.fs_config) as (file_system, root_path):
write_json(
data,
join(root_path, data_path),
file_system=file_system,
**kwargs,
)

if apply_last_modified:
self.update_last_modified()

def write_image(
self, image: Image.Image, data_path: str, apply_last_modified: bool = True
):
Expand Down