/*
 * Decompiled with CFR 0.152.
 */
package org.graylog.shaded.opensearch2.org.apache.lucene.util.hnsw;

import java.io.IOException;
import java.util.Arrays;
import java.util.function.LongConsumer;
import org.graylog.shaded.opensearch2.org.apache.lucene.internal.hppc.MaxSizedFloatArrayList;
import org.graylog.shaded.opensearch2.org.apache.lucene.internal.hppc.MaxSizedIntArrayList;
import org.graylog.shaded.opensearch2.org.apache.lucene.util.RamUsageEstimator;
import org.graylog.shaded.opensearch2.org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.graylog.shaded.opensearch2.org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;

public class NeighborArray {
    private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(NeighborArray.class);
    private final boolean scoresDescOrder;
    private int size;
    private final int maxSize;
    private final MaxSizedFloatArrayList scores;
    private final MaxSizedIntArrayList nodes;
    private int sortedNodeSize;
    private long ramBytesUsed = BASE_RAM_BYTES_USED;
    private final LongConsumer onHeapMemoryUsageListener;

    public NeighborArray(int maxSize, boolean descOrder) {
        this(maxSize, descOrder, null);
    }

    public NeighborArray(int maxSize, boolean descOrder, LongConsumer onHeapMemoryUsageListener) {
        this.maxSize = maxSize;
        this.nodes = new MaxSizedIntArrayList(maxSize, maxSize / 8);
        this.scores = new MaxSizedFloatArrayList(maxSize, maxSize / 8);
        this.ramBytesUsed += this.nodes.ramBytesUsed() + this.scores.ramBytesUsed();
        this.scoresDescOrder = descOrder;
        this.onHeapMemoryUsageListener = onHeapMemoryUsageListener;
        if (onHeapMemoryUsageListener != null) {
            onHeapMemoryUsageListener.accept(this.ramBytesUsed);
        }
    }

    public void addInOrder(int newNode, float newScore) {
        assert (this.size == this.sortedNodeSize) : "cannot call addInOrder after addOutOfOrder";
        if (this.size == this.maxSize) {
            throw new IllegalStateException("No growth is allowed");
        }
        if (this.size > 0) {
            float previousScore = this.scores.get(this.size - 1);
            assert (this.scoresDescOrder && previousScore >= newScore || !this.scoresDescOrder && previousScore <= newScore) : "Nodes are added in the incorrect order! Comparing " + newScore + " to " + Arrays.toString(this.scores.toArray());
        }
        int previousLength = this.nodes.buffer.length;
        this.nodes.add(newNode);
        this.scores.add(newScore);
        this.alertOnHeapMemoryUsageChange(this.nodes.buffer.length, previousLength);
        ++this.size;
        ++this.sortedNodeSize;
    }

    public void addOutOfOrder(int newNode, float newScore) {
        if (this.size == this.maxSize) {
            throw new IllegalStateException("No growth is allowed");
        }
        int previousLength = this.nodes.buffer.length;
        this.nodes.add(newNode);
        this.scores.add(newScore);
        this.alertOnHeapMemoryUsageChange(this.nodes.buffer.length, previousLength);
        ++this.size;
    }

    private void alertOnHeapMemoryUsageChange(int newLength, int previousLength) {
        if (newLength > previousLength && this.onHeapMemoryUsageListener != null) {
            int lengthDelta = newLength - previousLength;
            this.onHeapMemoryUsageListener.accept((long)lengthDelta * 4L + (long)lengthDelta * 4L);
        }
    }

    public void addAndEnsureDiversity(int newNode, float newScore, int nodeId, RandomVectorScorerSupplier scorerSupplier) throws IOException {
        this.addOutOfOrder(newNode, newScore);
        if (this.size < this.maxSize) {
            return;
        }
        this.removeIndex(this.findWorstNonDiverse(nodeId, scorerSupplier));
        assert (this.size == this.maxSize - 1);
    }

    int[] sort(RandomVectorScorer scorer) throws IOException {
        if (this.size == this.sortedNodeSize) {
            return null;
        }
        assert (this.sortedNodeSize < this.size);
        int[] uncheckedIndexes = new int[this.size - this.sortedNodeSize];
        int count = 0;
        while (this.sortedNodeSize != this.size) {
            uncheckedIndexes[count] = this.insertSortedInternal(scorer);
            for (int i = 0; i < count; ++i) {
                if (uncheckedIndexes[i] < uncheckedIndexes[count]) continue;
                int n = i;
                uncheckedIndexes[n] = uncheckedIndexes[n] + 1;
            }
            ++count;
        }
        Arrays.sort(uncheckedIndexes);
        return uncheckedIndexes;
    }

    private int insertSortedInternal(RandomVectorScorer scorer) throws IOException {
        assert (this.sortedNodeSize < this.size) : "Call this method only when there's unsorted node";
        int tmpNode = this.nodes.get(this.sortedNodeSize);
        float tmpScore = this.scores.get(this.sortedNodeSize);
        if (Float.isNaN(tmpScore)) {
            tmpScore = scorer.score(tmpNode);
        }
        int insertionPoint = this.scoresDescOrder ? this.descSortFindRightMostInsertionPoint(tmpScore, this.sortedNodeSize) : this.ascSortFindRightMostInsertionPoint(tmpScore, this.sortedNodeSize);
        System.arraycopy(this.nodes.buffer, insertionPoint, this.nodes.buffer, insertionPoint + 1, this.sortedNodeSize - insertionPoint);
        System.arraycopy(this.scores.buffer, insertionPoint, this.scores.buffer, insertionPoint + 1, this.sortedNodeSize - insertionPoint);
        this.nodes.buffer[insertionPoint] = tmpNode;
        this.scores.buffer[insertionPoint] = tmpScore;
        ++this.sortedNodeSize;
        return insertionPoint;
    }

    void insertSorted(int newNode, float newScore) throws IOException {
        this.addOutOfOrder(newNode, newScore);
        this.insertSortedInternal(null);
    }

    public int size() {
        return this.size;
    }

    public int[] nodes() {
        return this.nodes.buffer;
    }

    public float getScores(int i) {
        return this.scores.get(i);
    }

    public void clear() {
        this.size = 0;
        this.sortedNodeSize = 0;
        this.nodes.clear();
        this.scores.clear();
    }

    void removeLast() {
        this.nodes.removeLast();
        this.scores.removeLast();
        --this.size;
        this.sortedNodeSize = Math.min(this.sortedNodeSize, this.size);
    }

    void removeIndex(int idx) {
        if (idx == this.size - 1) {
            this.removeLast();
            return;
        }
        this.nodes.removeAt(idx);
        this.scores.removeAt(idx);
        if (idx < this.sortedNodeSize) {
            --this.sortedNodeSize;
        }
        --this.size;
    }

    public String toString() {
        return "NeighborArray[" + this.size + "]";
    }

    private int ascSortFindRightMostInsertionPoint(float newScore, int bound) {
        int insertionPoint = Arrays.binarySearch(this.scores.buffer, 0, bound, newScore);
        if (insertionPoint >= 0) {
            while (insertionPoint < bound - 1 && this.scores.get(insertionPoint + 1) == this.scores.get(insertionPoint)) {
                ++insertionPoint;
            }
            ++insertionPoint;
        } else {
            insertionPoint = -insertionPoint - 1;
        }
        return insertionPoint;
    }

    private int descSortFindRightMostInsertionPoint(float newScore, int bound) {
        int start = 0;
        int end = bound - 1;
        while (start <= end) {
            int mid = (start + end) / 2;
            if (this.scores.get(mid) < newScore) {
                end = mid - 1;
                continue;
            }
            start = mid + 1;
        }
        return start;
    }

    private int findWorstNonDiverse(int nodeOrd, RandomVectorScorerSupplier scorerSupplier) throws IOException {
        RandomVectorScorer scorer = scorerSupplier.scorer(nodeOrd);
        int[] uncheckedIndexes = this.sort(scorer);
        assert (uncheckedIndexes != null) : "We will always have something unchecked";
        int uncheckedCursor = uncheckedIndexes.length - 1;
        for (int i = this.size - 1; i > 0 && uncheckedCursor >= 0; --i) {
            if (this.isWorstNonDiverse(i, uncheckedIndexes, uncheckedCursor, scorerSupplier)) {
                return i;
            }
            if (i != uncheckedIndexes[uncheckedCursor]) continue;
            --uncheckedCursor;
        }
        return this.size - 1;
    }

    private boolean isWorstNonDiverse(int candidateIndex, int[] uncheckedIndexes, int uncheckedCursor, RandomVectorScorerSupplier scorerSupplier) throws IOException {
        float minAcceptedSimilarity = this.scores.get(candidateIndex);
        RandomVectorScorer scorer = scorerSupplier.scorer(this.nodes.get(candidateIndex));
        if (candidateIndex == uncheckedIndexes[uncheckedCursor]) {
            for (int i = candidateIndex - 1; i >= 0; --i) {
                float neighborSimilarity = scorer.score(this.nodes.get(i));
                if (!(neighborSimilarity >= minAcceptedSimilarity)) continue;
                return true;
            }
        } else {
            assert (candidateIndex > uncheckedIndexes[uncheckedCursor]);
            for (int i = uncheckedCursor; i >= 0; --i) {
                float neighborSimilarity = scorer.score(this.nodes.get(uncheckedIndexes[i]));
                if (!(neighborSimilarity >= minAcceptedSimilarity)) continue;
                return true;
            }
        }
        return false;
    }

    public int maxSize() {
        return this.maxSize;
    }
}

