Skip to content

Commit

Permalink
Merge pull request BerriAI#4719 from BerriAI/litellm_fix_audio_transc…
Browse files Browse the repository at this point in the history
…ript

[Fix] /audio/transcription - don't write to the local file system
  • Loading branch information
ishaan-jaff committed Jul 16, 2024
2 parents 4dfc00a + a900f35 commit 979b5d8
Showing 1 changed file with 63 additions and 66 deletions.
129 changes: 63 additions & 66 deletions litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import asyncio
import copy
import inspect
import io
import os
import random
import secrets
Expand Down Expand Up @@ -3787,74 +3788,70 @@ async def audio_transcriptions(

router_model_names = llm_router.model_names if llm_router is not None else []

assert (
file.filename is not None
) # make sure filename passed in (needed for type)

_original_filename = file.filename
file_extension = os.path.splitext(file.filename)[1]
# rename the file to a random hash file name -> we eventuall remove the file and don't want to remove any local files
file.filename = f"tmp-request" + str(uuid.uuid4()) + file_extension

# IMP - Asserts that we've renamed the uploaded file, since we run os.remove(file.filename), we should rename the original file
assert file.filename != _original_filename
if file.filename is None:
raise ProxyException(
message="File name is None. Please check your file name",
code=status.HTTP_400_BAD_REQUEST,
type="bad_request",
param="file",
)

with open(file.filename, "wb+") as f:
f.write(await file.read())
try:
data["file"] = open(file.filename, "rb")
### CALL HOOKS ### - modify incoming data / reject request before calling the model
data = await proxy_logging_obj.pre_call_hook(
user_api_key_dict=user_api_key_dict,
data=data,
call_type="audio_transcription",
# Instead of writing to a file
file_content = await file.read()
file_object = io.BytesIO(file_content)
file_object.name = file.filename
data["file"] = file_object
try:
### CALL HOOKS ### - modify incoming data / reject request before calling the model
data = await proxy_logging_obj.pre_call_hook(
user_api_key_dict=user_api_key_dict,
data=data,
call_type="audio_transcription",
)

## ROUTE TO CORRECT ENDPOINT ##
# skip router if user passed their key
if "api_key" in data:
response = await litellm.atranscription(**data)
elif (
llm_router is not None and data["model"] in router_model_names
): # model in router model list
response = await llm_router.atranscription(**data)

elif (
llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.atranscription(
**data, specific_deployment=True
)

## ROUTE TO CORRECT ENDPOINT ##
# skip router if user passed their key
if "api_key" in data:
response = await litellm.atranscription(**data)
elif (
llm_router is not None and data["model"] in router_model_names
): # model in router model list
response = await llm_router.atranscription(**data)

elif (
llm_router is not None
and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.atranscription(
**data, specific_deployment=True
)
elif (
llm_router is not None
and llm_router.model_group_alias is not None
and data["model"] in llm_router.model_group_alias
): # model set in model_group_alias
response = await llm_router.atranscription(
**data
) # ensure this goes the llm_router, router will do the correct alias mapping
elif (
llm_router is not None
and data["model"] not in router_model_names
and llm_router.default_deployment is not None
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.atranscription(**data)
elif user_model is not None: # `litellm --model <your-model-name>`
response = await litellm.atranscription(**data)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": "audio_transcriptions: Invalid model name passed in model="
+ data.get("model", "")
},
)

except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
os.remove(file.filename) # Delete the saved file
elif (
llm_router is not None
and llm_router.model_group_alias is not None
and data["model"] in llm_router.model_group_alias
): # model set in model_group_alias
response = await llm_router.atranscription(
**data
) # ensure this goes the llm_router, router will do the correct alias mapping
elif (
llm_router is not None
and data["model"] not in router_model_names
and llm_router.default_deployment is not None
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.atranscription(**data)
elif user_model is not None: # `litellm --model <your-model-name>`
response = await litellm.atranscription(**data)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": "audio_transcriptions: Invalid model name passed in model="
+ data.get("model", "")
},
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
file_object.close() # close the file read in by io library

### ALERTING ###
asyncio.create_task(
Expand Down

0 comments on commit 979b5d8

Please sign in to comment.