Skip to content

Commit

Permalink
add get_route_info
Browse files Browse the repository at this point in the history
Signed-off-by: Mynhardt Burger <[email protected]>
  • Loading branch information
mynhardtburger committed Jun 6, 2024
1 parent 4bf53fd commit e278915
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 3 deletions.
36 changes: 33 additions & 3 deletions caikit_nlp/toolkit/text_generation/tgis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This file is for helper functions related to TGIS.
"""
"""This file is for helper functions related to TGIS."""

# Standard
from typing import Iterable
from typing import Iterable, Optional, Tuple

# Third Party
import fastapi
import grpc

# First Party
Expand All @@ -33,6 +34,7 @@
TokenizationResults,
TokenStreamDetails,
)
from caikit.interfaces.runtime.data_model import RuntimeServerContextType
from caikit_tgis_backend.protobufs import generation_pb2
import alog

Expand Down Expand Up @@ -683,3 +685,31 @@ def unary_tokenize(
return TokenizationResults(
token_count=response.token_count,
)


def get_route_info(
context: Optional[RuntimeServerContextType],
) -> Tuple[bool, Optional[str]]:
"""
Returns a tuple `(True, x-route-info)` from context if "x-route-info" was found in the headers/metadata.
Otherwise returns a tuple `(False, None)` if "x-route-info" was not found in the context or if context is None.
"""
if context is None:
return False, None

if isinstance(context, grpc.ServicerContext):
route_info = dict(context.invocation_metadata()).get("x-route-info")
if route_info:
return True, route_info
elif isinstance(context, fastapi.Request):
route_info = context.headers.get("x-route-info")
if route_info:
return True, route_info
else:
error.log_raise(
"<NLP92615097E>",
ValueError(f"context is of an unsupported type: {type(context)}"),
)

return False, None
37 changes: 37 additions & 0 deletions tests/toolkit/text_generation/test_tgis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,20 @@
"""
Tests for tgis_utils
"""

# Standard
from typing import Iterable, Optional, Type

# Third Party
import fastapi
import grpc
import grpc._channel
import pytest

# First Party
from caikit.core.data_model import ProducerId
from caikit.core.exceptions.caikit_core_exception import CaikitCoreException
from caikit.interfaces.runtime.data_model import RuntimeServerContextType
from caikit_tgis_backend.protobufs import generation_pb2

# Local
Expand Down Expand Up @@ -127,3 +130,37 @@ def test_TGISGenerationClient_rpc_errors(status_code, method):
)
rpc_err = context.value.__context__
assert isinstance(rpc_err, grpc.RpcError)


@pytest.mark.parametrize(
argnames=["context", "ok", "route_info"],
argvalues=[
(
fastapi.Request(
{"type": "http", "headers": [(b"x-route-info", b"sometext")]}
),
True,
"sometext",
),
(
fastapi.Request(
{"type": "http", "headers": [(b"route-info", b"sometext")]}
),
False,
None,
),
("should raise ValueError", False, None),
(None, False, None),
# Uncertain how to create a grpc.ServicerContext object
],
)
def test_get_route_info(
context: RuntimeServerContextType, ok: bool, route_info: Optional[str]
):
if not isinstance(context, (fastapi.Request, grpc.ServicerContext, type(None))):
with pytest.raises(ValueError):
tgis_utils.get_route_info(context)
else:
actual_ok, actual_route_info = tgis_utils.get_route_info(context)
assert actual_ok == ok
assert actual_route_info == route_info

0 comments on commit e278915

Please sign in to comment.