/*---------------------------------------------------------------------------------------------
 *  Copyright (c) Microsoft Corporation. All rights reserved.
 *  Licensed under the MIT License. See License.txt in the project root for license information.
 *--------------------------------------------------------------------------------------------*/

import type * as vscode from 'vscode';
import { raceCancellation } from '../../../base/common/async.js';
import { CancellationToken } from '../../../base/common/cancellation.js';
import { CancellationError } from '../../../base/common/errors.js';
import { IDisposable, toDisposable } from '../../../base/common/lifecycle.js';
import { revive } from '../../../base/common/marshalling.js';
import { generateUuid } from '../../../base/common/uuid.js';
import { IExtensionDescription } from '../../../platform/extensions/common/extensions.js';
import { IPreparedToolInvocation, IStreamedToolInvocation, isToolInvocationContext, IToolInvocation, IToolInvocationContext, IToolInvocationPreparationContext, IToolInvocationStreamContext, IToolResult, ToolInvocationPresentation } from '../../contrib/chat/common/tools/languageModelToolsService.js';
import { ExtensionEditToolId, InternalEditToolId } from '../../contrib/chat/common/tools/builtinTools/editFileTool.js';
import { InternalFetchWebPageToolId } from '../../contrib/chat/common/tools/builtinTools/tools.js';
import { SearchExtensionsToolId } from '../../contrib/extensions/common/searchExtensionsTool.js';
import { checkProposedApiEnabled, isProposedApiEnabled } from '../../services/extensions/common/extensions.js';
import { Dto, SerializableObjectWithBuffers } from '../../services/extensions/common/proxyIdentifier.js';
import { ExtHostLanguageModelToolsShape, IMainContext, IToolDataDto, IToolDefinitionDto, MainContext, MainThreadLanguageModelToolsShape } from './extHost.protocol.js';
import { ExtHostLanguageModels } from './extHostLanguageModels.js';
import * as typeConvert from './extHostTypeConverters.js';
import { URI } from '../../../base/common/uri.js';

class Tool {

	private _data: IToolDataDto;
	private _apiObject: vscode.LanguageModelToolInformation | undefined;
	private _apiObjectWithChatParticipantAdditions: vscode.LanguageModelToolInformation | undefined;

	constructor(data: IToolDataDto) {
		this._data = data;
	}

	update(newData: IToolDataDto): void {
		this._data = newData;
		this._apiObject = undefined;
		this._apiObjectWithChatParticipantAdditions = undefined;
	}

	get data(): IToolDataDto {
		return this._data;
	}

	get apiObject(): vscode.LanguageModelToolInformation {
		if (!this._apiObject) {
			this._apiObject = Object.freeze({
				name: this._data.id,
				description: this._data.modelDescription,
				inputSchema: this._data.inputSchema,
				tags: this._data.tags ?? [],
				source: undefined
			});
		}
		return this._apiObject;
	}

	get apiObjectWithChatParticipantAdditions() {
		if (!this._apiObjectWithChatParticipantAdditions) {
			this._apiObjectWithChatParticipantAdditions = Object.freeze({
				name: this._data.id,
				description: this._data.modelDescription,
				inputSchema: this._data.inputSchema,
				tags: this._data.tags ?? [],
				source: typeConvert.LanguageModelToolSource.to(this._data.source)
			});
		}
		return this._apiObjectWithChatParticipantAdditions;
	}
}

export class ExtHostLanguageModelTools implements ExtHostLanguageModelToolsShape {
	/** A map of tools that were registered in this EH */
	private readonly _registeredTools = new Map<string, { extension: IExtensionDescription; tool: vscode.LanguageModelTool<Object> }>();
	private readonly _proxy: MainThreadLanguageModelToolsShape;
	private readonly _tokenCountFuncs = new Map</* call ID */string, (text: string, token?: vscode.CancellationToken) => Thenable<number>>();

	/** A map of all known tools, from other EHs or registered in vscode core */
	private readonly _allTools = new Map<string, Tool>();

	constructor(
		mainContext: IMainContext,
		private readonly _languageModels: ExtHostLanguageModels,
	) {
		this._proxy = mainContext.getProxy(MainContext.MainThreadLanguageModelTools);

		this._proxy.$getTools().then(tools => {
			for (const tool of tools) {
				this._allTools.set(tool.id, new Tool(revive(tool)));
			}
		});
	}

	async $countTokensForInvocation(callId: string, input: string, token: CancellationToken): Promise<number> {
		const fn = this._tokenCountFuncs.get(callId);
		if (!fn) {
			throw new Error(`Tool invocation call ${callId} not found`);
		}

		return await fn(input, token);
	}

	async invokeTool(extension: IExtensionDescription, toolIdOrInfo: string | vscode.LanguageModelToolInformation, options: vscode.LanguageModelToolInvocationOptions<any>, token?: CancellationToken): Promise<vscode.LanguageModelToolResult> {
		const toolId = typeof toolIdOrInfo === 'string' ? toolIdOrInfo : toolIdOrInfo.name;
		const callId = generateUuid();
		if (options.tokenizationOptions) {
			this._tokenCountFuncs.set(callId, options.tokenizationOptions.countTokens);
		}

		try {
			if (options.toolInvocationToken && !isToolInvocationContext(options.toolInvocationToken)) {
				throw new Error(`Invalid tool invocation token`);
			}

			if ((toolId === InternalEditToolId || toolId === ExtensionEditToolId) && !isProposedApiEnabled(extension, 'chatParticipantPrivate')) {
				throw new Error(`Invalid tool: ${toolId}`);
			}

			// Making the round trip here because not all tools were necessarily registered in this EH
			const result = await this._proxy.$invokeTool({
				toolId,
				callId,
				parameters: options.input,
				tokenBudget: options.tokenizationOptions?.tokenBudget,
				context: options.toolInvocationToken as IToolInvocationContext | undefined,
				chatRequestId: isProposedApiEnabled(extension, 'chatParticipantPrivate') ? options.chatRequestId : undefined,
				chatInteractionId: isProposedApiEnabled(extension, 'chatParticipantPrivate') ? options.chatInteractionId : undefined,
				subAgentInvocationId: isProposedApiEnabled(extension, 'chatParticipantPrivate') ? options.subAgentInvocationId : undefined,
				chatStreamToolCallId: isProposedApiEnabled(extension, 'chatParticipantAdditions') ? options.chatStreamToolCallId : undefined,
				preToolUseResult: isProposedApiEnabled(extension, 'chatParticipantPrivate') ? options.preToolUseResult : undefined,
			}, token);

			const dto: Dto<IToolResult> = result instanceof SerializableObjectWithBuffers ? result.value : result;
			return typeConvert.LanguageModelToolResult.to(revive(dto));
		} finally {
			this._tokenCountFuncs.delete(callId);
		}
	}

	$onDidChangeTools(tools: IToolDataDto[]): void {

		const oldTools = new Set(this._allTools.keys());

		for (const tool of tools) {
			oldTools.delete(tool.id);
			const existing = this._allTools.get(tool.id);
			if (existing) {
				existing.update(tool);
			} else {
				this._allTools.set(tool.id, new Tool(revive(tool)));
			}
		}

		for (const id of oldTools) {
			this._allTools.delete(id);
		}
	}

	getTools(extension: IExtensionDescription): vscode.LanguageModelToolInformation[] {
		const hasParticipantAdditions = isProposedApiEnabled(extension, 'chatParticipantPrivate');
		return Array.from(this._allTools.values())
			.map(tool => hasParticipantAdditions ? tool.apiObjectWithChatParticipantAdditions : tool.apiObject)
			.filter(tool => {
				switch (tool.name) {
					case InternalEditToolId:
					case ExtensionEditToolId:
					case InternalFetchWebPageToolId:
					case SearchExtensionsToolId:
						return isProposedApiEnabled(extension, 'chatParticipantPrivate');
					default:
						return true;
				}
			});
	}

	async $invokeTool(dto: Dto<IToolInvocation>, token: CancellationToken): Promise<Dto<IToolResult> | SerializableObjectWithBuffers<Dto<IToolResult>>> {
		const item = this._registeredTools.get(dto.toolId);
		if (!item) {
			throw new Error(`Unknown tool ${dto.toolId}`);
		}

		const options: vscode.LanguageModelToolInvocationOptions<Object> = {
			input: dto.parameters,
			toolInvocationToken: revive(dto.context) as unknown as vscode.ChatParticipantToolToken | undefined,
		};
		if (isProposedApiEnabled(item.extension, 'chatParticipantPrivate')) {
			options.chatRequestId = dto.chatRequestId;
			options.chatInteractionId = dto.chatInteractionId;
			options.chatSessionId = dto.context?.sessionId;
			options.chatSessionResource = URI.revive(dto.context?.sessionResource);
			options.subAgentInvocationId = dto.subAgentInvocationId;
		}

		if (isProposedApiEnabled(item.extension, 'chatParticipantAdditions') && dto.modelId) {
			options.model = await this.getModel(dto.modelId, item.extension);
		}
		if (isProposedApiEnabled(item.extension, 'chatParticipantAdditions') && dto.chatStreamToolCallId) {
			options.chatStreamToolCallId = dto.chatStreamToolCallId;
		}

		if (dto.tokenBudget !== undefined) {
			options.tokenizationOptions = {
				tokenBudget: dto.tokenBudget,
				countTokens: this._tokenCountFuncs.get(dto.callId) || ((value, token = CancellationToken.None) =>
					this._proxy.$countTokensForInvocation(dto.callId, value, token))
			};
		}

		let progress: vscode.Progress<{ message?: string | vscode.MarkdownString; increment?: number }> | undefined;
		if (isProposedApiEnabled(item.extension, 'toolProgress')) {
			let lastProgress: number | undefined;
			progress = {
				report: value => {
					if (value.increment !== undefined) {
						lastProgress = (lastProgress ?? 0) + value.increment;
					}

					this._proxy.$acceptToolProgress(dto.callId, {
						message: typeConvert.MarkdownString.fromStrict(value.message),
						progress: lastProgress === undefined ? undefined : lastProgress / 100,
					});
				}
			};
		}

		// todo: 'any' cast because TS can't handle the overloads
		// eslint-disable-next-line local/code-no-any-casts
		const extensionResult = await raceCancellation(Promise.resolve((item.tool.invoke as any)(options, token, progress!)), token);
		if (!extensionResult) {
			throw new CancellationError();
		}

		return typeConvert.LanguageModelToolResult.from(extensionResult, item.extension);
	}

	private async getModel(modelId: string, extension: IExtensionDescription): Promise<vscode.LanguageModelChat> {
		let model: vscode.LanguageModelChat | undefined;
		if (modelId) {
			model = await this._languageModels.getLanguageModelByIdentifier(extension, modelId);
		}
		if (!model) {
			model = await this._languageModels.getDefaultLanguageModel(extension);
			if (!model) {
				throw new Error('Language model unavailable');
			}
		}

		return model;
	}

	async $handleToolStream(toolId: string, context: IToolInvocationStreamContext, token: CancellationToken): Promise<IStreamedToolInvocation | undefined> {
		const item = this._registeredTools.get(toolId);
		if (!item) {
			throw new Error(`Unknown tool ${toolId}`);
		}

		// Only call handleToolStream if it's defined on the tool
		if (!item.tool.handleToolStream) {
			return undefined;
		}

		// Ensure the chatParticipantAdditions API is enabled
		checkProposedApiEnabled(item.extension, 'chatParticipantAdditions');

		const options: vscode.LanguageModelToolInvocationStreamOptions<any> = {
			rawInput: context.rawInput,
			chatRequestId: context.chatRequestId,
			chatSessionId: context.chatSessionId,
			chatSessionResource: context.chatSessionResource,
			chatInteractionId: context.chatInteractionId
		};

		const result = await item.tool.handleToolStream(options, token);
		if (!result) {
			return undefined;
		}

		return {
			invocationMessage: typeConvert.MarkdownString.fromStrict(result.invocationMessage)
		};
	}

	async $prepareToolInvocation(toolId: string, context: IToolInvocationPreparationContext, token: CancellationToken): Promise<IPreparedToolInvocation | undefined> {
		const item = this._registeredTools.get(toolId);
		if (!item) {
			throw new Error(`Unknown tool ${toolId}`);
		}

		const options: vscode.LanguageModelToolInvocationPrepareOptions<any> = {
			input: context.parameters,
			chatRequestId: context.chatRequestId,
			chatSessionId: context.chatSessionId,
			chatSessionResource: context.chatSessionResource,
			chatInteractionId: context.chatInteractionId,
			forceConfirmationReason: context.forceConfirmationReason
		};
		if (context.forceConfirmationReason) {
			checkProposedApiEnabled(item.extension, 'chatParticipantPrivate');
		}
		if (item.tool.prepareInvocation) {
			const result = await item.tool.prepareInvocation(options, token);
			if (!result) {
				return undefined;
			}

			if (result.pastTenseMessage || result.presentation) {
				checkProposedApiEnabled(item.extension, 'chatParticipantPrivate');
			}

			return {
				confirmationMessages: result.confirmationMessages ? {
					title: typeof result.confirmationMessages.title === 'string' ? result.confirmationMessages.title : typeConvert.MarkdownString.from(result.confirmationMessages.title),
					message: typeof result.confirmationMessages.message === 'string' ? result.confirmationMessages.message : typeConvert.MarkdownString.from(result.confirmationMessages.message),
				} : undefined,
				invocationMessage: typeConvert.MarkdownString.fromStrict(result.invocationMessage),
				pastTenseMessage: typeConvert.MarkdownString.fromStrict(result.pastTenseMessage),
				presentation: result.presentation as ToolInvocationPresentation | undefined,
			};
		}

		return undefined;
	}

	registerTool(extension: IExtensionDescription, id: string, tool: vscode.LanguageModelTool<any>): IDisposable {
		this._registeredTools.set(id, { extension, tool });
		this._proxy.$registerTool(id, typeof tool.handleToolStream === 'function');

		return toDisposable(() => {
			this._registeredTools.delete(id);
			this._proxy.$unregisterTool(id);
		});
	}

	registerToolDefinition(extension: IExtensionDescription, definition: vscode.LanguageModelToolDefinition, tool: vscode.LanguageModelTool<any>): IDisposable {
		checkProposedApiEnabled(extension, 'languageModelToolSupportsModel');

		const id = definition.name;

		// Convert the definition to a DTO
		const dto: IToolDefinitionDto = {
			id,
			displayName: definition.displayName,
			toolReferenceName: definition.toolReferenceName,
			userDescription: definition.userDescription,
			modelDescription: definition.description,
			inputSchema: definition.inputSchema as object,
			source: {
				type: 'extension',
				label: extension.displayName ?? extension.name,
				extensionId: extension.identifier,
			},
			icon: typeConvert.IconPath.from(definition.icon),
			models: definition.models,
			toolSet: definition.toolSet,
		};

		this._registeredTools.set(id, { extension, tool });
		this._proxy.$registerToolWithDefinition(extension.identifier, dto, typeof tool.handleToolStream === 'function');

		return toDisposable(() => {
			this._registeredTools.delete(id);
			this._proxy.$unregisterTool(id);
		});
	}
}
