/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.search.profile.query;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import java.util.function.Supplier;
import org.apache.lucene.search.Query;
import org.opensearch.search.profile.ContextualProfileBreakdown;
import org.opensearch.search.profile.ProfileMetric;
import org.opensearch.search.profile.ProfileResult;
import org.opensearch.search.profile.Timer;
import org.opensearch.search.profile.query.AbstractQueryProfileTree;
import org.opensearch.search.profile.query.ConcurrentQueryProfileTree;
import org.opensearch.search.profile.query.QueryProfiler;

public final class ConcurrentQueryProfiler
extends QueryProfiler {
    private final Map<Long, ConcurrentQueryProfileTree> threadToProfileTree;
    private final Map<Long, LinkedList<Timer>> threadToRewriteTimers;
    private final Function<Query, Collection<Supplier<ProfileMetric>>> customPluginMetrics;

    public ConcurrentQueryProfiler() {
        this(new ConcurrentQueryProfileTree(query -> List.of()), query -> List.of());
    }

    public ConcurrentQueryProfiler(AbstractQueryProfileTree profileTree, Function<Query, Collection<Supplier<ProfileMetric>>> customPluginMetrics) {
        super(profileTree);
        long threadId = this.getCurrentThreadId();
        this.threadToProfileTree = Collections.synchronizedMap(new LinkedHashMap());
        this.threadToProfileTree.put(threadId, (ConcurrentQueryProfileTree)profileTree);
        this.threadToRewriteTimers = new ConcurrentHashMap<Long, LinkedList<Timer>>();
        this.threadToRewriteTimers.put(threadId, new LinkedList());
        this.customPluginMetrics = customPluginMetrics;
    }

    @Override
    public ContextualProfileBreakdown getQueryBreakdown(Query query) {
        ConcurrentQueryProfileTree profileTree = this.threadToProfileTree.computeIfAbsent(this.getCurrentThreadId(), k -> new ConcurrentQueryProfileTree(this.customPluginMetrics));
        return (ContextualProfileBreakdown)profileTree.getProfileBreakdown(query);
    }

    @Override
    public void pollLastElement() {
        ConcurrentQueryProfileTree concurrentProfileTree = this.threadToProfileTree.get(this.getCurrentThreadId());
        if (concurrentProfileTree != null) {
            concurrentProfileTree.pollLast();
        }
    }

    @Override
    public List<ProfileResult> getTree() {
        ArrayList<ProfileResult> profileResults = new ArrayList<ProfileResult>();
        for (Map.Entry<Long, ConcurrentQueryProfileTree> profile : this.threadToProfileTree.entrySet()) {
            profileResults.addAll(profile.getValue().getTree());
        }
        return profileResults;
    }

    @Override
    public void startRewriteTime() {
        Timer rewriteTimer = new Timer("rewrite_timer");
        this.threadToRewriteTimers.computeIfAbsent(this.getCurrentThreadId(), k -> new LinkedList()).add(rewriteTimer);
        rewriteTimer.start();
    }

    @Override
    public void stopAndAddRewriteTime() {
        Timer rewriteTimer = this.threadToRewriteTimers.get(this.getCurrentThreadId()).getLast();
        rewriteTimer.stop();
    }

    @Override
    public long getRewriteTime() {
        long totalRewriteTime = 0L;
        LinkedList<Timer> rewriteTimers = new LinkedList<Timer>();
        this.threadToRewriteTimers.values().forEach(rewriteTimers::addAll);
        LinkedList<long[]> mergedIntervals = this.mergeRewriteTimeIntervals(rewriteTimers);
        for (long[] interval : mergedIntervals) {
            totalRewriteTime += interval[1] - interval[0];
        }
        return totalRewriteTime;
    }

    LinkedList<long[]> mergeRewriteTimeIntervals(List<Timer> timers) {
        LinkedList<long[]> mergedIntervals = new LinkedList<long[]>();
        timers.sort(Comparator.comparingLong(Timer::getEarliestTimerStartTime));
        for (Timer timer : timers) {
            long startTime = timer.getEarliestTimerStartTime();
            long endTime = startTime + timer.getApproximateTiming();
            if (mergedIntervals.isEmpty() || mergedIntervals.getLast()[1] < startTime) {
                long[] interval = new long[]{startTime, endTime};
                mergedIntervals.add(interval);
                continue;
            }
            mergedIntervals.getLast()[1] = Math.max(mergedIntervals.getLast()[1], endTime);
        }
        return mergedIntervals;
    }

    private long getCurrentThreadId() {
        return Thread.currentThread().threadId();
    }
}

