Skip to content

Commit

Permalink
🐛✅ Fix fixture for tgis tests
Browse files Browse the repository at this point in the history
Signed-off-by: gkumbhat <[email protected]>
  • Loading branch information
gkumbhat committed May 17, 2024
1 parent 401510d commit fc5ebda
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 9 deletions.
6 changes: 3 additions & 3 deletions tests/fixtures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def Tokenize(self, request, **kwargs):
return StubTGISClient.tokenize(request)

@staticmethod
def unary_generate(request):
def unary_generate(request, **kwargs):
fake_response = mock.Mock()
fake_result = mock.Mock()
fake_result.stop_reason = 5
Expand All @@ -229,7 +229,7 @@ def unary_generate(request):
return fake_response

@staticmethod
def stream_generate(request):
def stream_generate(request, **kwargs):
fake_stream = mock.Mock()
fake_stream.stop_reason = 5
fake_stream.generated_token_count = 1
Expand All @@ -250,7 +250,7 @@ def stream_generate(request):
yield fake_stream

@staticmethod
def tokenize(request):
def tokenize(request, **kwargs):
fake_response = mock.Mock()
fake_result = mock.Mock()
fake_result.token_count = 1
Expand Down
9 changes: 3 additions & 6 deletions tests/toolkit/text_generation/test_tgis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,19 @@ def _maybe_raise(self, error_type: Type[grpc.RpcError], *args):
)

def Generate(
self,
request: generation_pb2.BatchedGenerationRequest,
self, request: generation_pb2.BatchedGenerationRequest, **kwargs
) -> generation_pb2.BatchedGenerationResponse:
self._maybe_raise(grpc._channel._InactiveRpcError)
return generation_pb2.BatchedGenerationResponse()

def GenerateStream(
self,
request: generation_pb2.SingleGenerationRequest,
self, request: generation_pb2.SingleGenerationRequest, **kwargs
) -> Iterable[generation_pb2.GenerationResponse]:
self._maybe_raise(grpc._channel._MultiThreadedRendezvous, None, None, None)
yield generation_pb2.GenerationResponse()

def Tokenize(
self,
request: generation_pb2.BatchedTokenizeRequest,
self, request: generation_pb2.BatchedTokenizeRequest, **kwargs
) -> generation_pb2.BatchedTokenizeResponse:
self._maybe_raise(grpc._channel._InactiveRpcError)
return generation_pb2.BatchedTokenizeResponse()
Expand Down

0 comments on commit fc5ebda

Please sign in to comment.