Skip to content

Commit

Permalink
Move utility functions from _enqueue_next_runnable_component
Browse files Browse the repository at this point in the history
  • Loading branch information
silvanocerza committed Jun 19, 2024
1 parent 28902c4 commit 1672d3c
Showing 1 changed file with 34 additions and 27 deletions.
61 changes: 34 additions & 27 deletions haystack/core/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,42 +954,19 @@ def _enqueue_next_runnable_component(
if name not in inputs_by_component:
inputs_by_component[name] = {}

# Small utility function to check if a Component has a Variadic input that is not greedy.
def is_lazy_variadic(c: Component) -> bool:
is_variadic = any(
socket.is_variadic
for socket in c.__haystack_input__._sockets_dict.values() # type: ignore
)
if not is_variadic:
return False
return not getattr(c, "__haystack_is_greedy__", False)

# Small utility function to check if a Component has all inputs with defaults.
def has_all_inputs_with_defaults(c: Component) -> bool:
return all(
not socket.is_mandatory
for socket in c.__haystack_input__._sockets_dict.values() # type: ignore
)

# Updates the inputs with the default values for the inputs that are missing
def add_missing_input_defaults(name: str, comp: Component, inputs_by_component: Dict[str, Dict[str, Any]]):
for input_socket in comp.__haystack_input__._sockets_dict.values(): # type: ignore
if input_socket.name not in inputs_by_component[name]:
inputs_by_component[name][input_socket.name] = input_socket.default_value

all_lazy_variadic = True
all_with_default_inputs = True

filtered_waiting_for_input = []

for name, comp in waiting_for_input:
if not is_lazy_variadic(comp):
if not _is_lazy_variadic(comp):
# Components with variadic inputs that are not greedy must be removed only if there's nothing else to
# run at this stage.
# We need to wait as long as possible to run them, so we can collect as most inputs as we can.
all_lazy_variadic = False

if not has_all_inputs_with_defaults(comp):
if not _has_all_inputs_with_defaults(comp):
# Components that have defaults for all their inputs must be treated the same identical way as we treat
# lazy variadic components. If there are only components with defaults we can run them.
# If we don't do this the order of execution of the Pipeline's Components will be affected cause we
Expand All @@ -998,7 +975,7 @@ def add_missing_input_defaults(name: str, comp: Component, inputs_by_component:
# logic A must be executed after B it could run instead before if we don't do this check.
all_with_default_inputs = False

if not is_lazy_variadic(comp) and not has_all_inputs_with_defaults(comp):
if not _is_lazy_variadic(comp) and not _has_all_inputs_with_defaults(comp):
# Keep track of the Components that are not lazy variadic and don't have all inputs with defaults.
# We'll handle these later if necessary.
filtered_waiting_for_input.append((name, comp))
Expand All @@ -1008,7 +985,7 @@ def add_missing_input_defaults(name: str, comp: Component, inputs_by_component:
pair = waiting_for_input.pop(0)
to_run.append(pair)
# Add missing input defaults if needed, this is a no-op for Components with Variadic inputs
add_missing_input_defaults(name, comp, inputs_by_component)
_add_missing_input_defaults(name, comp, inputs_by_component)
return

for name, comp in filtered_waiting_for_input:
Expand Down Expand Up @@ -1053,3 +1030,33 @@ def _connections_status(
receiver_sockets_list = "\n".join(receiver_sockets_entries)

return f"'{sender_node}':\n{sender_sockets_list}\n'{receiver_node}':\n{receiver_sockets_list}"


def _is_lazy_variadic(c: Component) -> bool:
"""
Small utility function to check if a Component has a Variadic input that is not greedy
"""
is_variadic = any(
socket.is_variadic
for socket in c.__haystack_input__._sockets_dict.values() # type: ignore
)
if not is_variadic:
return False
return not getattr(c, "__haystack_is_greedy__", False)


def _has_all_inputs_with_defaults(c: Component) -> bool:
"""
Small utility function to check if a Component has all inputs with defaults.
"""
return all(
not socket.is_mandatory
for socket in c.__haystack_input__._sockets_dict.values() # type: ignore
)


def _add_missing_input_defaults(name: str, comp: Component, inputs_by_component: Dict[str, Dict[str, Any]]):
"""Updates the inputs with the default values for the inputs that are missing"""
for input_socket in comp.__haystack_input__._sockets_dict.values(): # type: ignore
if input_socket.name not in inputs_by_component[name]:
inputs_by_component[name][input_socket.name] = input_socket.default_value

0 comments on commit 1672d3c

Please sign in to comment.