Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions agent/middleware/apiBasedTools.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import {
} from "../toolCallEvents.js";
import { ALWAYS_AVAILABLE_API_TOOL_NAMES } from "../tools/index.js";
import { createApiTool } from "../tools/apiTool.js";
import type { AgentEventEmitter } from "../../agentEvents.js";
import type { SequenceDebugCollector } from "./sequenceDebug.js";

function getEnabledApiToolNames(messages: unknown[]) {
const enabledToolNames = new Set<string>();
Expand Down Expand Up @@ -80,11 +82,19 @@ export function createApiBasedToolsMiddleware(
async wrapToolCall(request, handler) {
const startedAt = Date.now();
const toolInput = JSON.stringify(request.toolCall.args ?? {});
const { adminUser, emitToolCallEvent, userTimeZone } = request.runtime.context as {
const { adminUser, emit, sequenceDebugSink, userTimeZone } = request.runtime.context as {
adminUser: AdminUser;
emitToolCallEvent: ToolCallEventSink;
emit?: AgentEventEmitter;
sequenceDebugSink: SequenceDebugCollector;
userTimeZone: string;
};
const emitToolCall: ToolCallEventSink = (event) => {
sequenceDebugSink.handleToolCallEvent(event);
void emit?.({
type: "tool-call",
data: event,
});
};
const toolArgs = (request.toolCall.args ?? {}) as Record<string, unknown>;
let toolInfo: string | undefined;

Expand All @@ -102,7 +112,7 @@ export function createApiBasedToolsMiddleware(
});
}
const toolCallTracker = createToolCallTracker({
emit: emitToolCallEvent,
emit: emitToolCall,
toolCallId: request.toolCall.id,
toolName: request.toolCall.name,
toolInfo,
Expand Down
9 changes: 9 additions & 0 deletions agent/models/AgentModeResolver.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import type { PluginOptions } from "../../types.js";

export class AgentModeResolver {
constructor(private readonly options: PluginOptions) {}

resolve(modeName?: string | null) {
return this.options.modes.find((mode) => mode.name === modeName) ?? this.options.modes[0];
}
}
28 changes: 28 additions & 0 deletions agent/models/AgentModelFactory.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import type { CompletionAdapter } from "adminforth";
import { createAgentChatModel } from "../simpleAgent.js";
import type { AgentTurnModels } from "../turn/turnTypes.js";

export class AgentModelFactory {
constructor(private readonly maxTokens: number) {}

async create(completionAdapter: CompletionAdapter): Promise<AgentTurnModels> {
const [primaryModelSpec, summaryModelSpec] = await Promise.all([
createAgentChatModel({
adapter: completionAdapter,
maxTokens: this.maxTokens,
purpose: "primary",
}),
createAgentChatModel({
adapter: completionAdapter,
maxTokens: this.maxTokens,
purpose: "summary",
}),
]);

return {
model: primaryModelSpec.model,
summaryModel: summaryModelSpec.model,
modelMiddleware: primaryModelSpec.middleware,
};
}
}
30 changes: 30 additions & 0 deletions agent/runtime/AgentContext.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import type { AdminUser } from "adminforth";
import { z } from "zod";
import type { AgentEventEmitter } from "../../agentEvents.js";
import type { SequenceDebugCollector } from "../middleware/sequenceDebug.js";
import type { CurrentPageContext } from "../tools/getUserLocation.js";
import type { AgentTurnContext } from "../turn/turnTypes.js";

export const contextSchema = z.object({
adminUser: z.custom<AdminUser>(),
userTimeZone: z.string(),
sessionId: z.string(),
turnId: z.string(),
abortSignal: z.custom<AbortSignal>().optional(),
currentPage: z.custom<CurrentPageContext>().optional(),
chatSurface: z.string().optional(),
adminBaseUrl: z.string().optional(),
adminPublicOrigin: z.string().optional(),
emit: z.custom<AgentEventEmitter>().optional(),
sequenceDebugSink: z.custom<SequenceDebugCollector>(),
});

export function toLangchainAgentContext(
context: AgentTurnContext & {
adminBaseUrl: string;
emit?: AgentEventEmitter;
sequenceDebugSink: SequenceDebugCollector;
},
) {
return context;
}
68 changes: 68 additions & 0 deletions agent/runtime/AgentRuntime.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import type { IAdminForth } from "adminforth";
import { createAgent, summarizationMiddleware } from "langchain";
import type { BaseCheckpointSaver } from "@langchain/langgraph";
import { createApiBasedToolsMiddleware } from "../middleware/apiBasedTools.js";
import { createSequenceDebugMiddleware } from "../middleware/sequenceDebug.js";
import { createAgentLlmMetricsLogger } from "../simpleAgent.js";
import type { AgentToolProvider } from "../tools/AgentToolProvider.js";
import type { AgentRuntimeRunInput } from "../turn/turnTypes.js";
import { contextSchema, toLangchainAgentContext } from "./AgentContext.js";

export type AgentRuntimeOptions = {
name: string;
getAdminforth: () => IAdminForth;
getCheckpointer: () => BaseCheckpointSaver;
toolProvider: AgentToolProvider;
};

export class AgentRuntime {
constructor(private readonly options: AgentRuntimeOptions) {}

async stream(input: AgentRuntimeRunInput) {
const apiBasedTools = this.options.toolProvider.getApiBasedTools();
const tools = await this.options.toolProvider.getTools(apiBasedTools);
const adminforth = this.options.getAdminforth();
const apiBasedToolsMiddleware = createApiBasedToolsMiddleware(
apiBasedTools,
adminforth,
);
const sequenceDebugMiddleware = createSequenceDebugMiddleware(
input.observability.sequenceDebugSink,
);
const middleware = [
apiBasedToolsMiddleware,
...(input.models.modelMiddleware ?? []),
sequenceDebugMiddleware,
summarizationMiddleware({
model: input.models.summaryModel,
trigger: { tokens: 1024 * 64 },
keep: { messages: 10 },
}),
] as const;

const agent = createAgent({
name: this.options.name,
model: input.models.model,
checkpointer: this.options.getCheckpointer(),
tools,
contextSchema,
middleware,
});

return agent.stream({ messages: input.messages } as any, {
streamMode: "messages",
recursionLimit: 100,
callbacks: [createAgentLlmMetricsLogger()],
signal: input.context.abortSignal,
configurable: {
thread_id: input.context.sessionId,
},
context: toLangchainAgentContext({
...input.context,
adminBaseUrl: adminforth.config.baseUrlSlashed,
emit: input.observability.emit,
sequenceDebugSink: input.observability.sequenceDebugSink,
}),
});
}
}
131 changes: 2 additions & 129 deletions agent/simpleAgent.ts
Original file line number Diff line number Diff line change
@@ -1,39 +1,13 @@
import type { BaseChatModel } from "@langchain/core/language_models/chat_models";
import { createAgent, summarizationMiddleware } from "langchain";
import {
logger,
type AdminUser,
type CompletionAdapter,
type IAdminForth,
} from "adminforth";
import { BaseCallbackHandler } from "@langchain/core/callbacks/base";
import {type BaseCheckpointSaver, type Messages } from "@langchain/langgraph";
import type { LLMResult } from "@langchain/core/outputs";
import { z } from "zod";
import { createAgentTools } from "./tools/index.js";
import { createApiBasedToolsMiddleware } from "./middleware/apiBasedTools.js";
import {
createSequenceDebugMiddleware,
type SequenceDebugModelCallSink,
} from "./middleware/sequenceDebug.js";
import type { ApiBasedTool } from "../apiBasedTools.js";
import type { ToolCallEventSink } from "./toolCallEvents.js";
import type { CurrentPageContext } from "./tools/getUserLocation.js";
import type { AgentEventEmitter } from "../agentEvents.js";

export const contextSchema = z.object({
adminUser: z.custom<AdminUser>(),
userTimeZone: z.string(),
sessionId: z.string(),
turnId: z.string(),
abortSignal: z.custom<AbortSignal>().optional(),
currentPage: z.custom<CurrentPageContext>().optional(),
chatSurface: z.string().optional(),
adminBaseUrl: z.string().optional(),
adminPublicOrigin: z.string().optional(),
emitToolCallEvent: z.custom<ToolCallEventSink>(),
emitAgentEvent: z.custom<AgentEventEmitter>().optional(),
});

export type AgentChatModel = BaseChatModel<any, any>;
export type AgentModelPurpose = "primary" | "summary";
Expand All @@ -50,7 +24,7 @@ export type AgentModeCompletionAdapter = CompletionAdapter & {
};
};

type AgentMiddleware = ReturnType<typeof createSequenceDebugMiddleware>;
export type AgentMiddleware = ReturnType<typeof createSequenceDebugMiddleware>;

type AgentChatModelSpec = {
model: AgentChatModel;
Expand Down Expand Up @@ -202,7 +176,7 @@ class AgentLlmMetricsLogger extends BaseCallbackHandler {
}
}

function createAgentLlmMetricsLogger() {
export function createAgentLlmMetricsLogger() {
return new AgentLlmMetricsLogger();
}

Expand All @@ -223,104 +197,3 @@ export async function createAgentChatModel(params: {
purpose: params.purpose,
});
}

export async function callAgent(params: {
name: string;
model: AgentChatModel;
summaryModel: AgentChatModel;
modelMiddleware?: AgentMiddleware[];
checkpointer?: BaseCheckpointSaver;
messages: Messages;
adminUser: AdminUser;
adminforth: IAdminForth;
apiBasedTools: Record<string, ApiBasedTool>;
customComponentsDir: string;
pluginCustomFolderPaths: string[];
sessionId: string;
turnId: string;
currentPage?: CurrentPageContext;
chatSurface?: string;
adminPublicOrigin?: string;
userTimeZone: string;
abortSignal?: AbortSignal;
emitToolCallEvent: ToolCallEventSink;
emitAgentEvent?: AgentEventEmitter;
sequenceDebugSink: SequenceDebugModelCallSink;
}) {
const {
name,
model,
summaryModel,
modelMiddleware = [],
checkpointer,
messages,
adminUser,
adminforth,
apiBasedTools,
customComponentsDir,
pluginCustomFolderPaths,
sessionId,
turnId,
currentPage,
chatSurface,
adminPublicOrigin,
userTimeZone,
abortSignal,
emitToolCallEvent,
emitAgentEvent,
sequenceDebugSink,
} = params;

const tools = await createAgentTools(
customComponentsDir,
apiBasedTools,
pluginCustomFolderPaths,
);
const apiBasedToolsMiddleware = createApiBasedToolsMiddleware(apiBasedTools, adminforth);
const sequenceDebugMiddleware = createSequenceDebugMiddleware(
sequenceDebugSink,
);

const middleware = [
apiBasedToolsMiddleware,
...modelMiddleware,
sequenceDebugMiddleware,
summarizationMiddleware({
model: summaryModel,
trigger: { tokens: 1024 * 64 },
keep: { messages: 10 },
}),
] as const;

const agent = createAgent<undefined, typeof contextSchema, typeof middleware>({
name,
model,
checkpointer,
tools,
contextSchema,
middleware,
});

return await agent.stream({ messages } as any, {
streamMode: "messages",
recursionLimit: 100,
callbacks: [createAgentLlmMetricsLogger()],
signal: abortSignal,
configurable: {
thread_id: sessionId,
},
context: {
adminUser,
userTimeZone,
sessionId,
turnId,
abortSignal,
currentPage,
chatSurface,
adminBaseUrl: adminforth.config.baseUrlSlashed,
adminPublicOrigin,
emitToolCallEvent,
emitAgentEvent,
},
});
}
Loading