Skip to content

Commit

Permalink
Code refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesrochabrun committed Dec 20, 2023
1 parent d88010f commit b94a23b
Show file tree
Hide file tree
Showing 2 changed files with 306 additions and 321 deletions.
306 changes: 306 additions & 0 deletions Sources/OpenAI/Public/Service/DefaultOpeanAIService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -624,3 +624,309 @@ struct DefaultOpenAIService: OpenAIService {

}
}


extension DefaultOpenAIService {

/// Asynchronously fetches the contents of a file that has been uploaded to OpenAI's service.
///
/// This method is used exclusively for retrieving the content of uploaded files.
///
/// - Parameter request: The `URLRequest` describing the API request to fetch the file.
/// - Throws: An error if the request fails.
/// - Returns: A dictionary array representing the file contents.
private func fetchContentsOfFile(
request: URLRequest)
async throws -> [[String: Any]]
{
printCurlCommand(request)
let (data, response) = try await session.data(for: request)
guard let httpResponse = response as? HTTPURLResponse else {
throw APIError.requestFailed(description: "invalid response unable to get a valid HTTPURLResponse")
}
printHTTPURLResponse(httpResponse)
guard httpResponse.statusCode == 200 else {
var errorMessage = "status code \(httpResponse.statusCode)"
do {
let error = try decoder.decode(OpenAIErrorResponse.self, from: data)
errorMessage += " \(error.error.message ?? "NO ERROR MESSAGE PROVIDED")"
} catch {
// If decoding fails, proceed with a general error message
errorMessage = "status code \(httpResponse.statusCode)"
}
throw APIError.responseUnsuccessful(description: errorMessage)
}
var content: [[String: Any]] = []
if let jsonString = String(data: data, encoding: .utf8) {
let lines = jsonString.split(separator: "\n")
for line in lines {
#if DEBUG
print("DEBUG Received line:\n\(line)")
#endif
if let lineData = line.data(using: .utf8),
let jsonObject = try? JSONSerialization.jsonObject(with: lineData, options: .allowFragments) as? [String: Any] {
content.append(jsonObject)
}
}
}
return content
}

/// Asynchronously fetches audio data.
///
/// This method is used exclusively for handling audio data responses.
///
/// - Parameter request: The `URLRequest` describing the API request to fetch the file.
/// - Throws: An error if the request fails.
/// - Returns: The audio Data
private func fetchAudio(
with request: URLRequest)
async throws -> Data
{
printCurlCommand(request)
let (data, response) = try await session.data(for: request)

guard let httpResponse = response as? HTTPURLResponse else {
throw APIError.requestFailed(description: "Invalid response: unable to get a valid HTTPURLResponse")
}
printHTTPURLResponse(httpResponse)
guard httpResponse.statusCode == 200 else {
var errorMessage = "Status code \(httpResponse.statusCode)"
do {
let errorResponse = try decoder.decode(OpenAIErrorResponse.self, from: data)
errorMessage += " \(errorResponse.error.message ?? "NO ERROR MESSAGE PROVIDED")"
} catch {
if let errorString = String(data: data, encoding: .utf8), !errorString.isEmpty {
errorMessage += " - \(errorString)"
} else {
errorMessage += " - No error message provided"
}
}
throw APIError.responseUnsuccessful(description: errorMessage)
}
return data
}

/// Asynchronously fetches a decodable data type from OpenAI's API.
///
/// - Parameters:
/// - type: The `Decodable` type that the response should be decoded to.
/// - request: The `URLRequest` describing the API request.
/// - Throws: An error if the request fails or if decoding fails.
/// - Returns: A value of the specified decodable type.
private func fetch<T: Decodable>(
type: T.Type,
with request: URLRequest)
async throws -> T
{
printCurlCommand(request)
let (data, response) = try await session.data(for: request)
guard let httpResponse = response as? HTTPURLResponse else {
throw APIError.requestFailed(description: "invalid response unable to get a valid HTTPURLResponse")
}
printHTTPURLResponse(httpResponse)
guard httpResponse.statusCode == 200 else {
var errorMessage = "status code \(httpResponse.statusCode)"
do {
let error = try decoder.decode(OpenAIErrorResponse.self, from: data)
errorMessage += " \(error.error.message ?? "NO ERROR MESSAGE PROVIDED")"
} catch {
// If decoding fails, proceed with a general error message
errorMessage = "status code \(httpResponse.statusCode)"
}
throw APIError.responseUnsuccessful(description: errorMessage)
}
#if DEBUG
print("DEBUG JSON FETCH API = \(try JSONSerialization.jsonObject(with: data, options: .allowFragments) as? [String: Any])")
#endif
do {
return try decoder.decode(type, from: data)
} catch let DecodingError.keyNotFound(key, context) {
let debug = "Key '\(key.stringValue)' not found: \(context.debugDescription)"
let codingPath = "codingPath: \(context.codingPath)"
let debugMessage = debug + codingPath
#if DEBUG
print(debugMessage)
#endif
throw APIError.dataCouldNotBeReadMissingData(description: debugMessage)
} catch {
throw APIError.jsonDecodingFailure(description: error.localizedDescription)
}
}

/// Asynchronously fetches a stream of decodable data types from OpenAI's API for chat completions.
///
/// This method is primarily used for streaming chat completions.
///
/// - Parameters:
/// - type: The `Decodable` type that each streamed response should be decoded to.
/// - request: The `URLRequest` describing the API request.
/// - Throws: An error if the request fails or if decoding fails.
/// - Returns: An asynchronous throwing stream of the specified decodable type.
private func fetchStream<T: Decodable>(
type: T.Type,
with request: URLRequest)
async throws -> AsyncThrowingStream<T, Error>
{
printCurlCommand(request)

let (data, response) = try await session.bytes(for: request)
try Task.checkCancellation()
guard let httpResponse = response as? HTTPURLResponse else {
throw APIError.requestFailed(description: "invalid response unable to get a valid HTTPURLResponse")
}
printHTTPURLResponse(httpResponse)
guard httpResponse.statusCode == 200 else {
var errorMessage = "status code \(httpResponse.statusCode)"
do {
let data = try await data.reduce(into: Data()) { data, byte in
data.append(byte)
}
let error = try decoder.decode(OpenAIErrorResponse.self, from: data)
errorMessage += " \(error.error.message ?? "NO ERROR MESSAGE PROVIDED")"
} catch {
// If decoding fails, proceed with a general error message
errorMessage = "status code \(httpResponse.statusCode)"
}
throw APIError.responseUnsuccessful(description: errorMessage)
}
return AsyncThrowingStream { continuation in
Task {
do {
for try await line in data.lines {
try Task.checkCancellation()
if line.hasPrefix("data:") && line != "data: [DONE]",
let data = line.dropFirst(5).data(using: .utf8) {
#if DEBUG
print("DEBUG JSON STREAM LINE = \(try JSONSerialization.jsonObject(with: data, options: .allowFragments) as? [String: Any])")
#endif
do {
let decoded = try self.decoder.decode(T.self, from: data)
continuation.yield(decoded)
} catch let DecodingError.keyNotFound(key, context) {
let debug = "Key '\(key.stringValue)' not found: \(context.debugDescription)"
let codingPath = "codingPath: \(context.codingPath)"
let debugMessage = debug + codingPath
#if DEBUG
print(debugMessage)
#endif
throw APIError.dataCouldNotBeReadMissingData(description: debugMessage)
} catch {
#if DEBUG
debugPrint("CONTINUATION ERROR DECODING \(error.localizedDescription)")
#endif
continuation.finish(throwing: error)
}
}
}
continuation.finish()
} catch let DecodingError.keyNotFound(key, context) {
let debug = "Key '\(key.stringValue)' not found: \(context.debugDescription)"
let codingPath = "codingPath: \(context.codingPath)"
let debugMessage = debug + codingPath
#if DEBUG
print(debugMessage)
#endif
throw APIError.dataCouldNotBeReadMissingData(description: debugMessage)
} catch {
#if DEBUG
print("CONTINUATION ERROR DECODING \(error.localizedDescription)")
#endif
continuation.finish(throwing: error)
}
}
}
}

// MARK: Debug Helpers

private func prettyPrintJSON(
_ data: Data)
-> String?
{
guard
let jsonObject = try? JSONSerialization.jsonObject(with: data, options: []),
let prettyData = try? JSONSerialization.data(withJSONObject: jsonObject, options: [.prettyPrinted]),
let prettyPrintedString = String(data: prettyData, encoding: .utf8)
else { return nil }
return prettyPrintedString
}

private func printCurlCommand(
_ request: URLRequest)
{
guard let url = request.url, let httpMethod = request.httpMethod else {
debugPrint("Invalid URL or HTTP method.")
return
}

var baseCommand = "curl \(url.absoluteString)"

// Add method if not GET
if httpMethod != "GET" {
baseCommand += " -X \(httpMethod)"
}

// Add headers if any, masking the Authorization token
if let headers = request.allHTTPHeaderFields {
for (header, value) in headers {
let maskedValue = header.lowercased() == "authorization" ? maskAuthorizationToken(value) : value
baseCommand += " \\\n-H \"\(header): \(maskedValue)\""
}
}

// Add body if present
if let httpBody = request.httpBody, let bodyString = prettyPrintJSON(httpBody) {
// The body string is already pretty printed and should be enclosed in single quotes
baseCommand += " \\\n-d '\(bodyString)'"
}

// Print the final command
#if DEBUG
print(baseCommand)
#endif
}

private func prettyPrintJSON(
_ data: Data)
-> String
{
guard
let jsonObject = try? JSONSerialization.jsonObject(with: data, options: []),
let prettyData = try? JSONSerialization.data(withJSONObject: jsonObject, options: [.prettyPrinted]),
let prettyPrintedString = String(data: prettyData, encoding: .utf8) else { return "Could not print JSON - invalid format" }
return prettyPrintedString
}

private func printHTTPURLResponse(
_ response: HTTPURLResponse,
data: Data? = nil)
{
#if DEBUG
print("\n- - - - - - - - - - INCOMING RESPONSE - - - - - - - - - -\n")
print("URL: \(response.url?.absoluteString ?? "No URL")")
print("Status Code: \(response.statusCode)")
print("Headers: \(response.allHeaderFields)")
if let mimeType = response.mimeType {
print("MIME Type: \(mimeType)")
}
if let data = data, response.mimeType == "application/json" {
print("Body: \(prettyPrintJSON(data))")
} else if let data = data, let bodyString = String(data: data, encoding: .utf8) {
print("Body: \(bodyString)")
}
print("\n- - - - - - - - - - - - - - - - - - - - - - - - - - - -\n")
#endif
}

private func maskAuthorizationToken(_ token: String) -> String {
if token.count > 6 {
let prefix = String(token.prefix(3))
let suffix = String(token.suffix(3))
return "\(prefix)................\(suffix)"
} else {
return "INVALID TOKEN LENGTH"
}
}

}
Loading

0 comments on commit b94a23b

Please sign in to comment.