Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Prevent component pre-init hook from being called "recursively" #7894

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 34 additions & 13 deletions haystack/core/component/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from contextlib import contextmanager
from contextvars import ContextVar
from copy import deepcopy
from dataclasses import dataclass
from types import new_class
from typing import Any, Dict, Optional, Protocol, runtime_checkable

Expand All @@ -87,8 +88,24 @@
logger = logging.getLogger(__name__)


# Callback inputs: component class (Type) and init parameters (as keyword arguments) (Dict[str, Any]).
_COMPONENT_PRE_INIT_CALLBACK: ContextVar[Optional[Callable]] = ContextVar("component_pre_init_callback", default=None)
@dataclass
class PreInitHookPayload:
"""
Payload for the hook called before a component instance is initialized.

:param callback:
Receives the following inputs: component class and init parameter keyword args.
:param in_progress:
Flag to indicate if the hook is currently being executed.
Used to prevent it from being called recursively (if the component's constructor
instantiates another component).
"""

callback: Callable
in_progress: bool = False


_COMPONENT_PRE_INIT_HOOK: ContextVar[Optional[PreInitHookPayload]] = ContextVar("component_pre_init_hook", default=None)


@contextmanager
Expand All @@ -102,11 +119,11 @@ def _hook_component_init(callback: Callable):
:param callback:
Callback function to invoke.
"""
token = _COMPONENT_PRE_INIT_CALLBACK.set(callback)
token = _COMPONENT_PRE_INIT_HOOK.set(PreInitHookPayload(callback))
try:
yield
finally:
_COMPONENT_PRE_INIT_CALLBACK.reset(token)
_COMPONENT_PRE_INIT_HOOK.reset(token)


@runtime_checkable
Expand Down Expand Up @@ -172,17 +189,21 @@ def __call__(cls, *args, **kwargs):
This method is called when clients instantiate a Component and runs before __new__ and __init__.
"""
# This will call __new__ then __init__, giving us back the Component instance
pre_init_hook = _COMPONENT_PRE_INIT_CALLBACK.get()
if pre_init_hook is None:
pre_init_hook = _COMPONENT_PRE_INIT_HOOK.get()
if pre_init_hook is None or pre_init_hook.in_progress:
instance = super().__call__(*args, **kwargs)
else:
named_positional_args = ComponentMeta.positional_to_kwargs(cls, args)
assert (
set(named_positional_args.keys()).intersection(kwargs.keys()) == set()
), "positional and keyword arguments overlap"
kwargs.update(named_positional_args)
pre_init_hook(cls, kwargs)
instance = super().__call__(**kwargs)
try:
pre_init_hook.in_progress = True
named_positional_args = ComponentMeta.positional_to_kwargs(cls, args)
assert (
set(named_positional_args.keys()).intersection(kwargs.keys()) == set()
), "positional and keyword arguments overlap"
kwargs.update(named_positional_args)
pre_init_hook.callback(cls, kwargs)
instance = super().__call__(**kwargs)
finally:
pre_init_hook.in_progress = False

# Before returning, we have the chance to modify the newly created
# Component instance, so we take the chance and set up the I/O sockets
Expand Down
24 changes: 24 additions & 0 deletions test/core/pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@ def run(self, input_: str):
return {"value": input_}


@component
class FakeComponentSquared:
def __init__(self, an_init_param: Optional[str] = None):
self.an_init_param = an_init_param
self.inner = FakeComponent()

@component.output_types(value=str)
def run(self, input_: str):
return {"value": input_}


class TestPipeline:
"""
This class contains only unit tests for the Pipeline class.
Expand Down Expand Up @@ -414,6 +425,19 @@ def component_pre_init_callback_modify(name, component_cls, init_params):
assert greet.message == "modified test"
assert greet.log_level == "DEBUG"

# Test with a component that internally instantiates another component
def component_pre_init_callback_check_class(name, component_cls, init_params):
assert name == "fake_component_squared"
assert component_cls == FakeComponentSquared

pipe = Pipeline()
pipe.add_component("fake_component_squared", FakeComponentSquared())
pipe = Pipeline.from_dict(
pipe.to_dict(),
callbacks=DeserializationCallbacks(component_pre_init=component_pre_init_callback_check_class),
)
assert type(pipe.graph.nodes["fake_component_squared"]["instance"].inner) == FakeComponent

# UNIT
def test_from_dict_with_empty_dict(self):
assert Pipeline() == Pipeline.from_dict({})
Expand Down
Loading