Skip to content

Commit

Permalink
add base path and catch generation errors
Browse files Browse the repository at this point in the history
  • Loading branch information
socketteer committed Mar 24, 2023
1 parent b77024a commit dc9c8ab
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 72 deletions.
34 changes: 28 additions & 6 deletions components/dialogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,16 +1143,25 @@ def apply(self):

class AddModelDialog(Dialog):
def __init__(self, parent):
self.model_id_entry = None
self.model_name_entry = None
self.model_type_entry = None
self.base_path_entry = None
Dialog.__init__(self, parent, title="Model Configuration")

def body(self, master):
self.model_id_entry = Entry(master, master.grid_size()[1], "Model id", "", None, width=30)
self.model_name_entry = Entry(master, master.grid_size()[1], "Model name", "", None, width=30)
self.model_type_entry = Entry(master, master.grid_size()[1], "Model type", "openai-custom", None, width=30)
self.model_type_entry = Entry(master, master.grid_size()[1], "Model type", "openai", None, width=30)
self.base_path_entry = Entry(master, master.grid_size()[1], "API base", "https://api.openai.com/v1", None, width=30)

def apply(self):
self.result = {'name': self.model_name_entry.tk_variables.get(), 'type': self.model_type_entry.tk_variables.get()}
self.result = {
'id': self.model_id_entry.tk_variables.get(),
'name': self.model_name_entry.tk_variables.get(),
'type': self.model_type_entry.tk_variables.get(),
'api_base': self.base_path_entry.tk_variables.get()
}


class ModelConfigDialog(Dialog):
Expand All @@ -1169,6 +1178,7 @@ def __init__(self, parent, state):
self.ai21_api_key = None
self.gooseai_api_key = None
self.gooseai_api_key_entry = None
# self.api_base = None
Dialog.__init__(self, parent, title="Model Configuration")

def set_vars(self):
Expand All @@ -1177,6 +1187,7 @@ def set_vars(self):
self.openai_api_key = self.state.OPENAI_API_KEY if self.state.OPENAI_API_KEY else ""
self.ai21_api_key = self.state.AI21_API_KEY if self.state.AI21_API_KEY else ""
self.gooseai_api_key = self.state.GOOSEAI_API_KEY if self.state.GOOSEAI_API_KEY else ""
# self.api_base = self.state.model_config.api_base if self.state.model_config.api_base else ""

def body(self, master):
self.set_vars()
Expand All @@ -1185,6 +1196,7 @@ def body(self, master):
self.openai_api_key_entry = Entry(master, master.grid_size()[1], "OpenAI API Key", self.openai_api_key, None, width=key_length)
self.ai21_api_key_entry = Entry(master, master.grid_size()[1], "AI21 API Key", self.ai21_api_key, None, width=key_length)
self.gooseai_api_key_entry = Entry(master, master.grid_size()[1], "GooseAI API Key", self.gooseai_api_key, None, width=key_length)
# self.api_base_entry = Entry(master, master.grid_size()[1], "API Base", self.api_base, None)
models_list = self.available_models.keys()
self.model_label = ttk.Label(master, text="Model")
self.model_label.grid(row=master.grid_size()[1], column=0)
Expand All @@ -1196,16 +1208,26 @@ def body(self, master):
def add_model(self):
self.result = AddModelDialog(self).result
if self.result:
self.available_models[self.result['name']] = {'type': self.result['type']}
self.available_models_dropdown['menu'].add_command(label=self.result['name'],
command=lambda: self.selected_model.set(self.result['name']))
self.available_models[self.result['id']] = {
'name': self.result['name'],
'type': self.result['type'],
'api_base': self.result['api_base']
}
self.available_models_dropdown['menu'].add_command(label=self.result['id'],
command=lambda: self.selected_model.set(self.result['id']))

def apply(self):
self.state.update_frame(node=self.state.root(), update={'model_config': {'models': self.available_models}})
self.state.update_frame(node=self.state.root(), update={'model_config': {'models': self.available_models,
# 'OPENAI_API_KEY': self.openai_api_key_entry.tk_variables.get(),
# 'AI21_API_KEY': self.ai21_api_key_entry.tk_variables.get(),
# 'GOOSEAI_API_KEY': self.gooseai_api_key_entry.tk_variables.get(),
# 'api_base': self.api_base_entry.tk_variables.get()
}})
#'OPENAI_API_KEY': self.openai_api_key_entry.tk_variables.get(),
#'AI21_API_KEY': self.ai21_api_key_entry.tk_variables.get(),
self.state.OPENAI_API_KEY = self.openai_api_key_entry.tk_variables.get().strip()
self.state.AI21_API_KEY = self.ai21_api_key_entry.tk_variables.get().strip()
self.state.GOOSEAI_API_KEY = self.gooseai_api_key_entry.tk_variables.get().strip()
# self.state.api_base = self.api_base_entry.tk_variables.get().strip()
self.state.update_user_frame(update={'generation_settings': {'model': self.selected_model.get()}})

65 changes: 25 additions & 40 deletions gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,29 +43,6 @@
? "sequence": string }
'''

# POSSIBLE_MODELS = [
# 'ada',
# 'babbage',
# 'content-filter-alpha-c4',
# 'content-filter-dev',
# 'curie',
# 'cursing-filter-v6',
# 'davinci',
# 'instruct-curie-beta',
# 'instruct-davinci-beta',
# 'j1-large',
# 'j1-jumbo',
# ]












#ai21_api_key = os.environ.get("AI21_API_KEY", None)
Expand All @@ -82,35 +59,43 @@ def gen(prompt, settings, config, **kwargs):
logit_bias = None
#if config['OPENAI_API_KEY']:
model_info = config['models'][settings['model']]
# print('model info:', model_info)
openai.api_base = model_info['api_base'] if model_info['api_base'] else "https://api.openai.com/v1"
ai21_api_key = kwargs.get('AI21_API_KEY', None)
ai21_api_key = ai21_api_key if ai21_api_key else os.environ.get("AI21_API_KEY", None)
if model_info['type'] == 'gooseai':
openai.api_base = "https://api.goose.ai/v1"
# openai.api_base = openai.api_base if openai.api_base else "https://api.goose.ai/v1"
gooseai_api_key = kwargs.get('GOOSEAI_API_KEY', None)
openai.api_key = gooseai_api_key if gooseai_api_key else os.environ.get("GOOSEAI_API_KEY", None)
elif model_info['type'] == 'openai':
openai.api_base = "https://api.openai.com/v1"
elif model_info['type'] in ('openai', 'openai-custom', 'openai-chat'):
# openai.api_base = openai.api_base if openai.api_base else "https://api.openai.com/v1"
openai_api_key = kwargs.get('OPENAI_API_KEY', None)
openai.api_key = openai_api_key if openai_api_key else os.environ.get("OPENAI_API_KEY", None)

#print('openai api key: ' + openai.api_key)
# print('openai api base: ' + openai.api_base)

# print('openai api key: ' + openai.api_key)

# if config['AI21_API_KEY']:
#TODO
# ai21_api_key = config['AI21_API_KEY']
response, error = generate(prompt=prompt,
length=settings['response_length'],
num_continuations=settings['num_continuations'],
temperature=settings['temperature'],
logprobs=settings['logprobs'],
top_p=settings['top_p'],
model=settings['model'],
stop=stop,
logit_bias=logit_bias,
config=config,
ai21_api_key=ai21_api_key,
)
return response, error
try:
response, error = generate(prompt=prompt,
length=settings['response_length'],
num_continuations=settings['num_continuations'],
temperature=settings['temperature'],
logprobs=settings['logprobs'],
top_p=settings['top_p'],
model=settings['model'],
stop=stop,
logit_bias=logit_bias,
config=config,
ai21_api_key=ai21_api_key,
)
return response, error
except Exception as e:
print(e)
return None, e


def generate(config, **kwargs):
Expand Down
122 changes: 96 additions & 26 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def wrapper(self, *args, **kwargs):
DEFAULT_WORKSPACE = {
'side_pane': {'open': True,
'modules': ["minimap"]},
'bottom_pane': {'open': True,
'modules': ["children"]},
'bottom_pane': {'open': False,
'modules': []},
'buttons': ["Edit", "Delete", "Generate", "New Child", "Next", "Prev", "Wavefunction", "Map"],
'alt_textbox': False,
'show_search': False
Expand Down Expand Up @@ -117,31 +117,100 @@ def wrapper(self, *args, **kwargs):

DEFAULT_MODEL_CONFIG = {
'models': {
'ada': {'type': 'openai'},
'babbage': {'type': 'openai'},
'content-filter-alpha-c4': {'type': 'openai'},
'content-filter-dev': {'type': 'openai'},
'curie': {'type': 'openai'},
'cursing-filter-v6': {'type': 'openai'},
'davinci': {'type': 'openai'},
'text-davinci-002': {'type': 'openai'},
'text-davinci-003': {'type': 'openai'},
'code-davinci-002': {'type': 'openai'},
'instruct-curie-beta': {'type': 'openai'},
'instruct-davinci-beta': {'type': 'openai'},
'gpt-3.5-turbo': {'type': 'openai-chat'},
'gpt-4': {'type': 'openai-chat'},
'j1-large': {'type': 'ai21'},
'j1-jumbo': {'type': 'ai21'},
'gpt-neo-1-3b': {'type': 'gooseai'},
'gpt-neo-2-7b': {'type': 'gooseai'},
'gpt-j-6b': {'type': 'gooseai'},
'gpt-neo-20b': {'type': 'gooseai'},


'ada': {
'model': 'ada',
'type': 'openai',
'api_base': 'https://api.openai.com/v1'
},
'babbage': {
'model': 'babbage',
'type': 'openai',
'api_base': 'https://api.openai.com/v1'
},
# 'content-filter-alpha-c4': {'type': 'openai'},
# 'content-filter-dev': {'type': 'openai'},
'curie': {
'model': 'curie',
'type': 'openai',
'api_base': 'https://api.openai.com/v1'
},
# 'cursing-filter-v6': {'type': 'openai'},
'davinci': {
'model': 'davinci',
'type': 'openai',
'api_base': 'https://api.openai.com/v1'
},
'text-davinci-002': {
'model': 'text-davinci-002',
'type': 'openai',
'api_base': 'https://api.openai.com/v1'
},
'text-davinci-003': {
'model': 'text-davinci-003',
'type': 'openai',
'api_base': 'https://api.openai.com/v1'
},
# 'code-davinci-002': {
# 'model': 'code-davinci-002',
# 'type': 'openai',
# 'api_base': 'https://api.openai.com/v1'
# },
'instruct-curie-beta': {
'model': 'instruct-curie-beta',
'type': 'openai',
'api_base': 'https://api.openai.com/v1'
},
'instruct-davinci-beta': {
'model': 'instruct-davinci-beta',
'type': 'openai',
'api_base': 'https://api.openai.com/v1'
},
'gpt-3.5-turbo': {
'model': 'gpt-3.5-turbo',
'type': 'openai-chat',
'api_base': 'https://api.openai.com/v1'
},
'gpt-4': {
'model': 'gpt-4',
'type': 'openai-chat',
'api_base': 'https://api.openai.com/v1'
},
'j1-large': {
'model': 'j1-large',
'type': 'ai21',
'api_base': None,
},
'j1-jumbo': {
'model': 'j1-jumbo',
'type': 'ai21',
'api_base': None,
},
'gpt-neo-1-3b': {
'model': 'gpt-neo-1.3B',
'type': 'gooseai',
'api_base': None,
},
'gpt-neo-2-7b': {
'model': 'gpt-neo-2.7B',
'type': 'gooseai',
'api_base': None,
},
'gpt-j-6b': {
'model': 'gpt-j-6B',
'type': 'gooseai',
'api_base': None,
},
'gpt-neo-20b': {
'model': 'gpt-neo-20B',
'type': 'gooseai',
'api_base': None,
},
},
#'OPENAI_API_KEY': os.environ.get("OPENAI_API_KEY", None),
#'AI21_API_KEY': os.environ.get("AI21_API_KEY", None),
# 'api_base': None,
# 'api_key': os.environ.get("API_KEY", ''),
# 'OPENAI_API_KEY': os.environ.get("OPENAI_API_KEY", None),
# 'AI21_API_KEY': os.environ.get("AI21_API_KEY", None),
# 'GOOSEAI_API_KEY': os.environ.get("GOOSEAI_API_KEY", None),
}

DEFAULT_INLINE_GENERATION_SETTINGS = {
Expand Down Expand Up @@ -487,6 +556,7 @@ def rebuild_tree(self):
def edit_new_nodes(self):
print('new nodes:', self.new_nodes)
self.tree_updated()
time.sleep(0.5)
for node_id in self.new_nodes[0]:
self.node(node_id)['mutable'] = True
self.tree_updated(edit=self.new_nodes[0])
Expand Down

0 comments on commit dc9c8ab

Please sign in to comment.