Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to set a max input length? #1132

Closed
hellwigt-eq opened this issue Jun 11, 2024 · 3 comments
Closed

How to set a max input length? #1132

hellwigt-eq opened this issue Jun 11, 2024 · 3 comments

Comments

@hellwigt-eq
Copy link

I'm using a LLM being served by Hugging Face's TGI. TGI enforces a max-input-tokens setting, which is typically set to the context length of the model. Any requests to the endpoint which exceed this limit are rejected. This rejection causes DSPy to throw an exception and exit as seen here:

Traceback (most recent call last):
  File "/var/home/test/miniconda3/envs/test/lib/python3.11/site-packages/dsp/modules/hf_client.py", line 87, in _generate
    completions = [json_response["generated_text"]]
                   ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
KeyError: 'generated_text'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/var/home/test/dev/test/src/dspy_module.py", line 147, in <module>
    optimize()
  File "/var/home/test/dev/test/src/dspy_module.py", line 119, in optimize
    optimized_disambiguator = optimizer.compile(
                              ^^^^^^^^^^^^^^^^^^
  File "/var/home/test/miniconda3/envs/test/lib/python3.11/site-packages/dspy/teleprompt/mipro_optimizer.py", line 459, in compile
    instruction_candidates, _ = self._generate_first_N_candidates(
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/home/test/miniconda3/envs/test/lib/python3.11/site-packages/dspy/teleprompt/mipro_optimizer.py", line 247, in _generate_first_N_candidates
    self.observations = self._observe_data(devset).replace("Observations:", "").replace("Summary:", "")
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/home/test/miniconda3/envs/test/lib/python3.11/site-packages/dspy/teleprompt/mipro_optimizer.py", line 175, in _observe_data
    observation = dspy.Predict(DatasetDescriptor, n=1, temperature=1.0)(examples=(trainset[0:upper_lim].__repr__()))
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/home/test/miniconda3/envs/test/lib/python3.11/site-packages/dspy/predict/predict.py", line 61, in __call__
    return self.forward(**kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/var/home/test/miniconda3/envs/test/lib/python3.11/site-packages/dspy/predict/predict.py", line 103, in forward
    x, C = dsp.generate(template, **config)(x, stage=self.stage)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/home/test/miniconda3/envs/test/lib/python3.11/site-packages/dsp/primitives/predict.py", line 77, in do_generate
    completions: list[dict[str, Any]] = generator(prompt, **kwargs)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/home/test/miniconda3/envs/test/lib/python3.11/site-packages/dsp/modules/hf.py", line 190, in __call__
    response = self.request(prompt, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/home/test/miniconda3/envs/test/lib/python3.11/site-packages/dsp/modules/lm.py", line 26, in request
    return self.basic_request(prompt, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/home/test/miniconda3/envs/test/lib/python3.11/site-packages/dsp/modules/hf.py", line 147, in basic_request
    response = self._generate(prompt, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/home/test/miniconda3/envs/test/lib/python3.11/site-packages/dsp/modules/hf_client.py", line 102, in _generate
    raise Exception("Received invalid JSON response from server")
Exception: Received invalid JSON response from server
Failed to parse JSON response: {"error":"Input validation error: `inputs` tokens + `max_new_tokens` must be <= 8192. Given: 13800 `inputs` tokens and 75 `max_new_tokens`","error_type":"validation"}

Process finished with exit code 1

Is there a way to limit the max input length of all queries that get sent to the LLM, such that they never exceed a set context length?

@hellwigt-eq hellwigt-eq changed the title How to set a max input length How to set a max input length? Jun 11, 2024
@tom-doerr
Copy link
Contributor

I don't know one, but the documentation has some tips on handling this error: https://dspy-docs.vercel.app/docs/faqs#errors

@arnavsinghvi11
Copy link
Collaborator

Hi @hellwigt-eq , there isn't a clean way to set a max length on the final prompt sent to the LLM, but for this use case in TGI, you can add a validation check before the request is sent. this could be added via a PR with exploration on keeping flags for rejection, truncation+retry, etc.

@okhat okhat closed this as completed Jun 22, 2024
@hellwigt-eq
Copy link
Author

For anyone else who encounters this, I found that MIPRO() has an optional parameter called view_data_batch_size which defaults to 10. In my case, I was able to avoid these length errors by setting that to 3 instead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants