mirror of
https://github.com/marcogll/TaxHacker_s23.git
synced 2026-01-13 13:25:18 +00:00
feat: more llm provider options (google, mistral) (#28)
* feat: add google provider * fix: default for google model * feat: multiple providers * fix: defaults from env for login form * fix: add mistral to env files * chore: delete unused code * chore: revert database url to original * fix: render default value for api key from env on server * fix: type errors during compilation --------- Co-authored-by: Vasily Zubarev <me@vas3k.ru>
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
"use server"
|
||||
|
||||
import { ActionState } from "@/lib/actions"
|
||||
import config from "@/lib/config"
|
||||
import OpenAI from "openai"
|
||||
import { AnalyzeAttachment } from "./attachments"
|
||||
import { updateFile } from "@/models/files"
|
||||
import { getSettings, getLLMSettings } from "@/models/settings"
|
||||
import { requestLLM } from "./providers/llmProvider"
|
||||
|
||||
export type AnalysisResult = {
|
||||
output: Record<string, string>
|
||||
@@ -15,52 +15,39 @@ export async function analyzeTransaction(
|
||||
prompt: string,
|
||||
schema: Record<string, unknown>,
|
||||
attachments: AnalyzeAttachment[],
|
||||
apiKey: string,
|
||||
fileId: string,
|
||||
userId: string
|
||||
): Promise<ActionState<AnalysisResult>> {
|
||||
const openai = new OpenAI({
|
||||
apiKey,
|
||||
})
|
||||
console.log("RUNNING AI ANALYSIS")
|
||||
console.log("PROMPT:", prompt)
|
||||
console.log("SCHEMA:", schema)
|
||||
|
||||
const settings = await getSettings(userId)
|
||||
const llmSettings = getLLMSettings(settings)
|
||||
|
||||
try {
|
||||
const response = await openai.responses.create({
|
||||
model: config.ai.modelName,
|
||||
input: [
|
||||
{
|
||||
role: "user",
|
||||
content: prompt,
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content: attachments.map((attachment) => ({
|
||||
type: "input_image",
|
||||
detail: "auto",
|
||||
image_url: `data:${attachment.contentType};base64,${attachment.base64}`,
|
||||
})),
|
||||
},
|
||||
],
|
||||
text: {
|
||||
format: {
|
||||
type: "json_schema",
|
||||
name: "transaction",
|
||||
schema: schema,
|
||||
strict: true,
|
||||
},
|
||||
},
|
||||
const response = await requestLLM(llmSettings, {
|
||||
prompt,
|
||||
schema,
|
||||
attachments,
|
||||
})
|
||||
|
||||
console.log("ChatGPT response:", response.output_text)
|
||||
console.log("ChatGPT tokens used:", response.usage)
|
||||
if (response.error) {
|
||||
throw new Error(response.error)
|
||||
}
|
||||
|
||||
const result = response.output
|
||||
const tokensUsed = response.tokensUsed || 0
|
||||
|
||||
console.log("LLM response:", result)
|
||||
console.log("LLM tokens used:", tokensUsed)
|
||||
|
||||
const result = JSON.parse(response.output_text)
|
||||
|
||||
await updateFile(fileId, userId, { cachedParseResult: result })
|
||||
|
||||
return { success: true, data: { output: result, tokensUsed: response.usage?.total_tokens || 0 } }
|
||||
return {
|
||||
success: true,
|
||||
data: {
|
||||
output: result,
|
||||
tokensUsed: tokensUsed
|
||||
},
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("AI Analysis error:", error)
|
||||
return {
|
||||
|
||||
115
ai/providers/llmProvider.ts
Normal file
115
ai/providers/llmProvider.ts
Normal file
@@ -0,0 +1,115 @@
|
||||
import { ChatOpenAI } from "@langchain/openai"
|
||||
import { ChatGoogleGenerativeAI } from "@langchain/google-genai"
|
||||
import { ChatMistralAI } from "@langchain/mistralai"
|
||||
import { BaseMessage, HumanMessage } from "@langchain/core/messages"
|
||||
|
||||
export type LLMProvider = "openai" | "google" | "mistral"
|
||||
|
||||
export interface LLMConfig {
|
||||
provider: LLMProvider
|
||||
apiKey: string
|
||||
model: string
|
||||
}
|
||||
|
||||
export interface LLMSettings {
|
||||
providers: LLMConfig[]
|
||||
}
|
||||
|
||||
export interface LLMRequest {
|
||||
prompt: string
|
||||
schema?: Record<string, unknown>
|
||||
attachments?: any[]
|
||||
}
|
||||
|
||||
export interface LLMResponse {
|
||||
output: Record<string, string>
|
||||
tokensUsed?: number
|
||||
provider: LLMProvider
|
||||
error?: string
|
||||
}
|
||||
|
||||
async function requestLLMUnified(config: LLMConfig, req: LLMRequest): Promise<LLMResponse> {
|
||||
try {
|
||||
const temperature = 0;
|
||||
let model: any;
|
||||
if (config.provider === "openai") {
|
||||
model = new ChatOpenAI({
|
||||
apiKey: config.apiKey,
|
||||
model: config.model,
|
||||
temperature: temperature,
|
||||
});
|
||||
} else if (config.provider === "google") {
|
||||
model = new ChatGoogleGenerativeAI({
|
||||
apiKey: config.apiKey,
|
||||
model: config.model,
|
||||
temperature: temperature,
|
||||
});
|
||||
} else if (config.provider === "mistral") {
|
||||
model = new ChatMistralAI({
|
||||
apiKey: config.apiKey,
|
||||
model: config.model,
|
||||
temperature: temperature,
|
||||
});
|
||||
} else {
|
||||
return {
|
||||
output: {},
|
||||
provider: config.provider,
|
||||
error: "Unknown provider",
|
||||
};
|
||||
}
|
||||
|
||||
const structuredModel = model.withStructuredOutput(req.schema, { 'name': 'transaction'});
|
||||
|
||||
let message_content: any = [{ type: "text", text: req.prompt }];
|
||||
if (req.attachments && req.attachments.length > 0) {
|
||||
const images = req.attachments.map(att => ({
|
||||
type: "image_url",
|
||||
image_url: {
|
||||
url: `data:${att.contentType};base64,${att.base64}`
|
||||
},
|
||||
}));
|
||||
message_content.push(...images);
|
||||
}
|
||||
const messages: BaseMessage[] = [
|
||||
new HumanMessage({ content: message_content })
|
||||
];
|
||||
|
||||
const response = await structuredModel.invoke(messages);
|
||||
|
||||
return {
|
||||
output: response,
|
||||
provider: config.provider,
|
||||
};
|
||||
} catch (error: any) {
|
||||
return {
|
||||
output: {},
|
||||
provider: config.provider,
|
||||
error: error instanceof Error ? error.message : `${config.provider} request failed`,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
export async function requestLLM(settings: LLMSettings, req: LLMRequest): Promise<LLMResponse> {
|
||||
for (const config of settings.providers) {
|
||||
if (!config.apiKey || !config.model) {
|
||||
console.info('Skipping provider:', config.provider);
|
||||
continue;
|
||||
}
|
||||
console.info('Use provider:', config.provider);
|
||||
|
||||
const response = await requestLLMUnified(config, req);
|
||||
|
||||
if (!response.error) {
|
||||
return response;
|
||||
}
|
||||
else {
|
||||
console.error(response.error)
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
output: {},
|
||||
provider: settings.providers[0]?.provider || "openai",
|
||||
error: "All LLM providers failed or are not configured",
|
||||
};
|
||||
}
|
||||
Reference in New Issue
Block a user