Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix(mcp): harden notification system against race conditions
- Guard concurrent connect() calls in connection manager with connectingServers Set
- Suppress post-disconnect notification handler firing in MCP client
- Clean up Redis event listeners in pub/sub dispose()
- Add tests for all three hardening fixes (11 new tests)
  • Loading branch information
waleedlatif1 committed Feb 8, 2026
commit cf25fb084389e082b2e3ff033e77b34f9d95c732
111 changes: 111 additions & 0 deletions apps/sim/lib/mcp/client.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/**
* @vitest-environment node
*/
import { loggerMock } from '@sim/testing'
import { beforeEach, describe, expect, it, vi } from 'vitest'

vi.mock('@sim/logger', () => loggerMock)

/**
* Capture the notification handler registered via `client.setNotificationHandler()`.
* This lets us simulate the MCP SDK delivering a `tools/list_changed` notification.
*/
let capturedNotificationHandler: (() => Promise<void>) | null = null

vi.mock('@modelcontextprotocol/sdk/client/index.js', () => ({
Client: vi.fn().mockImplementation(() => ({
connect: vi.fn().mockResolvedValue(undefined),
close: vi.fn().mockResolvedValue(undefined),
getServerVersion: vi.fn().mockReturnValue('2025-06-18'),
getServerCapabilities: vi.fn().mockReturnValue({ tools: { listChanged: true } }),
setNotificationHandler: vi
.fn()
.mockImplementation((_schema: unknown, handler: () => Promise<void>) => {
capturedNotificationHandler = handler
}),
listTools: vi.fn().mockResolvedValue({ tools: [] }),
})),
}))

vi.mock('@modelcontextprotocol/sdk/client/streamableHttp.js', () => ({
StreamableHTTPClientTransport: vi.fn().mockImplementation(() => ({
onclose: null,
sessionId: 'test-session',
})),
}))

vi.mock('@modelcontextprotocol/sdk/types.js', () => ({
ToolListChangedNotificationSchema: { method: 'notifications/tools/list_changed' },
}))

vi.mock('@/lib/core/execution-limits', () => ({
getMaxExecutionTimeout: vi.fn().mockReturnValue(30000),
}))

import { McpClient } from './client'
import type { McpServerConfig } from './types'

function createConfig(): McpServerConfig {
return {
id: 'server-1',
name: 'Test Server',
transport: 'streamable-http',
url: 'https://test.example.com/mcp',
}
}

describe('McpClient notification handler', () => {
beforeEach(() => {
capturedNotificationHandler = null
})

it('fires onToolsChanged when a notification arrives while connected', async () => {
const onToolsChanged = vi.fn()

const client = new McpClient({
config: createConfig(),
securityPolicy: { requireConsent: false, auditLevel: 'basic' },
onToolsChanged,
})

await client.connect()

expect(capturedNotificationHandler).not.toBeNull()

await capturedNotificationHandler!()

expect(onToolsChanged).toHaveBeenCalledTimes(1)
expect(onToolsChanged).toHaveBeenCalledWith('server-1')
})

it('suppresses notifications after disconnect', async () => {
const onToolsChanged = vi.fn()

const client = new McpClient({
config: createConfig(),
securityPolicy: { requireConsent: false, auditLevel: 'basic' },
onToolsChanged,
})

await client.connect()
expect(capturedNotificationHandler).not.toBeNull()

await client.disconnect()

// Simulate a late notification arriving after disconnect
await capturedNotificationHandler!()

expect(onToolsChanged).not.toHaveBeenCalled()
})

it('does not register a notification handler when onToolsChanged is not provided', async () => {
const client = new McpClient({
config: createConfig(),
securityPolicy: { requireConsent: false, auditLevel: 'basic' },
})

await client.connect()

expect(capturedNotificationHandler).toBeNull()
})
})
81 changes: 63 additions & 18 deletions apps/sim/lib/mcp/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,15 @@

import { Client } from '@modelcontextprotocol/sdk/client/index.js'
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'
import type { ListToolsResult, Tool } from '@modelcontextprotocol/sdk/types.js'
import {
type ListToolsResult,
type Tool,
ToolListChangedNotificationSchema,
} from '@modelcontextprotocol/sdk/types.js'
import { createLogger } from '@sim/logger'
import { getMaxExecutionTimeout } from '@/lib/core/execution-limits'
import {
type McpClientOptions,
McpConnectionError,
type McpConnectionStatus,
type McpConsentRequest,
Expand All @@ -24,6 +29,7 @@ import {
type McpTool,
type McpToolCall,
type McpToolResult,
type McpToolsChangedCallback,
type McpVersionInfo,
} from '@/lib/mcp/types'

Expand All @@ -35,6 +41,7 @@ export class McpClient {
private config: McpServerConfig
private connectionStatus: McpConnectionStatus
private securityPolicy: McpSecurityPolicy
private onToolsChanged?: McpToolsChangedCallback
private isConnected = false

private static readonly SUPPORTED_VERSIONS = [
Expand All @@ -44,23 +51,36 @@ export class McpClient {
]

/**
* Creates a new MCP client
*
* No session ID parameter (we disconnect after each operation).
* The SDK handles session management automatically via Mcp-Session-Id header.
* Creates a new MCP client.
*
* @param config - Server configuration
* @param securityPolicy - Optional security policy
* Accepts either the legacy (config, securityPolicy?) signature
* or a single McpClientOptions object with an optional onToolsChanged callback.
*/
constructor(config: McpServerConfig, securityPolicy?: McpSecurityPolicy) {
this.config = config
this.connectionStatus = { connected: false }
this.securityPolicy = securityPolicy ?? {
requireConsent: true,
auditLevel: 'basic',
maxToolExecutionsPerHour: 1000,
constructor(config: McpServerConfig, securityPolicy?: McpSecurityPolicy)
constructor(options: McpClientOptions)
constructor(
configOrOptions: McpServerConfig | McpClientOptions,
securityPolicy?: McpSecurityPolicy
) {
if ('config' in configOrOptions) {
this.config = configOrOptions.config
this.securityPolicy = configOrOptions.securityPolicy ?? {
requireConsent: true,
auditLevel: 'basic',
maxToolExecutionsPerHour: 1000,
}
this.onToolsChanged = configOrOptions.onToolsChanged
} else {
this.config = configOrOptions
this.securityPolicy = securityPolicy ?? {
requireConsent: true,
auditLevel: 'basic',
maxToolExecutionsPerHour: 1000,
}
}

this.connectionStatus = { connected: false }

if (!this.config.url) {
throw new McpError('URL required for Streamable HTTP transport')
}
Expand All @@ -79,16 +99,15 @@ export class McpClient {
{
capabilities: {
tools: {},
// Resources and prompts can be added later
// resources: {},
// prompts: {},
},
}
)
}

/**
* Initialize connection to MCP server
* Initialize connection to MCP server.
* If an `onToolsChanged` callback was provided, registers a notification handler
* for `notifications/tools/list_changed` after connecting.
*/
async connect(): Promise<void> {
logger.info(`Connecting to MCP server: ${this.config.name} (${this.config.transport})`)
Expand All @@ -100,6 +119,15 @@ export class McpClient {
this.connectionStatus.connected = true
this.connectionStatus.lastConnected = new Date()

if (this.onToolsChanged) {
this.client.setNotificationHandler(ToolListChangedNotificationSchema, async () => {
if (!this.isConnected) return
logger.info(`[${this.config.name}] Received tools/list_changed notification`)
this.onToolsChanged?.(this.config.id)
})
logger.info(`[${this.config.name}] Registered tools/list_changed notification handler`)
}

const serverVersion = this.client.getServerVersion()
logger.info(`Successfully connected to MCP server: ${this.config.name}`, {
protocolVersion: serverVersion,
Expand Down Expand Up @@ -241,6 +269,23 @@ export class McpClient {
return !!serverCapabilities?.[capability]
}

/**
* Check if the server declared `capabilities.tools.listChanged: true` during initialization.
*/
hasListChangedCapability(): boolean {
const caps = this.client.getServerCapabilities()
const toolsCap = caps?.tools as Record<string, unknown> | undefined
return !!toolsCap?.listChanged
}

/**
* Register a callback to be invoked when the underlying transport closes.
* Used by the connection manager for reconnection logic.
*/
onClose(callback: () => void): void {
this.transport.onclose = callback
}

/**
* Get server configuration
*/
Expand Down
Loading