Skip to content

Commit

Permalink
Customize UserDefinedDagsterK8sConfig per op via tags (dagster-io#23053)
Browse files Browse the repository at this point in the history
## Summary & Motivation

As stated in dagster-io#22138 , Dagster currently requires that op/job limits in
k8s are defined statically - and no dynamic overload is allowed.
I'd like to resolve this issue for `celery-k8s` launcher (we use them in
our production).
This PR tries to address the problem by introducing a special tag
`"dagster-k8s/config-per-op"` that allows to override
`user_defined_k8s_config`.
I expect this to work as follows:

```python
from dagster import op, job, schedule, RunRequest


@op
def op1():
    print(1 + 2)


# op defines it's own limits
@op(
    tags={
        "dagster-k8s/config": {
            "container_config": {
                "resources": {
                    "requests": {"cpu": "1", "memory": "1Gi"},
                    "limits": {"cpu": "10", "memory": "10Gi"},
                }
            }
        }
    }
)
def op2():
    print(7)


# <------- one way to override op tags
@job(
    tags={
        "dagster-k8s/config-per-op": {
            "op2": {
                "container_config": {
                    # requests will be overwritten, limits will be taken from op definition
                    "resources": {
                        "requests": {"cpu": "2", "memory": "2Gi"},
                    }
                }
            },
            "op1": {
                "container_config": {
                    # op1 gets requests/limits even though it never defined them
                    "resources": {
                        "requests": {"cpu": "11", "memory": "11Gi"},
                        "limits": {"cpu": "22", "memory": "22Gi"},
                    }
                }
            },
        }
    }
)
def main_job():
    op1()
    op2()


# <------- another way to override op tags
@schedule(job=main_job, cron_schedule="* * * * *")
def main_job():
    return RunRequest(
        tags={
            "dagster-k8s/config-per-op": {
                "op2": {
                    "container_config": {
                        # requests will be overwritten at schedule level, limits will be taken from op definition
                        "resources": {
                            "requests": {"cpu": "8", "memory": "8Gi"},
                        }
                    }
                },
                "op1": {
                    "container_config": {
                        # op1 gets requests/limits from schedule definition
                        "resources": {
                            "requests": {"cpu": "88", "memory": "88Gi"},
                            "limits": {"cpu": "89", "memory": "89Gi"},
                        }
                    }
                },
            }
        }
    )
```

I could use this feature for jobs that vary in resource consumption (for
instance, jobs that accept sql query and download data from db).

Could you please take a look at the code and check it for broken corner
cases or violated contracts?
This is my first dive into Dagster codebase so I'll appreciate any
comments and suggestions.

## How I Tested These Changes
I modified some unit-tests in k8s-related modules in dagster. No
integration testing was done.

---------

Co-authored-by: gibsondan <[email protected]>
  • Loading branch information
alekseik1 and gibsondan committed Jul 24, 2024
1 parent 8dea2ef commit b754b20
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from dagster import Field, Float, Noneable, StringSource
from dagster import Field, Float, Map, Noneable, StringSource
from dagster._core.remote_representation import IN_PROCESS_NAME
from dagster._utils.merger import merge_dicts
from dagster_celery.executor import CELERY_CONFIG
from dagster_k8s import DagsterK8sJobConfig
from dagster_k8s.client import DEFAULT_WAIT_TIMEOUT
from dagster_k8s.job import USER_DEFINED_K8S_CONFIG_SCHEMA

CELERY_K8S_CONFIG_KEY = "celery-k8s"

Expand Down Expand Up @@ -52,6 +53,12 @@ def celery_k8s_executor_config():
f" Defaults to {DEFAULT_WAIT_TIMEOUT} seconds."
),
),
"per_step_k8s_config": Field(
Map(str, USER_DEFINED_K8S_CONFIG_SCHEMA, key_label_name="step_name"),
is_required=False,
default_value={},
description="Per op k8s configuration overrides.",
),
}

cfg = merge_dicts(CELERY_CONFIG, job_config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
DagsterK8sUnrecoverableAPIError,
DagsterKubernetesClient,
)
from dagster_k8s.container_context import K8sContainerContext
from dagster_k8s.job import (
UserDefinedDagsterK8sConfig,
get_k8s_job_name,
Expand Down Expand Up @@ -128,6 +129,7 @@ def celery_k8s_job_executor(init_context):
kubeconfig_file=exc_cfg.get("kubeconfig_file"),
repo_location_name=exc_cfg.get("repo_location_name"),
job_wait_timeout=exc_cfg.get("job_wait_timeout"),
per_step_k8s_config=exc_cfg.get("per_step_k8s_config", {}),
)


Expand All @@ -145,6 +147,7 @@ def __init__(
kubeconfig_file=None,
repo_location_name=None,
job_wait_timeout=None,
per_step_k8s_config=None,
):
if load_incluster_config:
check.invariant(
Expand All @@ -158,6 +161,9 @@ def __init__(
self.broker = check.opt_str_param(broker, "broker", default=broker_url)
self.backend = check.opt_str_param(backend, "backend", default=result_backend)
self.include = check.opt_list_param(include, "include", of_type=str)
self.per_step_k8s_config = check.opt_dict_param(
per_step_k8s_config, "per_step_k8s_config", key_type=str, value_type=dict
)
self.config_source = dict_wrapper(
dict(DEFAULT_CONFIG, **check.opt_dict_param(config_source, "config_source"))
)
Expand Down Expand Up @@ -222,6 +228,7 @@ def _submit_task_k8s_job(app, plan_context, step, queue, priority, known_state):
job_config_dict=job_config.to_dict(),
job_namespace=plan_context.executor.job_namespace,
user_defined_k8s_config_dict=user_defined_k8s_config.to_dict(),
per_step_k8s_config=plan_context.executor.per_step_k8s_config,
load_incluster_config=plan_context.executor.load_incluster_config,
job_wait_timeout=plan_context.executor.job_wait_timeout,
kubeconfig_file=plan_context.executor.kubeconfig_file,
Expand Down Expand Up @@ -267,6 +274,7 @@ def _execute_step_k8s_job(
job_namespace,
load_incluster_config,
job_wait_timeout,
per_step_k8s_config=None,
user_defined_k8s_config_dict=None,
kubeconfig_file=None,
):
Expand Down Expand Up @@ -368,6 +376,17 @@ def _execute_step_k8s_job(
labels["dagster/code-location"] = (
dagster_run.external_job_origin.repository_origin.code_location_origin.location_name
)
per_op_override = per_step_k8s_config.get(step_key, {})

tag_container_context = K8sContainerContext(run_k8s_config=user_defined_k8s_config)
executor_config_container_context = K8sContainerContext(
run_k8s_config=UserDefinedDagsterK8sConfig.from_dict(per_op_override)
)
merged_user_defined_k8s_config = tag_container_context.merge(
executor_config_container_context
).run_k8s_config
user_defined_k8s_config = merged_user_defined_k8s_config

job = construct_dagster_k8s_job(
job_config,
args,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import subprocess
from unittest import mock

from dagster import (
_check as check,
job,
op,
)
from dagster._core.definitions.reconstruct import ReconstructableJob
from dagster._core.events.utils import filter_dagster_events_from_cli_logs
from dagster._core.execution.api import execute_job
from dagster._core.test_utils import instance_for_test
from dagster._serdes import serialize_value, unpack_value
from dagster_celery_k8s.executor import celery_k8s_job_executor


@op(
tags={
"dagster-k8s/config": {
"container_config": {
"resources": {
"requests": {"cpu": "444m", "memory": "444Mi"},
"limits": {"cpu": "444m", "memory": "444Mi"},
}
}
}
}
)
def op1():
return


@job(executor_def=celery_k8s_job_executor)
def some_job():
op1()


def celery_mock():
# Wrap around celery into single-process queue
celery_mock = mock.MagicMock()

class SimpleQueueWrapper:
queue = []

def __init__(self, f):
self.f = f
self.request = mock.MagicMock()
# self.request.hostname inside celery task definition
self.request.hostname = "test-celery-worker-name"

def si(self, **kwargs):
self.queue.append(kwargs)
return self

def apply_async(self, **kwargs):
from celery.result import AsyncResult

serialized_events = []
for task_kwargs in self.queue:
# Run the mocked task
self.f(self, **task_kwargs)

# Run the step locally that would have run in the k8s job and return the serialized events
execute_step_args_packed = task_kwargs["execute_step_args_packed"]
execute_step_args = unpack_value(
check.dict_param(
execute_step_args_packed,
"execute_step_args_packed",
)
)
args = execute_step_args.get_command_args()
result = subprocess.run(args, check=True, capture_output=True)
raw_logs = result.stdout

logs = raw_logs.decode("utf-8").split("\n")
events = filter_dagster_events_from_cli_logs(logs)
serialized_events += [serialize_value(event) for event in events]

# apply async must return AsyncResult
rv = AsyncResult(id="123", task_name="execute_step_k8s_job", backend=celery_mock)
rv.ready = lambda: True
rv.get = lambda: serialized_events
return rv

celery_mock.return_value.task.return_value = lambda f: SimpleQueueWrapper(f)
return celery_mock


def test_per_step_k8s_config(kubeconfig_file):
"""We expected precedence order as follows:
1. celery_k8s_job_executor is most important, it precedes everything else. Is specified, `run_config` from RunRequest is ignored.
Precedence order:
1) at job-s definition (via executor_def=...)
2) at Definitions (via executor=...)
2. after it goes run_config from request limit in schedule. If celery_k8s_job_executor is configured (via Definitions or via job), RunRequest config is ignored.
3. Then goes tag "dagster-k8s/config" from op.
This test only checks executor and op's overrides.
"""
mock_k8s_client_batch_api = mock.MagicMock()

default_config = dict(
instance_config_map="dagster-instance",
postgres_password_secret="dagster-postgresql-secret1",
dagster_home="/opt/dagster/dagster_home",
load_incluster_config=False,
kubeconfig_file=kubeconfig_file,
)

with instance_for_test(
overrides={
"run_launcher": {
"module": "dagster_celery_k8s",
"class": "CeleryK8sRunLauncher",
"config": default_config,
}
}
) as instance:
run_config = {
"execution": {
"config": {
"job_image": "some-job-image:tag",
"per_step_k8s_config": {
"op1": {
"container_config": {
"resources": {
"requests": {"cpu": "111m", "memory": "111Mi"},
"limits": {"cpu": "222m", "memory": "222Mi"},
}
}
}
},
"load_incluster_config": False,
"kubeconfig_file": kubeconfig_file,
}
}
}

with mock.patch("dagster_celery.core_execution_loop.make_app", celery_mock()), mock.patch(
"dagster_celery_k8s.executor.DagsterKubernetesClient.production_client",
mock_k8s_client_batch_api,
):
result = execute_job(
ReconstructableJob.for_file(__file__, "some_job"),
run_config=run_config,
instance=instance,
)
assert result.success

expected_mock = mock_k8s_client_batch_api().batch_api.create_namespaced_job
assert expected_mock.called
created_containers = expected_mock.call_args.kwargs["body"].spec.template.spec.containers
assert len(created_containers) == 1
container = created_containers[0]
assert container.resources.limits == {"cpu": "222m", "memory": "222Mi"}
assert container.resources.requests == {"cpu": "111m", "memory": "111Mi"}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def test_empty_celery_config():
"volume_mounts": [],
"volumes": [],
"load_incluster_config": True,
"per_step_k8s_config": {},
"repo_location_name": "<<in_process>>",
"job_wait_timeout": DEFAULT_WAIT_TIMEOUT,
}
Expand All @@ -53,6 +54,7 @@ def test_get_validated_celery_k8s_executor_config():
"retries": {"enabled": {}},
"job_image": "foo",
"load_incluster_config": True,
"per_step_k8s_config": {},
"repo_location_name": "<<in_process>>",
"job_wait_timeout": DEFAULT_WAIT_TIMEOUT,
"volume_mounts": [],
Expand All @@ -79,6 +81,7 @@ def test_get_validated_celery_k8s_executor_config():
"backend": "rpc:https://",
"retries": {"enabled": {}},
"env_config_maps": ["config-pipeline-env"],
"per_step_k8s_config": {},
"load_incluster_config": True,
"job_namespace": "my-namespace",
"repo_location_name": "<<in_process>>",
Expand Down Expand Up @@ -143,6 +146,7 @@ def test_get_validated_celery_k8s_executor_config():
"config_source": {"task_annotations": """{'*': {'on_failure': my_on_failure}}"""},
"retries": {"disabled": {}},
"job_image": "foo",
"per_step_k8s_config": {},
"image_pull_policy": "IfNotPresent",
"image_pull_secrets": [{"name": "super-secret-1"}, {"name": "super-secret-2"}],
"service_account_name": "my-cool-service-acccount",
Expand All @@ -162,6 +166,7 @@ def test_get_validated_celery_k8s_executor_config_for_job():
"retries": {"enabled": {}},
"job_image": "foo",
"load_incluster_config": True,
"per_step_k8s_config": {},
"repo_location_name": "<<in_process>>",
"job_wait_timeout": DEFAULT_WAIT_TIMEOUT,
"volume_mounts": [],
Expand All @@ -187,6 +192,7 @@ def test_get_validated_celery_k8s_executor_config_for_job():
"retries": {"enabled": {}},
"env_config_maps": ["config-pipeline-env"],
"load_incluster_config": True,
"per_step_k8s_config": {},
"job_namespace": "my-namespace",
"repo_location_name": "<<in_process>>",
"job_wait_timeout": DEFAULT_WAIT_TIMEOUT,
Expand Down Expand Up @@ -253,6 +259,7 @@ def test_get_validated_celery_k8s_executor_config_for_job():
"config_source": {"task_annotations": """{'*': {'on_failure': my_on_failure}}"""},
"retries": {"disabled": {}},
"job_image": "foo",
"per_step_k8s_config": {},
"image_pull_policy": "IfNotPresent",
"image_pull_secrets": [{"name": "super-secret-1"}, {"name": "super-secret-2"}],
"service_account_name": "my-cool-service-acccount",
Expand Down

0 comments on commit b754b20

Please sign in to comment.