Skip to content

Commit

Permalink
Adding required
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesrochabrun committed May 4, 2024
1 parent 9d2e7d3 commit c3ad62e
Showing 1 changed file with 23 additions and 9 deletions.
32 changes: 23 additions & 9 deletions Sources/OpenAI/Public/Shared/ToolChoice.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,24 @@
import Foundation

/// string `none` means the model will not call a function and instead generates a message.
///
/// `auto` means the model can pick between generating a message or calling a function.
///
/// `object` Specifies a tool the model should use. Use to force the model to call a specific function. The type of the tool. Currently, only` function` is supported. `{"type: "function", "function": {"name": "my_function"}}`
///
/// `required` To force the model to always call one or more functions, you can set tool_choice: "required". The model will then select which function(s) to call.
///
/// [Function Calling](https://platform.openai.com/docs/guides/function-calling)
public enum ToolChoice: Codable, Equatable {
case none
case auto
case required
case function(type: String = "function", name: String)

enum CodingKeys: String, CodingKey {
case none = "none"
case auto = "auto"
case required = "required"
case type = "type"
case function = "function"
}
Expand All @@ -34,32 +42,38 @@ public enum ToolChoice: Codable, Equatable {
case .auto:
var container = encoder.singleValueContainer()
try container.encode(CodingKeys.auto.rawValue)
case .required:
var container = encoder.singleValueContainer()
try container.encode(CodingKeys.required.rawValue)
case .function(let type, let name):
var container = encoder.container(keyedBy: CodingKeys.self)
try container.encode(type, forKey: .type)
var functionContainer = container.nestedContainer(keyedBy: FunctionCodingKeys.self, forKey: .function)
try functionContainer.encode(name, forKey: .name)

}
}

public init(from decoder: Decoder) throws {
// Handle the 'function' case:
if let container = try? decoder.container(keyedBy: CodingKeys.self),
let functionContainer = try? container.nestedContainer(keyedBy: FunctionCodingKeys.self, forKey: .function) {
let name = try functionContainer.decode(String.self, forKey: .name)
self = .function(type: "function", name: name)
return
let name = try functionContainer.decode(String.self, forKey: .name)
self = .function(type: "function", name: name)
return
}

// Handle the 'auto' and 'none' cases
let container = try decoder.singleValueContainer()
switch try container.decode(String.self) {
case "none":
self = .none
case "auto":
self = .auto
case "none":
self = .none
case "auto":
self = .auto
case "required":
self = .required
default:
throw DecodingError.dataCorruptedError(in: container, debugDescription: "Invalid tool_choice structure")
throw DecodingError.dataCorruptedError(in: container, debugDescription: "Invalid tool_choice structure")
}
}
}

0 comments on commit c3ad62e

Please sign in to comment.