/*
 * Decompiled with CFR 0.152.
 */
package org.graylog.shaded.opensearch2.org.apache.lucene.search.join;

import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
import org.graylog.shaded.opensearch2.org.apache.lucene.index.FloatVectorValues;
import org.graylog.shaded.opensearch2.org.apache.lucene.index.LeafReaderContext;
import org.graylog.shaded.opensearch2.org.apache.lucene.index.QueryTimeout;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.DocIdSetIterator;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.HitQueue;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.IndexSearcher;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.KnnCollector;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.KnnFloatVectorQuery;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.Query;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.ScoreDoc;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.TopDocs;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.TopDocsCollector;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.TotalHits;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.VectorScorer;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.join.BitSetProducer;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.join.DiversifyingNearestChildrenKnnCollectorManager;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.knn.KnnCollectorManager;
import org.graylog.shaded.opensearch2.org.apache.lucene.util.BitSet;
import org.graylog.shaded.opensearch2.org.apache.lucene.util.Bits;

public class DiversifyingChildrenFloatKnnVectorQuery
extends KnnFloatVectorQuery {
    private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
    private final BitSetProducer parentsFilter;
    private final Query childFilter;
    private final int k;
    private final float[] query;

    public DiversifyingChildrenFloatKnnVectorQuery(String field, float[] query, Query childFilter, int k, BitSetProducer parentsFilter) {
        super(field, query, k, childFilter);
        this.childFilter = childFilter;
        this.parentsFilter = parentsFilter;
        this.k = k;
        this.query = query;
    }

    @Override
    protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) throws IOException {
        FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(this.field);
        if (floatVectorValues == null) {
            FloatVectorValues.checkField(context.reader(), this.field);
            return NO_RESULTS;
        }
        BitSet parentBitSet = this.parentsFilter.getBitSet(context);
        if (parentBitSet == null) {
            return NO_RESULTS;
        }
        VectorScorer floatVectorScorer = floatVectorValues.scorer(this.query);
        if (floatVectorScorer == null) {
            return NO_RESULTS;
        }
        DiversifyingChildrenVectorScorer vectorScorer = new DiversifyingChildrenVectorScorer(acceptIterator, parentBitSet, floatVectorScorer);
        int queueSize = Math.min(this.k, Math.toIntExact(acceptIterator.cost()));
        HitQueue queue = new HitQueue(queueSize, true);
        TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO;
        ScoreDoc topDoc = (ScoreDoc)queue.top();
        while (vectorScorer.nextParent() != Integer.MAX_VALUE) {
            if (queryTimeout != null && queryTimeout.shouldExit()) {
                relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
                break;
            }
            float score = vectorScorer.score();
            if (!(score > topDoc.score)) continue;
            topDoc.score = score;
            topDoc.doc = vectorScorer.bestChild();
            topDoc = (ScoreDoc)queue.updateTop();
        }
        while (queue.size() > 0 && ((ScoreDoc)queue.top()).score < 0.0f) {
            queue.pop();
        }
        ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()];
        for (int i = topScoreDocs.length - 1; i >= 0; --i) {
            topScoreDocs[i] = (ScoreDoc)queue.pop();
        }
        TotalHits totalHits = new TotalHits(acceptIterator.cost(), relation);
        return new TopDocs(totalHits, topScoreDocs);
    }

    @Override
    protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
        return new DiversifyingNearestChildrenKnnCollectorManager(k, this.parentsFilter, searcher);
    }

    @Override
    protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitedLimit, KnnCollectorManager knnCollectorManager) throws IOException {
        FloatVectorValues.checkField(context.reader(), this.field);
        KnnCollector collector = knnCollectorManager.newCollector(visitedLimit, context);
        if (collector == null) {
            return NO_RESULTS;
        }
        context.reader().searchNearestVectors(this.field, this.query, collector, acceptDocs);
        return collector.topDocs();
    }

    @Override
    public String toString(String field) {
        return this.getClass().getSimpleName() + ":" + this.field + "[" + this.query[0] + ",...][" + this.k + "]";
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        DiversifyingChildrenFloatKnnVectorQuery that = (DiversifyingChildrenFloatKnnVectorQuery)o;
        return this.k == that.k && Objects.equals(this.parentsFilter, that.parentsFilter) && Objects.equals(this.childFilter, that.childFilter) && Arrays.equals(this.query, that.query);
    }

    @Override
    public int hashCode() {
        int result = Objects.hash(super.hashCode(), this.parentsFilter, this.childFilter, this.k);
        result = 31 * result + Arrays.hashCode(this.query);
        return result;
    }

    static class DiversifyingChildrenVectorScorer {
        private final VectorScorer vectorScorer;
        private final DocIdSetIterator vectorIterator;
        private final DocIdSetIterator acceptedChildrenIterator;
        private final BitSet parentBitSet;
        private int currentParent = -1;
        private int bestChild = -1;
        private float currentScore = Float.NEGATIVE_INFINITY;

        protected DiversifyingChildrenVectorScorer(DocIdSetIterator acceptedChildrenIterator, BitSet parentBitSet, VectorScorer vectorScorer) {
            this.acceptedChildrenIterator = acceptedChildrenIterator;
            this.vectorScorer = vectorScorer;
            this.vectorIterator = vectorScorer.iterator();
            this.parentBitSet = parentBitSet;
        }

        public int bestChild() {
            return this.bestChild;
        }

        public int nextParent() throws IOException {
            int nextChild = this.acceptedChildrenIterator.docID();
            if (nextChild == -1) {
                nextChild = this.acceptedChildrenIterator.nextDoc();
            }
            if (nextChild == Integer.MAX_VALUE) {
                this.currentParent = Integer.MAX_VALUE;
                return this.currentParent;
            }
            this.currentScore = Float.NEGATIVE_INFINITY;
            this.currentParent = this.parentBitSet.nextSetBit(nextChild);
            do {
                this.vectorIterator.advance(nextChild);
                float score = this.vectorScorer.score();
                if (!(score > this.currentScore)) continue;
                this.bestChild = nextChild;
                this.currentScore = score;
            } while ((nextChild = this.acceptedChildrenIterator.nextDoc()) != Integer.MAX_VALUE && nextChild < this.currentParent);
            return this.currentParent;
        }

        public float score() throws IOException {
            return this.currentScore;
        }
    }
}

