Skip to content

Commit

Permalink
Add streaming support to vertex AI API
Browse files Browse the repository at this point in the history
  • Loading branch information
gsans committed Dec 12, 2023
1 parent 0fdce51 commit 1dd26dc
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 47 deletions.
3 changes: 2 additions & 1 deletion src/app/app.module.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { NgModule } from '@angular/core';
import { BrowserModule } from '@angular/platform-browser';

import { HttpClientModule, HttpClient } from '@angular/common/http';
import { HttpClientModule, HttpClient, provideHttpClient, withFetch } from '@angular/common/http';

import { AppRoutingModule } from './app-routing.module';
import { AppComponent } from './app.component';
Expand Down Expand Up @@ -113,6 +113,7 @@ export function markedOptionsFactory(): MarkedOptions {
FormsModule,
],
providers: [
provideHttpClient(withFetch()),
AudioService,
RouterScrollService,
],
Expand Down
6 changes: 5 additions & 1 deletion src/app/generative-ai-vertex/v1/prediction.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ export class PredictionServiceClient {
let prompt: any = createPromptStreaming(text);
let headers = this.getAuthHeadersStreaming();

return this.http.post(endpoint, prompt, { headers });
return this.http.post(endpoint, prompt, { headers, ... {
observe: "events",
responseType: "text",
reportProgress: true,
} });
}

private buildEndpointUrl(model: string) {
Expand Down
2 changes: 1 addition & 1 deletion src/app/generative-ai-vertex/v1/vertex.types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ export function createPromptStreaming(
{
"struct_val": {
"prompt": {
"string_val": [ `${prompt}`]
"string_val": [`${prompt}`]
}
}
}
Expand Down
24 changes: 16 additions & 8 deletions src/app/rich-text-editor/rich-text-editor.component.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,22 @@ export class RichTextEditorComponent {
}
}

insertStream(text: string) {
var range = this.quillInstance.getSelection();
if (range) {
if (range.length > 0) return; // range selected ignore
const index = range.index;
this.quillInstance.insertText(index, text, 'api');
this.quillInstance.update('api');
}
getRange() {
return this.quillInstance.getSelection();
}

insertStream(text: string, range: any) {
// remove everything
let l = this.quillInstance.getLength();
this.quillInstance.deleteText(range.index, l)

// select initial range
this.quillInstance.setSelection(range);
const index = range.index || 0;
const length = text.length;

this.quillInstance.insertEmbed(index, 'label', text, 'api');
this.quillInstance.update('api');
}

insertAndFormat(text: string, error: boolean = false) {
Expand Down
117 changes: 81 additions & 36 deletions src/app/text/text.component.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import { Component, Inject, ViewChild } from '@angular/core';
// import { PREDICTION_SERVICE_CLIENT_TOKEN } from '../generative-ai-vertex/vertex.module';
// import { PredictionServiceClient } from '../generative-ai-vertex/v1/prediction.service';
// import { TextResponse } from '../generative-ai-vertex/v1/vertex.types';
import { TEXT_SERVICE_CLIENT_TOKEN } from '../generative-ai-palm/palm.module';
import { TextServiceClient } from '../generative-ai-palm/v1beta2/text.service';
import { TextResponse } from '../generative-ai-palm/v1beta2/palm.types';
import { PREDICTION_SERVICE_CLIENT_TOKEN } from '../generative-ai-vertex/vertex.module';
import { PredictionServiceClient } from '../generative-ai-vertex/v1/prediction.service';
import { TextResponse } from '../generative-ai-vertex/v1/vertex.types';
//import { TEXT_SERVICE_CLIENT_TOKEN } from '../generative-ai-palm/palm.module';
//import { TextServiceClient } from '../generative-ai-palm/v1beta2/text.service';
//import { TextResponse } from '../generative-ai-palm/v1beta2/palm.types';
import { RichTextEditorComponent } from '../rich-text-editor/rich-text-editor.component';
import { AudioService } from '../read/audio.service';
import { HttpEvent, HttpEventType, HttpDownloadProgressEvent } from '@angular/common/http';
import { Subject, switchMap } from 'rxjs';
const MAX_PHRASES = 10;

@Component({
Expand All @@ -21,8 +23,8 @@ export class TextComponent {
playing: boolean = false;

constructor(
@Inject(TEXT_SERVICE_CLIENT_TOKEN) public client: TextServiceClient,
//@Inject(PREDICTION_SERVICE_CLIENT_TOKEN) public client: PredictionServiceClient,
//@Inject(TEXT_SERVICE_CLIENT_TOKEN) public client: TextServiceClient,
@Inject(PREDICTION_SERVICE_CLIENT_TOKEN) public client: PredictionServiceClient,
private audio: AudioService
) { }

Expand All @@ -35,34 +37,77 @@ export class TextComponent {
const prompt = this.editor.extractPrompt();

// PaLM
const response = await this.client.generateText(prompt);
if (response.filters && response.filters.length > 0){
this.logBlockedResponse(prompt, response);
this.editor.insertAndFormat("Response was blocked. Try changing your prompt to avoid any derogatory, toxic, sexual, violent, dangerous or medical related content.", true);
} else {
const text = (response?.candidates?.[0].output || '').trim();
this.editor.insertAndFormat(text);
}
// const response = await this.client.generateText(prompt);
// if (response.filters && response.filters.length > 0){
// this.logBlockedResponse(prompt, response);
// this.editor.insertAndFormat("Response was blocked. Try changing your prompt to avoid any derogatory, toxic, sexual, violent, dangerous or medical related content.", true);
// } else {
// const text = (response?.candidates?.[0].output || '').trim();
// this.editor.insertAndFormat(text);
// }

// Vertex AI
//const response: TextResponse = await this.client.predict(prompt);
//const text = (response?.predictions?.[0].content).trim();
// const response: TextResponse = await this.client.predict(prompt);
// const text = (response?.predictions?.[0].content).trim();
// if (text.length > 0) {
// this.editor.insertAndFormatMarkdown(text);
// }

// Vertex AI Stream
// this.client.streamingPredict(prompt).subscribe({
// next: (response: any) => {
// console.log('stream-chunk');
// response.forEach( (element: any) => {
// const text = (element.outputs?.[0].structVal?.content?.stringVal?.[0]).trim();
// this.editor.insertStream(text);
// });
// },
// complete: () => { console.log('stream-end'); },
// error: (error) => { console.log(error); }
// })
let state: any = {
range: undefined,
text: ''
};
this.client.streamingPredict(prompt).subscribe({
next: (event: HttpEvent<string>) => {
state = this.handleStreamEvent(event, state);
},
complete: () => { console.log('stream-end'); },
error: (error) => { console.log(error); }
})
}

private showTextEditor(text: string) {
console.log(text);
}

private handleStreamEvent(event: HttpEvent<string>, state: any): any {
switch(event.type) {
case HttpEventType.Sent: {
let range = this.editor.getRange();
return { range, text: state.text };
}
case HttpEventType.DownloadProgress: {
let fragment = this.fixPartialAnswer((event as any).partialText);
let text = this.extractCompletion(fragment);
this.editor.insertStream(text, state.range);
break;
}
case HttpEventType.Response: {
let text = this.extractCompletion(event.body || "");
this.editor.insertStream(text, state.range);
break;
}
}
return { range: state.range, text: state.text };
}

private fixPartialAnswer(fragment: string): string {
// Add closing array to partial reply if necessary
if (fragment.slice(-1) !== ']') {
fragment += ']';
}
return fragment;
}

private extractCompletion(fragment: string): string {
const response = JSON.parse(fragment);
let final = '';
response.forEach((element: any) => {
final += (element.outputs?.[0].structVal?.content?.stringVal?.[0]).trim() + ' ';
})
final = final.slice(0, -1);
return final;
}

clear() {
Expand Down Expand Up @@ -95,13 +140,13 @@ export class TextComponent {
return text;
}

private logBlockedResponse(prompt: string, response: TextResponse) {
if (!response.filters || response.filters.length == 0) return;
// private logBlockedResponse(prompt: string, response: TextResponse) {
// if (!response.filters || response.filters.length == 0) return;

console.log("Response was blocked.");
console.log(`Original prompt: ${prompt}`);
console.log(`Filters applied:\n${JSON.stringify(response.filters, null, " ")}`);
console.log(`Safety settings and category ratings:\n${JSON.stringify(response.safetyFeedback, null, " ")}`);
}
// console.log("Response was blocked.");
// console.log(`Original prompt: ${prompt}`);
// console.log(`Filters applied:\n${JSON.stringify(response.filters, null, " ")}`);
// console.log(`Safety settings and category ratings:\n${JSON.stringify(response.safetyFeedback, null, " ")}`);
// }
}

0 comments on commit 1dd26dc

Please sign in to comment.