Skip to content

Commit

Permalink
Merge pull request #1169 from stanfordnlp/mipro_v2
Browse files Browse the repository at this point in the history
MIPRO optimizer updates for paper release
  • Loading branch information
XenonMolecule committed Jun 21, 2024
2 parents 01c8de0 + 1ee5479 commit 015c649
Show file tree
Hide file tree
Showing 14 changed files with 10,195 additions and 191 deletions.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@
/docs/downloads/
/docs/experiments/

/examples/qa/hotpot/MIPRO_notebook_cache/
/examples/nli/scone/MIPRO_notebook_cache/
/examples/nli/scone/ScoNe/
/examples/nli/scone/compiled_program.dspy
/examples/qa/hotpot/compiled_program.dspy
/ScoNe/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
88 changes: 35 additions & 53 deletions dsp/primitives/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,41 +61,6 @@ def _generate(template: Template, **kwargs) -> Callable:

generator = dsp.settings.lm

def extend_generation(completion: Example, field_names: list[str], stage:str, max_depth: int, original_example:Example):
"""If the required fields are not present in the completion, extend the generation."""
assert max_depth > 0, "Max depth exceeded - failed to complete in one pass - increase max_tokens"
# remove content of last field to avoid half-completed content
for field_name in get_all_fields_following_missing_field(completion, field_names):
completion.pop(field_name, None)

# Recurse with greedy decoding and a shorter length.
max_tokens = (kwargs.get("max_tokens") or
kwargs.get("max_output_tokens") or
dsp.settings.lm.kwargs.get("max_tokens") or
dsp.settings.lm.kwargs.get('max_output_tokens'))


if max_tokens is None:
raise ValueError("Required 'max_tokens' or 'max_output_tokens' not specified in settings.")
max_tokens = min(max(75, max_tokens // 2), max_tokens)
keys = list(kwargs.keys()) + list(dsp.settings.lm.kwargs.keys())
max_tokens_key = "max_tokens" if "max_tokens" in keys else "max_output_tokens"
new_kwargs = {
**kwargs,
max_tokens_key: max_tokens,
"n": 1,
"temperature": 0.0,
}

_, finished_completion = generate(template, **new_kwargs)(
completion,
stage=stage,
max_depth=max_depth - 1,
original_example=original_example,
)
return finished_completion.data[0]


def do_generate(
example: Example, stage: str, max_depth: int = 2, original_example=None,
):
Expand All @@ -112,19 +77,45 @@ def do_generate(
completions: list[dict[str, Any]] = generator(prompt, **kwargs)
completions: list[Example] = [template.extract(example, p) for p in completions]

# Find the completions that are unfinished.
# Find the completions that are most complete.
field_names: list[str] = [field.input_variable for field in template.fields]

finished_completions = []
for completion in completions:
if all((completion.get(key, "") != "") for key in field_names):
finished_completions.append(completion)
continue
finished_completions.append(
extend_generation(completion, field_names, stage, max_depth, original_example),
last_field_idx = 0
for field_idx, key in enumerate(field_names):
completions_ = [
c for c in completions if key in c.keys() and c[key] is not None
]

# Filter out completions that are missing fields that are present in at least one completion.
if len(completions_):
completions = completions_
last_field_idx = field_idx + 1

# If none of the completions is completed (i.e., none has the final field set).
if last_field_idx < len(field_names):
# Pick the first completion that has gone farthest.
completion = completions[0]
completion[field_names[last_field_idx]] = ""

# Recurse with greedy decoding and a shorter length.
max_tokens = kwargs.get("max_tokens", dsp.settings.lm.kwargs["max_tokens"])
max_tokens = min(max(75, max_tokens // 2), max_tokens)
new_kwargs = {
**kwargs,
"max_tokens": max_tokens,
"n": 1,
"temperature": 0.0,
}

assert max_depth > 0
return generate(template, **new_kwargs)(
completion,
stage=stage,
max_depth=max_depth - 1,
original_example=original_example,
)

completions = Completions(finished_completions, template=template)
completions = Completions(completions, template=template)
example = example.copy(completions=completions)

if len(completions) == 1:
Expand Down Expand Up @@ -161,15 +152,6 @@ def do_generate(

return do_generate

def get_all_fields_following_missing_field(completion: Example, field_names: list[str]) -> list[str]:
"""Returns every field following the first missing field"""
for i, field_name in enumerate(field_names):
if field_name not in completion:
return field_names[i:]
if completion[field_name] == "":
return field_names[i:]
return []


def generate_sc(
example, prompt, normalize=True, extract=None, prediction_field=None, **kwargs,
Expand Down
1 change: 1 addition & 0 deletions dspy/propose/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .grounded_proposer import GroundedProposer
86 changes: 86 additions & 0 deletions dspy/propose/dataset_summary_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import re

import dspy
from dspy.propose.utils import strip_prefix


class ObservationSummarizer(dspy.Signature):
("""Given a series of observations I have made about my dataset, please summarize them into a brief 2-3 sentence summary which highlights only the most important details.""")
observations = dspy.InputField(desc="Observations I have made about my dataset")
summary = dspy.OutputField(desc="Two to Three sentence summary of only the most significant highlights of my observations")

class DatasetDescriptor(dspy.Signature):
("""Given several examples from a dataset please write observations about trends that hold for most or all of the samples. """
"""Some areas you may consider in your observations: topics, content, syntax, conciceness, etc. """
"""It will be useful to make an educated guess as to the nature of the task this dataset will enable. Don't be afraid to be creative""")

examples = dspy.InputField(desc="Sample data points from the dataset")
observations = dspy.OutputField(desc="Somethings that holds true for most or all of the data you observed")

class DatasetDescriptorWithPriorObservations(dspy.Signature):
("""Given several examples from a dataset please write observations about trends that hold for most or all of the samples. """
"""I will also provide you with a few observations I have already made. Please add your own observations or if you feel the observations are comprehensive say 'COMPLETE' """
"""Some areas you may consider in your observations: topics, content, syntax, conciceness, etc. """
"""It will be useful to make an educated guess as to the nature of the task this dataset will enable. Don't be afraid to be creative""")

examples = dspy.InputField(desc="Sample data points from the dataset")
prior_observations = dspy.InputField(desc="Some prior observations I made about the data")
observations = dspy.OutputField(desc="Somethings that holds true for most or all of the data you observed or COMPLETE if you have nothing to add")

def order_input_keys_in_string(unordered_repr):
# Regex pattern to match the input keys structure
pattern = r"input_keys=\{([^\}]+)\}"

# Function to reorder keys
def reorder_keys(match):
# Extracting the keys from the match
keys_str = match.group(1)
# Splitting the keys, stripping extra spaces, and sorting them
keys = sorted(key.strip() for key in keys_str.split(','))
# Formatting the sorted keys back into the expected structure
return f"input_keys={{{', '.join(keys)}}}"

# Using re.sub to find all matches of the pattern and replace them using the reorder_keys function
ordered_repr = re.sub(pattern, reorder_keys, unordered_repr)

return ordered_repr

def create_dataset_summary(trainset, view_data_batch_size, prompt_model, log_file=None):
upper_lim = min(len(trainset), view_data_batch_size)
with dspy.settings.context(lm=prompt_model):
observation = dspy.Predict(DatasetDescriptor, n=1, temperature=1.0)(examples=order_input_keys_in_string(trainset[0:upper_lim].__repr__()))
observations = observation["observations"]

if log_file:
log_file.write("PRODUCING DATASET SUMMARY\n")

skips = 0
try:
max_calls = 10
calls = 0
for b in range(view_data_batch_size, len(trainset), view_data_batch_size):
calls+=1
if calls >= max_calls:
break
print(f"b: {b}")
upper_lim = min(len(trainset), b+view_data_batch_size)
with dspy.settings.context(lm=prompt_model):
output = dspy.Predict(DatasetDescriptorWithPriorObservations, n=1, temperature=1.0)(prior_observations=observations, examples=order_input_keys_in_string(trainset[b:upper_lim].__repr__()))
if len(output["observations"]) >= 8 and output["observations"][:8].upper() == "COMPLETE":
skips += 1
if skips >= 5:
break
continue
observations += output["observations"]

log_file.write(f"observations {observations}\n")
except Exception as e:
print(f"e {e}. using observations from past round for a summary.")

with dspy.settings.context(lm=prompt_model):
summary = dspy.Predict(ObservationSummarizer, n=1, temperature=1.0)(observations=observations)
print(f"summary: {summary}")
if log_file:
log_file.write(f"summary: {summary}\n")

return strip_prefix(summary.summary)
Loading

0 comments on commit 015c649

Please sign in to comment.