Skip to content

Commit

Permalink
Enable VertexAI for text to test streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
gsans committed Oct 17, 2023
1 parent d585d6c commit 462f8f2
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 16 deletions.
12 changes: 6 additions & 6 deletions src/app/app.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,15 @@ export function markedOptionsFactory(): MarkedOptions {
AppRoutingModule,

BrowserAnimationsModule,
// VertexModule.forRoot({
// projectId: environment.PROJECT_ID,
// accessToken: environment.GCLOUD_AUTH_PRINT_ACCESS_TOKEN,
// version: "v1" // options: v1beta1, v1
// })
VertexModule.forRoot({
projectId: environment.PROJECT_ID,
accessToken: environment.GCLOUD_AUTH_PRINT_ACCESS_TOKEN,
version: "v1" // options: v1beta1, v1
}),
PalmModule.forRoot({
apiKey: environment.API_KEY,
version: "v1beta2" // options: v1beta2
}),
}),
MatIconModule,
MatListModule,
MatSidenavModule,
Expand Down
29 changes: 28 additions & 1 deletion src/app/generative-ai-vertex/v1/prediction.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { Injectable } from '@angular/core';
import { HttpClient, HttpHeaders } from '@angular/common/http';

import { GoogleCloudCredentials } from '../types';
import { createPrompt, TextRequest, TextResponse } from './vertex.types';
import { createPrompt, TextRequest, TextResponse, createPromptStreaming } from './vertex.types';
import { firstValueFrom } from 'rxjs';

@Injectable({
Expand Down Expand Up @@ -30,6 +30,14 @@ export class PredictionServiceClient {
);
}

streamingPredict(text: string, model: string = "text-bison") {
let endpoint = this.buildEndpointUrlStreaming(model);
let prompt: any = createPromptStreaming(text);
let headers = this.getAuthHeadersStreaming();

return this.http.post(endpoint, prompt, { headers });
}

private buildEndpointUrl(model: string) {

let url = this.baseUrl; // base url
Expand All @@ -43,8 +51,27 @@ export class PredictionServiceClient {
return url;
}

private buildEndpointUrlStreaming(model: string) {

let url = this.baseUrl; // base url
url += this.version; // api version
url += "/projects/" + this.projectID; // project id
url += "/locations/us-central1"; // google cloud region
url += "/publishers/google"; // publisher
url += "/models/" + model; // model
url += ":serverStreamingPredict"; // action

return url;
}

private getAuthHeaders() {
return new HttpHeaders()
.set('Authorization', `Bearer ${this.accessToken}`);
}

private getAuthHeadersStreaming() {
return new HttpHeaders()
.set('Authorization', `Bearer ${this.accessToken}`);
//.set('Transfer-Encoding', 'chunked');
}
}
29 changes: 29 additions & 0 deletions src/app/generative-ai-vertex/v1/vertex.types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,35 @@ export function createPrompt(
return request;
}

export function createPromptStreaming(
prompt: string = "What is the largest number with a name?",
temperature: number = 0,
maxOutputTokens: number = 100,
topP: number = 0.70,
topK: number = 40
): any {
const request: any = {
"inputs": [
{
"struct_val": {
"prompt": {
"string_val": [ `${prompt}`]
}
}
}
],
"parameters": {
"struct_val": {
"temperature": { "float_val": temperature },
"maxOutputTokens": { "int_val": maxOutputTokens },
"topK": { "int_val": topK },
"topP": { "float_val": topP }
}
}
}
return request;
}

// Text API
export interface TextRequest {
instances: TextInstance[];
Expand Down
10 changes: 10 additions & 0 deletions src/app/rich-text-editor/rich-text-editor.component.ts
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,16 @@ export class RichTextEditorComponent {
}
}

insertStream(text: string) {
var range = this.quillInstance.getSelection();
if (range) {
if (range.length > 0) return; // range selected ignore
const index = range.index;
this.quillInstance.insertText(index, text, 'api');
this.quillInstance.update('api');
}
}

insertAndFormat(text:string) {
var range = this.quillInstance.getSelection();
if (range) {
Expand Down
40 changes: 31 additions & 9 deletions src/app/text/text.component.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import { Component, Inject, ViewChild } from '@angular/core';
//import { PREDICTION_SERVICE_CLIENT_TOKEN } from '../generative-ai-vertex/vertex.module';
import { TEXT_SERVICE_CLIENT_TOKEN } from '../generative-ai-palm/palm.module';
import { TextServiceClient } from '../generative-ai-palm/v1beta2/text.service';
import { PREDICTION_SERVICE_CLIENT_TOKEN } from '../generative-ai-vertex/vertex.module';
import { PredictionServiceClient } from '../generative-ai-vertex/v1/prediction.service';
import { TextResponse } from '../generative-ai-vertex/v1/vertex.types';
//import { TEXT_SERVICE_CLIENT_TOKEN } from '../generative-ai-palm/palm.module';
//import { TextServiceClient } from '../generative-ai-palm/v1beta2/text.service';
import { RichTextEditorComponent } from '../rich-text-editor/rich-text-editor.component';
import { AudioService } from '../read/audio.service';
const MAX_PHRASES = 10;
Expand All @@ -18,7 +20,8 @@ export class TextComponent {
playing: boolean = false;

constructor(
@Inject(TEXT_SERVICE_CLIENT_TOKEN) public client: TextServiceClient,
//@Inject(TEXT_SERVICE_CLIENT_TOKEN) public client: TextServiceClient,
@Inject(PREDICTION_SERVICE_CLIENT_TOKEN) public client: PredictionServiceClient,
private audio: AudioService
) { }

Expand All @@ -29,11 +32,30 @@ export class TextComponent {
async run() {
if (!this.editor) return;
const prompt = this.editor.extractPrompt();
const response = await this.client.generateText(prompt);
const text = (response?.candidates?.[0].output || '').trim();
if (text.length > 0) {
this.editor.insertAndFormatMarkdown(text);
}

// PaLM
//const response = await this.client.generateText(prompt);
//const text = (response?.candidates?.[0].output || '').trim();

// Vertex AI
//const response: TextResponse = await this.client.predict(prompt);
//const text = (response?.predictions?.[0].content).trim();
// if (text.length > 0) {
// this.editor.insertAndFormatMarkdown(text);
// }

// Vertex AI Stream
this.client.streamingPredict(prompt).subscribe({
next: (response: any) => {
console.log('stream-chunk');
response.forEach( (element: any) => {
const text = (element.outputs?.[0].structVal?.content?.stringVal?.[0]).trim();
this.editor.insertStream(text);
});
},
complete: () => { console.log('stream-end'); },
error: (error) => { console.log(error); }
})
}

clear() {
Expand Down

0 comments on commit 462f8f2

Please sign in to comment.