diff --git a/src/frontends/tui/components/App.tsx b/src/frontends/tui/components/App.tsx index d4dc49f..9f0e0a9 100644 --- a/src/frontends/tui/components/App.tsx +++ b/src/frontends/tui/components/App.tsx @@ -191,6 +191,10 @@ export function App({ } const tier = resolveModelAlias(command.name); if (modelRouter.setTier(tier)) { + // Also update the agent tier so chatWithRouter uses the correct client + if (agent) { + agent.setModelTier(tier); + } setCurrentModel(tier); const successMsg: Message = { role: 'assistant', content: `Switched to model: ${tier}` }; const successWithTs = session.addMessage(successMsg); diff --git a/src/gateway/session-bridge.ts b/src/gateway/session-bridge.ts index f61ffec..9e7c021 100644 --- a/src/gateway/session-bridge.ts +++ b/src/gateway/session-bridge.ts @@ -2,7 +2,7 @@ import { randomUUID } from 'crypto'; import type { SessionManager } from '../session/manager.js'; import type { Session } from '../session/manager.js'; import type { ModelClient } from '../models/types.js'; -import type { ModelRouter } from '../models/router.js'; +import type { ModelRouter, ModelTier } from '../models/router.js'; import type { ToolRegistry } from '../tools/registry.js'; import type { ToolExecutor } from '../tools/executor.js'; import { NativeAgent } from '../backends/native/agent.js'; @@ -27,9 +27,27 @@ export class SessionBridge { private clients: Map = new Map(); private agents: Map = new Map(); private config: SessionBridgeConfig; + /** Tracks the current model tier so new agents inherit it and existing agents stay in sync. */ + private currentTier: ModelTier = 'default'; constructor(config: SessionBridgeConfig) { this.config = config; + + // If the model client is a ModelRouter, subscribe to tier changes + // so all WebChat agents stay in sync with TUI model switches. + if ('getClient' in config.modelClient) { + const router = config.modelClient as ModelRouter; + this.currentTier = router.getTier(); + router.addOnTierChange((tier: ModelTier) => this.onTierChanged(tier)); + } + } + + /** Called when the ModelRouter's active tier changes. Updates all existing agents. */ + private onTierChanged(tier: ModelTier): void { + this.currentTier = tier; + for (const agent of this.agents.values()) { + agent.setModelTier(tier); + } } /** Register a new WS connection. Returns the assigned connection ID. */ @@ -172,6 +190,8 @@ export class SessionBridge { toolRegistry: this.config.toolRegistry, toolExecutor: this.config.toolExecutor, }); + // Inherit the current model tier so the agent uses the same model as the TUI + agent.setModelTier(this.currentTier); this.agents.set(sessionId, agent); } return agent; diff --git a/src/models/router.ts b/src/models/router.ts index 9be6c0d..392c93b 100644 --- a/src/models/router.ts +++ b/src/models/router.ts @@ -26,7 +26,7 @@ export class ModelRouter implements ModelClient { private currentTier: ModelTier = 'default'; private localProviderName?: string; private retryConfig?: RetryConfig; - private onTierChange?: (tier: ModelTier) => void; + private tierChangeListeners: Array<(tier: ModelTier) => void> = []; constructor(config: ModelRouterConfig) { this.clients = new Map(); @@ -35,7 +35,10 @@ export class ModelRouter implements ModelClient { this.fallbackChain = config.fallbackChain; this.tierFallbacks = config.tierFallbacks ?? new Map(); this.retryConfig = config.retryConfig; - this.onTierChange = config.onTierChange; + + if (config.onTierChange) { + this.tierChangeListeners.push(config.onTierChange); + } this.clients.set('default', config.default); if (config.fast) this.clients.set('fast', config.fast); @@ -54,14 +57,20 @@ export class ModelRouter implements ModelClient { setTier(tier: ModelTier): boolean { if (this.clients.has(tier)) { this.currentTier = tier; - this.onTierChange?.(tier); + for (const listener of this.tierChangeListeners) { + listener(tier); + } return true; } return false; } setOnTierChange(callback: (tier: ModelTier) => void): void { - this.onTierChange = callback; + this.tierChangeListeners = [callback]; + } + + addOnTierChange(callback: (tier: ModelTier) => void): void { + this.tierChangeListeners.push(callback); } getTier(): ModelTier {