Skip to content

Commit

Permalink
Merge pull request #1 from michelle123lam/poonam
Browse files Browse the repository at this point in the history
seed changes to lloom
  • Loading branch information
michelle123lam committed May 2, 2024
2 parents 24a7f5b + bea84b9 commit 1f74f45
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 10 deletions.
6 changes: 4 additions & 2 deletions text_lloom/src/text_lloom/concept.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@

# CONCEPT class ================================
class Concept:
def __init__(self, name, prompt, example_ids, active, summary=None):
def __init__(self, name, prompt, example_ids, active, summary=None, seed=None):
concept_id = str(uuid.uuid4())
self.id = concept_id
self.name = name
self.prompt = prompt
self.example_ids = example_ids
self.active = active
self.summary = summary
self.seed = seed

def to_dict(self):
return {
Expand All @@ -22,5 +23,6 @@ def to_dict(self):
"prompt": self.prompt,
"example_ids": list(self.example_ids),
"active": self.active,
"summary": self.summary
"summary": self.summary,
"seed": self.seed
}
37 changes: 29 additions & 8 deletions text_lloom/src/text_lloom/concept_induction.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
# CONSTANTS ================================
NAN_SCORE = 0 # Numerical score to use in place of NaN values for matrix viz
OUTLIER_CRITERIA = "Did the example not match any of the above concepts?"
SCORE_DF_OUT_COLS = ["doc_id", "text", "concept_id", "concept_name", "concept_prompt", "score", "rationale", "highlight"]
SCORE_DF_OUT_COLS = ["doc_id", "text", "concept_id", "concept_name", "concept_prompt", "score", "rationale", "highlight", "concept_seed"]


# HELPER functions ================================
Expand Down Expand Up @@ -326,6 +326,8 @@ def get_n_concepts_phrase(cur_set):
seed_phrase = f"If possible, please make the patterns RELATED TO {seed.upper()}."
else:
seed_phrase = ""

seed_label = seed
arg_dicts = []
cluster_ids = cluster_df[cluster_id_col].unique()
cluster_dfs = {} # Store each cluster's dataframe by cluster_id
Expand Down Expand Up @@ -381,14 +383,15 @@ def get_n_concepts_phrase(cur_set):
prompt=concept_dict["prompt"],
example_ids=ex_ids,
active=False,
seed=seed_label
)
concepts[concept.id] = concept

for ex_id in ex_ids:
# doc_id, text, concept_id, concept_name, concept_prompt
cur_key = (ex_id, cur_cluster_id)
if cur_key in ex_id_to_ex:
row = [ex_id, ex_id_to_ex[cur_key], concept.id, concept.name, concept.prompt]
row = [ex_id, ex_id_to_ex[cur_key], concept.id, concept.name, concept.prompt, concept.seed]
rows.append(row)
# Print intermediate results
examples = cluster_dfs[cur_cluster_id][doc_col].tolist()
Expand All @@ -398,7 +401,7 @@ def get_n_concepts_phrase(cur_set):
if verbose:
print(cur_log)
# doc_id, text, concept_id, concept_name, concept_prompt
concept_df = pd.DataFrame(rows, columns=[doc_id_col, doc_col, concept_col_prefix, f"{concept_col_prefix}_name", f"{concept_col_prefix}_prompt"])
concept_df = pd.DataFrame(rows, columns=[doc_id_col, doc_col, concept_col_prefix, f"{concept_col_prefix}_name", f"{concept_col_prefix}_prompt", "seed"])

concept_df[f"{concept_col_prefix}_namePrompt"] = concept_df[f"{concept_col_prefix}_name"] + ": " + concept_df[f"{concept_col_prefix}_prompt"]
if dedupe:
Expand Down Expand Up @@ -663,6 +666,7 @@ def get_score_df(res, in_df, concept, concept_id, text_col, doc_id_col, get_high
res_dict = json_load(res, top_level_key="pattern_results")
concept_name = concept.name
concept_prompt = concept.prompt
concept_seed = concept.seed
if res_dict is not None:
rows = []
for ex in res_dict:
Expand All @@ -684,9 +688,9 @@ def get_score_df(res, in_df, concept, concept_id, text_col, doc_id_col, get_high
rationale = "" # Set rationale to empty string

if get_highlights and ("quote" in ex):
row = [doc_id, text, concept_id, concept_name, concept_prompt, ans, rationale, ex["quote"]]
row = [doc_id, text, concept_id, concept_name, concept_prompt, ans, rationale, ex["quote"], concept_seed]
else:
row = [doc_id, text, concept_id, concept_name, concept_prompt, ans, rationale, ""] # Set quote to empty string
row = [doc_id, text, concept_id, concept_name, concept_prompt, ans, rationale, "", concept_seed] # Set quote to empty string
rows.append(row)

out_df = pd.DataFrame(rows, columns=SCORE_DF_OUT_COLS)
Expand All @@ -699,6 +703,7 @@ def get_empty_score_df(in_df, concept, concept_id, text_col, doc_id_col):
# Cols: doc_id, text, concept_id, concept_name, concept_prompt, score, highlight
concept_name = concept.name
concept_prompt = concept.prompt
concept_seed = concept.seed
out_df = in_df.copy()
out_df["doc_id"] = out_df[doc_id_col]
out_df["text"] = out_df[text_col]
Expand All @@ -708,6 +713,7 @@ def get_empty_score_df(in_df, concept, concept_id, text_col, doc_id_col):
out_df["score"] = NAN_SCORE
out_df["rationale"] = ""
out_df["highlight"] = ""
out_df["concept_seed"] = concept.seed
return out_df[SCORE_DF_OUT_COLS]

# Performs scoring for one concept
Expand Down Expand Up @@ -886,7 +892,7 @@ def get_covered_by_generic(score_df, doc_id_col, threshold=1.0, generic_threshol
# Determines generic concepts
df = score_df.copy()
df["score"] = df["score"].apply(lambda x: 1 if x >= threshold else 0)
df_generic = df.groupby("concept_id").mean().reset_index()
df_generic = df.groupby("concept_id")['score'].mean().reset_index()
df_generic.rename(columns={"score": "pos_frac"}, inplace=True)
df_generic = df_generic[df_generic["pos_frac"] >= generic_threshold]
generic_concepts = df_generic["concept_id"].unique().tolist()
Expand Down Expand Up @@ -915,7 +921,6 @@ def loop(score_df, doc_col, doc_id_col, debug=False):
# TODO: Allow users to decide on custom filtering conditions
# TODO: Save generic concepts to session (to avoid later)
n_initial = len(score_df[doc_id_col].unique())

underrep_ids = get_not_covered(score_df, doc_id_col)
generic_ids = get_covered_by_generic(score_df, doc_id_col)
ids_to_include = underrep_ids + generic_ids
Expand All @@ -932,7 +937,6 @@ def loop(score_df, doc_col, doc_id_col, debug=False):
return None
return text_df


def trace():
# Input: concept_df (columns: doc_id, text, concept_id, ...), text_dfs (columns: doc_id, text)
# Output: trace_df (columns: doc_id, text, concept_id, score, text1, text2)
Expand Down Expand Up @@ -1442,3 +1446,20 @@ def edit_concept(concepts, concept_id, new_name=None, new_prompt=None, new_ex_id

# TODO: handle concept_df
return concepts

async def check_concept_seed(concepts, seed, model_name="gpt-3.5-turbo"):
concepts_prompts_str = str(concepts)
seed_str = str(seed)
arg_dicts = [
{
"concepts": concepts_prompts_str,
"seed": seed_str,
}
]

# Run prompts
prompt_template = match_concept_prompt
res_text, res_full = await multi_query_gpt_wrapper(prompt_template, arg_dicts, model_name)
res = res_text[0]
concepts_to_remove = json_load(res, top_level_key="remove")
return concepts_to_remove
19 changes: 19 additions & 0 deletions text_lloom/src/text_lloom/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,22 @@
]
}}
"""

# Match Concept ========================
match_concept_prompt = """
I have this dict of CONCEPTS (keys) and their corresponding inclusion criteria (values), as follows:
{concepts}
I have the following theme:
{seed}
Please identify any CONCEPTS that do not match the THEME. If there no such concepts, please leave the list empty.
Please respond ONLY with a valid JSON in the following format:
{{
"remove": [
"<CONCEPT_NAME_5>",
"<CONCEPT_NAME_6>",
]
}}
"""

0 comments on commit 1f74f45

Please sign in to comment.