Skip to content

Commit

Permalink
fix: Prevent component pre-init hook from being called recursively (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
shadeMe committed Jun 21, 2024
1 parent d80e014 commit d1f8c0d
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 13 deletions.
47 changes: 34 additions & 13 deletions haystack/core/component/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from contextlib import contextmanager
from contextvars import ContextVar
from copy import deepcopy
from dataclasses import dataclass
from types import new_class
from typing import Any, Dict, Optional, Protocol, runtime_checkable

Expand All @@ -87,8 +88,24 @@
logger = logging.getLogger(__name__)


# Callback inputs: component class (Type) and init parameters (as keyword arguments) (Dict[str, Any]).
_COMPONENT_PRE_INIT_CALLBACK: ContextVar[Optional[Callable]] = ContextVar("component_pre_init_callback", default=None)
@dataclass
class PreInitHookPayload:
"""
Payload for the hook called before a component instance is initialized.
:param callback:
Receives the following inputs: component class and init parameter keyword args.
:param in_progress:
Flag to indicate if the hook is currently being executed.
Used to prevent it from being called recursively (if the component's constructor
instantiates another component).
"""

callback: Callable
in_progress: bool = False


_COMPONENT_PRE_INIT_HOOK: ContextVar[Optional[PreInitHookPayload]] = ContextVar("component_pre_init_hook", default=None)


@contextmanager
Expand All @@ -102,11 +119,11 @@ def _hook_component_init(callback: Callable):
:param callback:
Callback function to invoke.
"""
token = _COMPONENT_PRE_INIT_CALLBACK.set(callback)
token = _COMPONENT_PRE_INIT_HOOK.set(PreInitHookPayload(callback))
try:
yield
finally:
_COMPONENT_PRE_INIT_CALLBACK.reset(token)
_COMPONENT_PRE_INIT_HOOK.reset(token)


@runtime_checkable
Expand Down Expand Up @@ -172,17 +189,21 @@ def __call__(cls, *args, **kwargs):
This method is called when clients instantiate a Component and runs before __new__ and __init__.
"""
# This will call __new__ then __init__, giving us back the Component instance
pre_init_hook = _COMPONENT_PRE_INIT_CALLBACK.get()
if pre_init_hook is None:
pre_init_hook = _COMPONENT_PRE_INIT_HOOK.get()
if pre_init_hook is None or pre_init_hook.in_progress:
instance = super().__call__(*args, **kwargs)
else:
named_positional_args = ComponentMeta.positional_to_kwargs(cls, args)
assert (
set(named_positional_args.keys()).intersection(kwargs.keys()) == set()
), "positional and keyword arguments overlap"
kwargs.update(named_positional_args)
pre_init_hook(cls, kwargs)
instance = super().__call__(**kwargs)
try:
pre_init_hook.in_progress = True
named_positional_args = ComponentMeta.positional_to_kwargs(cls, args)
assert (
set(named_positional_args.keys()).intersection(kwargs.keys()) == set()
), "positional and keyword arguments overlap"
kwargs.update(named_positional_args)
pre_init_hook.callback(cls, kwargs)
instance = super().__call__(**kwargs)
finally:
pre_init_hook.in_progress = False

# Before returning, we have the chance to modify the newly created
# Component instance, so we take the chance and set up the I/O sockets
Expand Down
24 changes: 24 additions & 0 deletions test/core/pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@ def run(self, input_: str):
return {"value": input_}


@component
class FakeComponentSquared:
def __init__(self, an_init_param: Optional[str] = None):
self.an_init_param = an_init_param
self.inner = FakeComponent()

@component.output_types(value=str)
def run(self, input_: str):
return {"value": input_}


class TestPipeline:
"""
This class contains only unit tests for the Pipeline class.
Expand Down Expand Up @@ -414,6 +425,19 @@ def component_pre_init_callback_modify(name, component_cls, init_params):
assert greet.message == "modified test"
assert greet.log_level == "DEBUG"

# Test with a component that internally instantiates another component
def component_pre_init_callback_check_class(name, component_cls, init_params):
assert name == "fake_component_squared"
assert component_cls == FakeComponentSquared

pipe = Pipeline()
pipe.add_component("fake_component_squared", FakeComponentSquared())
pipe = Pipeline.from_dict(
pipe.to_dict(),
callbacks=DeserializationCallbacks(component_pre_init=component_pre_init_callback_check_class),
)
assert type(pipe.graph.nodes["fake_component_squared"]["instance"].inner) == FakeComponent

# UNIT
def test_from_dict_with_empty_dict(self):
assert Pipeline() == Pipeline.from_dict({})
Expand Down

0 comments on commit d1f8c0d

Please sign in to comment.