Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#4] Add VESSL Callback to Post Metrics to VESSL AI #6

Merged
merged 15 commits into from
Apr 23, 2024
Merged
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ generated-members=numpy.*, torch.*


[pylint.messages_control]
disable=missing-function-docstring, line-too-long, import-error,
disable=missing-function-docstring, line-too-long, import-error, too-many-lines,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added too-many-lines because data.py is already almost hitting the 1000 lines limit, and with puree dataset logic it exceeds the limit.

too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
3 changes: 3 additions & 0 deletions src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from axolotl.utils.models import load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import prepare_optim_env
from axolotl.utils.vessl_ import setup_vessl_env_vars
from axolotl.utils.wandb_ import setup_wandb_env_vars

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
Expand Down Expand Up @@ -369,6 +370,8 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):

setup_mlflow_env_vars(cfg)

setup_vessl_env_vars(cfg)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please note that usually, it is better you check the condition from where you call the function when the function does nothing when the condition is not met.
LGTM for now to keep it consistent since line 371 does the same as yours.


return cfg


Expand Down
5 changes: 5 additions & 0 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,11 @@ def get_callbacks(self) -> List[TrainerCallback]:
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)

if self.cfg.vessl_credential_path:
from axolotl.utils.callbacks.vessl_ import VesslLogMetricsCallback

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you tell me why you added a suffix _ to vessl_?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They added underscore too on mlflow (mlflow_) and wandb (wandb_) package, so I think it's the convention for external integration in this repo.


callbacks.append(VesslLogMetricsCallback(self.cfg.vessl_credential_path))

return callbacks

@abstractmethod
Expand Down
27 changes: 27 additions & 0 deletions src/axolotl/utils/callbacks/vessl_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Vessl module for trainer callbacks"""
import logging
from typing import Dict

import vessl
from transformers import TrainerCallback, TrainerControl, TrainerState
from transformers.training_args import TrainingArguments

LOG = logging.getLogger("axolotl.callbacks")


class VesslLogMetricsCallback(TrainerCallback):
"""Callback to send training metrics to VESSL AI"""

def __init__(self, credential_path: str) -> None:
vessl.configure(credentials_file=credential_path)

def on_log(
self,
args: TrainingArguments, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl, # pylint: disable=unused-argument
logs: Dict[str, float],
**kwargs # pylint: disable=unused-argument
):
if state.is_world_process_zero:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain where you copied this code from, please?
I am wondering why it is different to the following
https://github.com/vessl-ai/examples/blob/ebeae1c430509d619c380c56923c645cbd02f610/llama-factory/src/llmtuner/extras/callbacks.py#L162-L187

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vessl.log(logs, state.global_step)
7 changes: 7 additions & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,12 @@ def check_wandb_run(cls, data):
return data


class VesslConfig(BaseModel):
"""Vessl AI configuration subset"""

vessl_credential_path: Optional[str] = None


# pylint: disable=too-many-public-methods,too-many-ancestors
class AxolotlInputConfig(
ModelInputConfig,
Expand All @@ -404,6 +410,7 @@ class AxolotlInputConfig(
HyperparametersConfig,
WandbConfig,
MLFlowConfig,
VesslConfig,
RemappedParameters,
DeprecatedParameters,
BaseModel,
Expand Down
13 changes: 13 additions & 0 deletions src/axolotl/utils/vessl_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Module for vessl utilities"""

import os

from axolotl.utils.dict import DictDefault


def setup_vessl_env_vars(cfg: DictDefault):
# VESSL_RUN_INITIAL_CONFIG is a variable that contain path to

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cannot find any references explaining this variable. Can you attach a document pointing this variable?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got it from container environment variables, will try to attach a screenshot

# default credential inside a VESSL Run
credential_path = os.environ.get("VESSL_RUN_INITIAL_CONFIG")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should not override cfg.vessl_credential_path if it is already set

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if cfg.vessl_credential_path:
    return
credential_path = os.environ.get("VESSL_RUN_INITIAL_CONFIG")
if credential_path:
    cfg.vessl_credential_path = credential_path

if credential_path:
cfg.vessl_credential_path = credential_path
Loading