Skip to content

Commit

Permalink
Add causalLM OpenVino models (#1290)
Browse files Browse the repository at this point in the history
* added intel optimum

* added intel optimum in readme

* modified intel optimum

* modified intel optimum

* modified intel optimum

* modified install optimum

* modified path of IR file

* added openvino_device

* added openvino_device2

* changed optimum-causal to openvino-causal

* Update README.md

* Update README.md

* remove `lm_eval.base` import

* update openvino-causal -> openvino ; pass device through super().__init__()

* Update README.md

* Add optimum to tests dependencies

* apply pre-commit

* fix so tests pass

---------

Co-authored-by: Hailey Schoelkopf <[email protected]>
Co-authored-by: haileyschoelkopf <[email protected]>
  • Loading branch information
3 people committed Jan 26, 2024
1 parent 5b0b8a5 commit 97a67d2
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e '.[dev,anthropic,sentencepiece]' --extra-index-url https://download.pytorch.org/whl/cpu
pip install -e '.[dev,anthropic,sentencepiece,optimum]' --extra-index-url https://download.pytorch.org/whl/cpu
# Install optional git dependencies
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
Expand Down
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ We also provide a number of optional dependencies for extended functionality. Ex
| math | For running math task answer checking |
| multilingual | For multilingual tokenizers |
| openai | For using OpenAI's models |
| optimum | For running Intel OpenVINO models |
| promptsource | For using PromptSource prompts |
| sentencepiece | For using the sentencepiece tokenizer |
| testing | For running library test suite |
Expand Down Expand Up @@ -189,8 +190,8 @@ Note that for externally hosted models, configs such as `--device` and `--batch_
| [Llama.cpp](https://github.com/ggerganov/llama.cpp) (via [llama-cpp-python](https://github.com/abetlen/llama-cpp-python)) | :heavy_check_mark: | `gguf`, `ggml` | [All models supported by llama.cpp](https://github.com/ggerganov/llama.cpp) | `generate_until`, `loglikelihood`, (perplexity evaluation not yet implemented) |
| vLLM | :heavy_check_mark: | `vllm` | [Most HF Causal Language Models](https://docs.vllm.ai/en/latest/models/supported_models.html) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| Mamba | :heavy_check_mark: | `mamba_ssm` | [Mamba architecture Language Models via the `mamba_ssm` package](https://huggingface.co/state-spaces) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| Your local inference server! | :heavy_check_mark: | `local-completions` or `local-chat-completions` (using `openai-chat-completions` model type) | Any server address that accepts GET requests using HF models and mirror's OpenAI's ChatCompletions interface | `generate_until` | | ... |
| `local-completions` (using `openai-completions` model type) | Any server address that accepts GET requests using HF models and mirror's OpenAI's Completions interface | `generate_until` | | ... |
| Huggingface Optimum (Causal LMs) | ✔️ | `openvino` | Any decoder-only AutoModelForCausalLM converted with Huggingface Optimum into OpenVINO™ Intermediate Representation (IR) format | `generate_until`, `loglikelihood`, `loglikelihood_rolling` | ... |
| Your local inference server! | :heavy_check_mark: | `local-completions` or `local-chat-completions` (using `openai-chat-completions` model type) | Any server address that accepts GET requests using HF models and mirror's OpenAI's Completions or ChatCompletions interface | `generate_until` | | ... |

Models which do not supply logits or logprobs can be used with tasks of type `generate_until` only, while local models, or APIs that supply logprobs/logits of their prompts, can be run on all task types: `generate_until`, `loglikelihood`, `loglikelihood_rolling`, and `multiple_choice`.

Expand Down
1 change: 1 addition & 0 deletions lm_eval/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
from . import gguf
from . import vllm_causallms
from . import mamba_lm
from . import optimum_lm

# TODO: implement __all__
5 changes: 3 additions & 2 deletions lm_eval/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,9 @@ def __init__(
)

# access self._model through self.model property outside this method
self.model.eval()
self.model.tie_weights()
if isinstance(self.model, torch.nn.Module):
self.model.eval()
self.model.tie_weights()

if isinstance(pretrained, str) and (gpus >= 1 or str(self.device) == "mps"):
# TODO: can remove this whole snippet except in the mps case, perhaps?
Expand Down
69 changes: 69 additions & 0 deletions lm_eval/models/optimum_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from importlib.util import find_spec
from pathlib import Path

from lm_eval.api.registry import register_model
from lm_eval.models.huggingface import HFLM


@register_model("openvino")
class OptimumLM(HFLM):
"""
Optimum Intel provides a simple interface to optimize Transformer models and convert them to \
OpenVINO™ Intermediate Representation (IR) format to accelerate end-to-end pipelines on \
Intel® architectures using OpenVINO™ runtime.
"""

def __init__(
self,
device="cpu",
**kwargs,
) -> None:
if "backend" in kwargs:
# optimum currently only supports causal models
assert (
kwargs["backend"] == "causal"
), "Currently, only OVModelForCausalLM is supported."

self.openvino_device = device

super().__init__(
device=self.openvino_device,
backend=kwargs.get("backend", "causal"),
**kwargs,
)

def _create_model(
self,
pretrained: str,
revision="main",
dtype="auto",
trust_remote_code=False,
**kwargs,
) -> None:
if not find_spec("optimum"):
raise Exception(
"package `optimum` is not installed. Please install it via `pip install optimum[openvino]`"
)
else:
from optimum.intel.openvino import OVModelForCausalLM

model_kwargs = kwargs if kwargs else {}
model_file = Path(pretrained) / "openvino_model.xml"
if model_file.exists():
export = False
else:
export = True
kwargs["ov_config"] = {
"PERFORMANCE_HINT": "LATENCY",
"NUM_STREAMS": "1",
"CACHE_DIR": "",
}

self._model = OVModelForCausalLM.from_pretrained(
pretrained,
revision=revision,
trust_remote_code=trust_remote_code,
export=export,
device=self.openvino_device.upper(),
**model_kwargs,
)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ mamba = ["mamba_ssm", "causal-conv1d==1.0.2"]
math = ["sympy>=1.12", "antlr4-python3-runtime==4.11"]
multilingual = ["nagisa>=0.2.7", "jieba>=0.42.1", "pycountry"]
openai = ["openai==1.3.9", "tiktoken"]
optimum = ["optimum[openvino]"]
promptsource = [
"promptsource @ git+https://github.com/bigscience-workshop/promptsource.git#egg=promptsource"
]
Expand Down
76 changes: 76 additions & 0 deletions tests/models/test_openvino.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import random
import tempfile

import pytest
from optimum.intel import OVModelForCausalLM
from transformers import AutoTokenizer

import lm_eval.evaluator as evaluator
import lm_eval.tasks as tasks
from lm_eval.api.registry import get_model


tasks.initialize_tasks()

SUPPORTED_ARCHITECTURES_TASKS = {
"facebook/opt-125m": "lambada_openai",
"hf-internal-testing/tiny-random-gpt2": "wikitext",
}


@pytest.mark.parametrize("model_id,task", SUPPORTED_ARCHITECTURES_TASKS.items())
def test_evaluator(model_id, task):
with tempfile.TemporaryDirectory() as tmpdirname:
model = OVModelForCausalLM.from_pretrained(
model_id, export=True, use_cache=True
)
model.save_pretrained(tmpdirname)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.save_pretrained(tmpdirname)

lm = get_model("openvino").create_from_arg_string(
f"pretrained={tmpdirname}",
{
"batch_size": 1,
"device": "cpu",
},
)

def ll_fn(reqs):
for ctx, cont in [req.args for req in reqs]:
if len(ctx) == 0:
continue
# space convention
assert ctx[-1] != " "
assert cont[0] == " " or ctx[-1] == "\n"

res = []

random.seed(42)
for _ in reqs:
res.append((-random.random(), False))

return res

def ll_perp_fn(reqs):
for (string,) in [req.args for req in reqs]:
assert isinstance(string, str)

res = []
random.seed(42)
for _ in reqs:
res.append(-random.random())

return res

lm.loglikelihood = ll_fn
lm.loglikelihood_rolling = ll_perp_fn

limit = 10
evaluator.simple_evaluate(
model=lm,
tasks=[task],
num_fewshot=0,
limit=limit,
bootstrap_iters=10,
)

0 comments on commit 97a67d2

Please sign in to comment.