/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.contextmanager;

import java.lang.invoke.CallSite;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.contextmanager.ActivationRule;
import org.opensearch.ml.common.contextmanager.ActivationRuleFactory;
import org.opensearch.ml.common.contextmanager.ContextManager;
import org.opensearch.ml.common.contextmanager.ContextManagerContext;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.engine.algorithms.contextmanager.ContextManagerUtils;
import org.opensearch.transport.client.Client;

public class SummarizationManager
implements ContextManager {
    @Generated
    private static final Logger log = LogManager.getLogger(SummarizationManager.class);
    public static final String TYPE = "SummarizationManager";
    private static final String SUMMARY_RATIO_KEY = "summary_ratio";
    private static final String PRESERVE_RECENT_MESSAGES_KEY = "preserve_recent_messages";
    private static final String SUMMARIZATION_MODEL_ID_KEY = "summarization_model_id";
    private static final String SUMMARIZATION_SYSTEM_PROMPT_KEY = "summarization_system_prompt";
    private static final double DEFAULT_SUMMARY_RATIO = 0.3;
    private static final int DEFAULT_PRESERVE_RECENT_MESSAGES = 10;
    private static final String DEFAULT_SUMMARIZATION_PROMPT = "You are a interactions summarization agent. Summarize the provided interactions concisely while preserving key information and context.";
    protected double summaryRatio;
    protected int preserveRecentMessages;
    protected String summarizationModelId;
    protected String summarizationSystemPrompt;
    protected List<ActivationRule> activationRules;
    private Client client;

    public SummarizationManager(Client client) {
        this.client = client;
    }

    public String getType() {
        return TYPE;
    }

    public void initialize(Map<String, Object> config) {
        this.summaryRatio = this.parseDoubleConfig(config, SUMMARY_RATIO_KEY, 0.3);
        this.preserveRecentMessages = this.parseIntegerConfig(config, PRESERVE_RECENT_MESSAGES_KEY, 10);
        this.summarizationModelId = (String)config.get(SUMMARIZATION_MODEL_ID_KEY);
        this.summarizationSystemPrompt = (String)config.getOrDefault(SUMMARIZATION_SYSTEM_PROMPT_KEY, DEFAULT_SUMMARIZATION_PROMPT);
        if (this.summaryRatio < 0.1 || this.summaryRatio > 0.8) {
            log.warn("Invalid summary_ratio value: {}, using default {}", (Object)this.summaryRatio, (Object)0.3);
            this.summaryRatio = 0.3;
        }
        Map activationConfig = (Map)config.get("activation");
        this.activationRules = ActivationRuleFactory.createRules((Map)activationConfig);
        log.info("Initialized SummarizationManager: summaryRatio={}, preserveRecentMessages={}", (Object)this.summaryRatio, (Object)this.preserveRecentMessages);
    }

    public boolean shouldActivate(ContextManagerContext context) {
        if (this.activationRules == null || this.activationRules.isEmpty()) {
            return true;
        }
        for (ActivationRule rule : this.activationRules) {
            if (rule.evaluate(context)) continue;
            log.debug("Activation rule not satisfied: {}", (Object)rule.getDescription());
            return false;
        }
        log.debug("All activation rules satisfied, manager will execute");
        return true;
    }

    public void execute(ContextManagerContext context) {
        Map parameters;
        List interactions = context.getToolInteractions();
        if (interactions == null || interactions.isEmpty()) {
            return;
        }
        if (interactions.isEmpty()) {
            log.debug("No string interactions found in tool interactions");
            return;
        }
        int totalMessages = interactions.size();
        int messagesToSummarizeCount = Math.max(1, (int)((double)totalMessages * this.summaryRatio));
        if ((messagesToSummarizeCount = Math.min(messagesToSummarizeCount, totalMessages - this.preserveRecentMessages)) <= 0) {
            return;
        }
        int safeCutPoint = ContextManagerUtils.findSafePoint(interactions, messagesToSummarizeCount, false);
        if (safeCutPoint <= 0) {
            return;
        }
        ArrayList messagesToSummarize = new ArrayList(interactions.subList(0, safeCutPoint));
        ArrayList<String> remainingMessages = new ArrayList<String>(interactions.subList(safeCutPoint, totalMessages));
        String modelId = this.summarizationModelId;
        if (modelId == null && (parameters = context.getParameters()) != null) {
            modelId = (String)parameters.get("_llm_model_id");
        }
        if (modelId == null) {
            log.error("No model ID available for summarization");
            return;
        }
        HashMap<String, String> summarizationParameters = new HashMap<String, String>();
        summarizationParameters.put("prompt", "Help summarize the following" + StringUtils.toJson((Object)String.join((CharSequence)",", messagesToSummarize)));
        summarizationParameters.put("system_prompt", this.summarizationSystemPrompt);
        this.executeSummarization(context, modelId, summarizationParameters, safeCutPoint, remainingMessages, interactions);
    }

    protected void executeSummarization(ContextManagerContext context, String modelId, Map<String, String> summarizationParameters, int messagesToSummarizeCount, List<String> remainingMessages, List<String> originalInteractions) {
        CountDownLatch latch = new CountDownLatch(1);
        AtomicBoolean timedOut = new AtomicBoolean(false);
        try {
            RemoteInferenceInputDataSet inputDataset = RemoteInferenceInputDataSet.builder().parameters(summarizationParameters).build();
            MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset((MLInputDataset)inputDataset).build();
            MLPredictionTaskRequest request = MLPredictionTaskRequest.builder().modelId(modelId).mlInput(mlInput).build();
            ActionListener listener = ActionListener.wrap(response -> {
                try {
                    if (timedOut.get()) {
                        return;
                    }
                    String summary = this.extractSummaryFromResponse((MLTaskResponse)response, context);
                    if (summary != null) {
                        this.processSummarizationResult(context, summary, messagesToSummarizeCount, remainingMessages, originalInteractions);
                    } else {
                        log.warn("Summary extraction failed, keeping original interactions");
                    }
                }
                catch (Exception e) {
                    log.warn("Summarization failed, keeping original interactions: {}", (Object)e.getMessage());
                }
                finally {
                    latch.countDown();
                }
            }, e -> {
                if (!timedOut.get()) {
                    log.warn("Summarization request failed, keeping original interactions: {}", (Object)e.getMessage());
                }
                latch.countDown();
            });
            this.client.execute((ActionType)MLPredictionTaskAction.INSTANCE, (ActionRequest)request, listener);
            boolean finished = latch.await(30L, TimeUnit.SECONDS);
            if (!finished) {
                timedOut.set(true);
                log.warn("Summarization timed out after 30s; skipping late results");
            }
        }
        catch (Exception e2) {
            log.warn("Summarization setup failed, keeping original interactions: {}", (Object)e2.getMessage());
        }
    }

    protected void processSummarizationResult(ContextManagerContext context, String summary, int messagesToSummarizeCount, List<String> remainingMessages, List<String> originalInteractions) {
        try {
            String summarizedInteraction = "{\"role\":\"assistant\",\"content\":[{\"type\": \"text\", \"text\": \"Summarized previous interactions: " + StringUtils.processTextDoc((String)summary) + "\"}]}";
            ArrayList<Object> updatedInteractions = new ArrayList<Object>();
            updatedInteractions.add(summarizedInteraction);
            updatedInteractions.addAll(remainingMessages);
            context.setToolInteractions(updatedInteractions);
            HashMap<String, CallSite> parameters = context.getParameters();
            if (parameters == null) {
                parameters = new HashMap<String, CallSite>();
            }
            parameters.put("_interactions", (CallSite)((Object)(", " + String.join((CharSequence)", ", updatedInteractions))));
            context.setParameters(parameters);
            log.info("Summarization completed: {} messages summarized, {} messages preserved", (Object)messagesToSummarizeCount, (Object)remainingMessages.size());
        }
        catch (Exception e) {
            log.error("Failed to process summarization result", (Throwable)e);
        }
    }

    /*
     * Exception decompiling
     */
    private String extractSummaryFromResponse(MLTaskResponse response, ContextManagerContext context) {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Tried to end blocks [3[CATCHBLOCK]], but top level block is 2[TRYBLOCK]
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.processEndingBlocks(Op04StructuredStatement.java:435)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:484)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    private double parseDoubleConfig(Map<String, Object> config, String key, double defaultValue) {
        Object value = config.get(key);
        if (value == null) {
            return defaultValue;
        }
        try {
            if (value instanceof Double) {
                return (Double)value;
            }
            if (value instanceof Number) {
                return ((Number)value).doubleValue();
            }
            if (value instanceof String) {
                return Double.parseDouble((String)value);
            }
            log.warn("Invalid type for config key '{}': {}, using default {}", (Object)key, (Object)value.getClass().getSimpleName(), (Object)defaultValue);
            return defaultValue;
        }
        catch (NumberFormatException e) {
            log.warn("Invalid double value for config key '{}': {}, using default {}", (Object)key, value, (Object)defaultValue);
            return defaultValue;
        }
    }

    private int parseIntegerConfig(Map<String, Object> config, String key, int defaultValue) {
        Object value = config.get(key);
        if (value == null) {
            return defaultValue;
        }
        try {
            if (value instanceof Integer) {
                return (Integer)value;
            }
            if (value instanceof Number) {
                return ((Number)value).intValue();
            }
            if (value instanceof String) {
                return Integer.parseInt((String)value);
            }
            log.warn("Invalid type for config key '{}': {}, using default {}", (Object)key, (Object)value.getClass().getSimpleName(), (Object)defaultValue);
            return defaultValue;
        }
        catch (NumberFormatException e) {
            log.warn("Invalid integer value for config key '{}': {}, using default {}", (Object)key, value, (Object)defaultValue);
            return defaultValue;
        }
    }
}

