forked from microsoft/autogen
-
Notifications
You must be signed in to change notification settings - Fork 0
/
graph_utils.py
136 lines (108 loc) · 5.93 KB
/
graph_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from typing import Dict, List
import logging
from autogen.agentchat.groupchat import Agent
def has_self_loops(allowed_speaker_transitions: Dict) -> bool:
"""
Returns True if there are self loops in the allowed_speaker_transitions_Dict.
"""
return any([key in value for key, value in allowed_speaker_transitions.items()])
def check_graph_validity(
allowed_speaker_transitions_dict: Dict,
agents: List[Agent],
):
"""
allowed_speaker_transitions_dict: A dictionary of keys and list as values. The keys are the names of the agents, and the values are the names of the agents that the key agent can transition to.
agents: A list of Agents
Checks for the following:
Errors
1. The dictionary must have a structure of keys and list as values
2. Every key exists in agents.
3. Every value is a list of Agents (not string).
Warnings
1. Warning if there are isolated agent nodes
2. Warning if the set of agents in allowed_speaker_transitions do not match agents
3. Warning if there are duplicated agents in any values of `allowed_speaker_transitions_dict`
"""
### Errors
# Check 1. The dictionary must have a structure of keys and list as values
if not isinstance(allowed_speaker_transitions_dict, dict):
raise ValueError("allowed_speaker_transitions_dict must be a dictionary.")
# All values must be lists of Agent or empty
if not all([isinstance(value, list) for value in allowed_speaker_transitions_dict.values()]):
raise ValueError("allowed_speaker_transitions_dict must be a dictionary with lists as values.")
# Check 2. Every key exists in agents
if not all([key in agents for key in allowed_speaker_transitions_dict.keys()]):
raise ValueError("allowed_speaker_transitions_dict has keys not in agents.")
# Check 3. Every value is a list of Agents or empty list (not string).
if not all(
[all([isinstance(agent, Agent) for agent in value]) for value in allowed_speaker_transitions_dict.values()]
):
raise ValueError("allowed_speaker_transitions_dict has values that are not lists of Agents.")
# Warnings
# Warning 1. Warning if there are isolated agent nodes, there are not incoming nor outgoing edges
# Concat keys if len(value) is positive
has_outgoing_edge = []
for key, agent_list in allowed_speaker_transitions_dict.items():
if len(agent_list) > 0:
has_outgoing_edge.append(key)
no_outgoing_edges = [agent for agent in agents if agent not in has_outgoing_edge]
# allowed_speaker_transitions_dict.values() is a list of list of Agents
# values_all_agents is a list of all agents in allowed_speaker_transitions_dict.values()
has_incoming_edge = []
for agent_list in allowed_speaker_transitions_dict.values():
if len(agent_list) > 0:
has_incoming_edge.extend(agent_list)
no_incoming_edges = [agent for agent in agents if agent not in has_incoming_edge]
isolated_agents = set(no_incoming_edges).intersection(set(no_outgoing_edges))
if len(isolated_agents) > 0:
logging.warning(
f"""Warning: There are isolated agent nodes, there are not incoming nor outgoing edges. Isolated agents: {[agent.name for agent in isolated_agents]}"""
)
# Warning 2. Warning if the set of agents in allowed_speaker_transitions do not match agents
# Get set of agents
agents_in_allowed_speaker_transitions = set(has_incoming_edge).union(set(has_outgoing_edge))
full_anti_join = set(agents_in_allowed_speaker_transitions).symmetric_difference(set(agents))
if len(full_anti_join) > 0:
logging.warning(
f"""Warning: The set of agents in allowed_speaker_transitions do not match agents. Offending agents: {[agent.name for agent in full_anti_join]}"""
)
# Warning 3. Warning if there are duplicated agents in any values of `allowed_speaker_transitions_dict`
for key, values in allowed_speaker_transitions_dict.items():
duplicates = [item for item in values if values.count(item) > 1]
unique_duplicates = list(set(duplicates))
if unique_duplicates:
logging.warning(
f"Agent '{key.name}' has duplicate elements: {[agent.name for agent in unique_duplicates]}. Please remove duplicates manually."
)
def invert_disallowed_to_allowed(disallowed_speaker_transitions_dict: dict, agents: List[Agent]) -> dict:
"""
Start with a fully connected allowed_speaker_transitions_dict of all agents. Remove edges from the fully connected allowed_speaker_transitions_dict according to the disallowed_speaker_transitions_dict to form the allowed_speaker_transitions_dict.
"""
# Create a fully connected allowed_speaker_transitions_dict of all agents
allowed_speaker_transitions_dict = {agent: [other_agent for other_agent in agents] for agent in agents}
# Remove edges from allowed_speaker_transitions_dict according to the disallowed_speaker_transitions_dict
for key, value in disallowed_speaker_transitions_dict.items():
allowed_speaker_transitions_dict[key] = [
agent for agent in allowed_speaker_transitions_dict[key] if agent not in value
]
return allowed_speaker_transitions_dict
def visualize_speaker_transitions_dict(speaker_transitions_dict: dict, agents: List[Agent]):
"""
Visualize the speaker_transitions_dict using networkx.
"""
try:
import networkx as nx
import matplotlib.pyplot as plt
except ImportError as e:
logging.fatal("Failed to import networkx or matplotlib. Try running 'pip install autogen[graphs]'")
raise e
G = nx.DiGraph()
# Add nodes
G.add_nodes_from([agent.name for agent in agents])
# Add edges
for key, value in speaker_transitions_dict.items():
for agent in value:
G.add_edge(key.name, agent.name)
# Visualize
nx.draw(G, with_labels=True, font_weight="bold")
plt.show()