package org.apache.cassandra.index.sai.disk.v1.segment;

import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import io.github.jbellis.jvector.util.Bits;
import io.github.jbellis.jvector.util.SparseFixedBitSet;
import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.agrona.collections.IntArrayList;
import org.apache.cassandra.db.PartitionPosition;
import org.apache.cassandra.dht.AbstractBounds;
import org.apache.cassandra.index.sai.QueryContext;
import org.apache.cassandra.index.sai.StorageAttachedIndex;
import org.apache.cassandra.index.sai.VectorQueryContext;
import org.apache.cassandra.index.sai.disk.PrimaryKeyMap;
import org.apache.cassandra.index.sai.disk.v1.PerColumnIndexFiles;
import org.apache.cassandra.index.sai.disk.v1.postings.VectorPostingList;
import org.apache.cassandra.index.sai.disk.v1.vector.DiskAnn;
import org.apache.cassandra.index.sai.disk.v1.vector.OnDiskOrdinalsMap;
import org.apache.cassandra.index.sai.disk.v1.vector.OptimizeFor;
import org.apache.cassandra.index.sai.iterators.KeyRangeIterator;
import org.apache.cassandra.index.sai.iterators.KeyRangeListIterator;
import org.apache.cassandra.index.sai.memory.VectorMemoryIndex;
import org.apache.cassandra.index.sai.plan.Expression;
import org.apache.cassandra.index.sai.postings.IntArrayPostingList;
import org.apache.cassandra.index.sai.postings.PostingList;
import org.apache.cassandra.index.sai.utils.AtomicRatio;
import org.apache.cassandra.index.sai.utils.PrimaryKey;
import org.apache.cassandra.index.sai.utils.RangeUtil;
import org.apache.cassandra.tracing.Tracing;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/cassandra/index/sai/disk/v1/segment/VectorIndexSegmentSearcher.class */
public class VectorIndexSegmentSearcher extends IndexSegmentSearcher {
    private static final Logger logger;
    private final DiskAnn graph;
    private final int globalBruteForceRows;
    private final AtomicRatio actualExpectedRatio;
    private final ThreadLocal<SparseFixedBitSet> cachedBitSets;
    private final OptimizeFor optimizeFor;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/cassandra/index/sai/disk/v1/segment/VectorIndexSegmentSearcher$BitsOrPostingList.class */
    public static class BitsOrPostingList {
        private final Bits bits;
        private final int expectedNodesVisited;
        private final PostingList postingList;

        public BitsOrPostingList(@Nullable Bits bits, int i) {
            this.bits = bits;
            this.expectedNodesVisited = i;
            this.postingList = null;
        }

        public BitsOrPostingList(@Nullable Bits bits) {
            this.bits = bits;
            this.postingList = null;
            this.expectedNodesVisited = -1;
        }

        public BitsOrPostingList(PostingList postingList) {
            this.bits = null;
            this.postingList = (PostingList) Preconditions.checkNotNull(postingList);
            this.expectedNodesVisited = -1;
        }

        @Nullable
        public Bits getBits() {
            Preconditions.checkState(!skipANN());
            return this.bits;
        }

        public PostingList postingList() {
            Preconditions.checkState(skipANN());
            return this.postingList;
        }

        public boolean skipANN() {
            return this.postingList != null;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public VectorIndexSegmentSearcher(PrimaryKeyMap.Factory factory, PerColumnIndexFiles perColumnIndexFiles, SegmentMetadata segmentMetadata, StorageAttachedIndex storageAttachedIndex) throws IOException {
        super(factory, perColumnIndexFiles, segmentMetadata, storageAttachedIndex);
        this.actualExpectedRatio = new AtomicRatio();
        this.graph = new DiskAnn(segmentMetadata.componentMetadatas, perColumnIndexFiles, storageAttachedIndex.indexWriterConfig());
        this.cachedBitSets = ThreadLocal.withInitial(() -> {
            return new SparseFixedBitSet(this.graph.size());
        });
        this.globalBruteForceRows = Integer.MAX_VALUE;
        this.optimizeFor = storageAttachedIndex.indexWriterConfig().getOptimizeFor();
    }

    @Override // org.apache.cassandra.index.sai.disk.v1.segment.IndexSegmentSearcher
    public long indexFileCacheSize() {
        return this.graph.ramBytesUsed();
    }

    @Override // org.apache.cassandra.index.sai.disk.v1.segment.IndexSegmentSearcher
    public KeyRangeIterator search(Expression expression, AbstractBounds<PartitionPosition> abstractBounds, QueryContext queryContext) throws IOException {
        int limit = queryContext.vectorContext().limit();
        if (logger.isTraceEnabled()) {
            logger.trace(this.index.identifier().logMessage("Searching on expression '{}'..."), expression);
        }
        if (expression.getIndexOperator() != Expression.IndexOperator.ANN) {
            throw new IllegalArgumentException(this.index.identifier().logMessage("Unsupported expression during ANN index query: " + expression));
        }
        int i = this.optimizeFor.topKFor(limit);
        BitsOrPostingList bitsOrPostingListForKeyRange = bitsOrPostingListForKeyRange(queryContext.vectorContext(), abstractBounds, i);
        if (bitsOrPostingListForKeyRange.skipANN()) {
            return toPrimaryKeyIterator(bitsOrPostingListForKeyRange.postingList(), queryContext);
        }
        VectorPostingList search = this.graph.search(this.index.termType().decomposeVector(expression.lower().value.raw.duplicate()), i, limit, bitsOrPostingListForKeyRange.getBits());
        if (bitsOrPostingListForKeyRange.expectedNodesVisited >= 0) {
            updateExpectedNodes(search.getVisitedCount(), bitsOrPostingListForKeyRange.expectedNodesVisited);
        }
        return toPrimaryKeyIterator(search, queryContext);
    }

    private BitsOrPostingList bitsOrPostingListForKeyRange(VectorQueryContext vectorQueryContext, AbstractBounds<PartitionPosition> abstractBounds, int i) throws IOException {
        int ordinalForRowId;
        PrimaryKeyMap newPerSSTablePrimaryKeyMap = this.primaryKeyMapFactory.newPerSSTablePrimaryKeyMap();
        try {
            if (RangeUtil.coversFullRing(abstractBounds)) {
                BitsOrPostingList bitsOrPostingList = new BitsOrPostingList(vectorQueryContext.bitsetForShadowedPrimaryKeys(this.metadata, newPerSSTablePrimaryKeyMap, this.graph));
                if (newPerSSTablePrimaryKeyMap != null) {
                    newPerSSTablePrimaryKeyMap.close();
                }
                return bitsOrPostingList;
            }
            long ceiling = newPerSSTablePrimaryKeyMap.ceiling(abstractBounds.left.getToken());
            if (ceiling < 0) {
                BitsOrPostingList bitsOrPostingList2 = new BitsOrPostingList(PostingList.EMPTY);
                if (newPerSSTablePrimaryKeyMap != null) {
                    newPerSSTablePrimaryKeyMap.close();
                }
                return bitsOrPostingList2;
            }
            long maxSSTableRowId = getMaxSSTableRowId(newPerSSTablePrimaryKeyMap, abstractBounds.right);
            if (ceiling > maxSSTableRowId) {
                BitsOrPostingList bitsOrPostingList3 = new BitsOrPostingList(PostingList.EMPTY);
                if (newPerSSTablePrimaryKeyMap != null) {
                    newPerSSTablePrimaryKeyMap.close();
                }
                return bitsOrPostingList3;
            }
            if (ceiling <= this.metadata.minSSTableRowId && maxSSTableRowId >= this.metadata.maxSSTableRowId) {
                BitsOrPostingList bitsOrPostingList4 = new BitsOrPostingList(vectorQueryContext.bitsetForShadowedPrimaryKeys(this.metadata, newPerSSTablePrimaryKeyMap, this.graph));
                if (newPerSSTablePrimaryKeyMap != null) {
                    newPerSSTablePrimaryKeyMap.close();
                }
                return bitsOrPostingList4;
            }
            long max = Math.max(ceiling, this.metadata.minSSTableRowId);
            long min = Math.min(maxSSTableRowId, this.metadata.maxSSTableRowId);
            int intExact = Math.toIntExact((min - max) + 1);
            int min2 = Math.min(this.globalBruteForceRows, maxBruteForceRows(i, intExact, this.graph.size()));
            logger.trace("Search range covers {} rows; max brute force rows is {} for sstable index with {} nodes, LIMIT {}", new Object[]{Integer.valueOf(intExact), Integer.valueOf(min2), Integer.valueOf(this.graph.size()), Integer.valueOf(i)});
            Tracing.trace("Search range covers {} rows; max brute force rows is {} for sstable index with {} nodes, LIMIT {}", Integer.valueOf(intExact), Integer.valueOf(min2), Integer.valueOf(this.graph.size()), Integer.valueOf(i));
            if (intExact <= min2) {
                IntArrayList intArrayList = new IntArrayList(Math.toIntExact(intExact), -1);
                for (long j = max; j <= min; j++) {
                    if (vectorQueryContext.shouldInclude(j, newPerSSTablePrimaryKeyMap)) {
                        intArrayList.addInt(this.metadata.toSegmentRowId(j));
                    }
                }
                BitsOrPostingList bitsOrPostingList5 = new BitsOrPostingList(new IntArrayPostingList(intArrayList.toIntArray()));
                if (newPerSSTablePrimaryKeyMap != null) {
                    newPerSSTablePrimaryKeyMap.close();
                }
                return bitsOrPostingList5;
            }
            SparseFixedBitSet bitSetForSearch = bitSetForSearch();
            boolean z = false;
            try {
                OnDiskOrdinalsMap.OrdinalsView ordinalsView = this.graph.getOrdinalsView();
                for (long j2 = max; j2 <= min; j2++) {
                    try {
                        if (vectorQueryContext.shouldInclude(j2, newPerSSTablePrimaryKeyMap) && (ordinalForRowId = ordinalsView.getOrdinalForRowId(this.metadata.toSegmentRowId(j2))) >= 0) {
                            bitSetForSearch.set(ordinalForRowId);
                            z = true;
                        }
                    } catch (Throwable th) {
                        if (ordinalsView != null) {
                            try {
                                ordinalsView.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        }
                        throw th;
                    }
                }
                if (ordinalsView != null) {
                    ordinalsView.close();
                }
                if (z) {
                    BitsOrPostingList bitsOrPostingList6 = new BitsOrPostingList(bitSetForSearch, VectorMemoryIndex.expectedNodesVisited(i, intExact, this.graph.size()));
                    if (newPerSSTablePrimaryKeyMap != null) {
                        newPerSSTablePrimaryKeyMap.close();
                    }
                    return bitsOrPostingList6;
                }
                BitsOrPostingList bitsOrPostingList7 = new BitsOrPostingList(PostingList.EMPTY);
                if (newPerSSTablePrimaryKeyMap != null) {
                    newPerSSTablePrimaryKeyMap.close();
                }
                return bitsOrPostingList7;
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        } catch (Throwable th3) {
            if (newPerSSTablePrimaryKeyMap != null) {
                try {
                    newPerSSTablePrimaryKeyMap.close();
                } catch (Throwable th4) {
                    th3.addSuppressed(th4);
                }
            }
            throw th3;
        }
    }

    private long getMaxSSTableRowId(PrimaryKeyMap primaryKeyMap, PartitionPosition partitionPosition) {
        if (partitionPosition.isMinimum()) {
            return this.metadata.maxSSTableRowId;
        }
        long floor = primaryKeyMap.floor(partitionPosition.getToken());
        return floor < 0 ? this.metadata.maxSSTableRowId : floor;
    }

    private SparseFixedBitSet bitSetForSearch() {
        SparseFixedBitSet sparseFixedBitSet = this.cachedBitSets.get();
        sparseFixedBitSet.clear();
        return sparseFixedBitSet;
    }

    @Override // org.apache.cassandra.index.sai.disk.v1.segment.SegmentOrdering
    public KeyRangeIterator limitToTopKResults(QueryContext queryContext, List<PrimaryKey> list, Expression expression) throws IOException {
        int limit = queryContext.vectorContext().limit();
        List list2 = (List) list.stream().dropWhile(primaryKey -> {
            return primaryKey.compareTo(this.metadata.minKey) < 0;
        }).takeWhile(primaryKey2 -> {
            return primaryKey2.compareTo(this.metadata.maxKey) <= 0;
        }).collect(Collectors.toList());
        if (list2.isEmpty()) {
            return KeyRangeIterator.empty();
        }
        int i = this.optimizeFor.topKFor(limit);
        if (shouldUseBruteForce(i, limit, list2.size())) {
            return new KeyRangeListIterator(this.metadata.minKey, this.metadata.maxKey, list2);
        }
        PrimaryKeyMap newPerSSTablePrimaryKeyMap = this.primaryKeyMapFactory.newPerSSTablePrimaryKeyMap();
        try {
            int segmentRowId = this.metadata.toSegmentRowId(this.metadata.maxSSTableRowId);
            SparseFixedBitSet bitSetForSearch = bitSetForSearch();
            IntArrayList intArrayList = new IntArrayList();
            OnDiskOrdinalsMap.OrdinalsView ordinalsView = this.graph.getOrdinalsView();
            try {
                Iterator it = list2.iterator();
                while (it.hasNext()) {
                    long rowIdFromPrimaryKey = newPerSSTablePrimaryKeyMap.rowIdFromPrimaryKey((PrimaryKey) it.next());
                    if (rowIdFromPrimaryKey >= this.metadata.minSSTableRowId) {
                        if (rowIdFromPrimaryKey > this.metadata.maxSSTableRowId) {
                            break;
                        }
                        int segmentRowId2 = this.metadata.toSegmentRowId(rowIdFromPrimaryKey);
                        intArrayList.add(Integer.valueOf(segmentRowId2));
                        int ordinalForRowId = ordinalsView.getOrdinalForRowId(segmentRowId2);
                        if (ordinalForRowId >= 0) {
                            bitSetForSearch.set(ordinalForRowId);
                        }
                    }
                }
                if (ordinalsView != null) {
                    ordinalsView.close();
                }
                if (shouldUseBruteForce(i, limit, intArrayList.size())) {
                    KeyRangeIterator primaryKeyIterator = toPrimaryKeyIterator(new IntArrayPostingList(intArrayList.toIntArray()), queryContext);
                    if (newPerSSTablePrimaryKeyMap != null) {
                        newPerSSTablePrimaryKeyMap.close();
                    }
                    return primaryKeyIterator;
                }
                VectorPostingList search = this.graph.search(this.index.termType().decomposeVector(expression.lower().value.raw.duplicate()), i, limit, bitSetForSearch);
                updateExpectedNodes(search.getVisitedCount(), expectedNodesVisited(i, segmentRowId, this.graph.size()));
                KeyRangeIterator primaryKeyIterator2 = toPrimaryKeyIterator(search, queryContext);
                if (newPerSSTablePrimaryKeyMap != null) {
                    newPerSSTablePrimaryKeyMap.close();
                }
                return primaryKeyIterator2;
            } finally {
            }
        } catch (Throwable th) {
            if (newPerSSTablePrimaryKeyMap != null) {
                try {
                    newPerSSTablePrimaryKeyMap.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private boolean shouldUseBruteForce(int i, int i2, int i3) {
        int min = Math.min(this.globalBruteForceRows, maxBruteForceRows(i, i3, this.graph.size()));
        logger.trace("SAI materialized {} rows; max brute force rows is {} for sstable index with {} nodes, LIMIT {}", new Object[]{Integer.valueOf(i3), Integer.valueOf(min), Integer.valueOf(this.graph.size()), Integer.valueOf(i2)});
        Tracing.trace("SAI materialized {} rows; max brute force rows is {} for sstable index with {} nodes, LIMIT {}", Integer.valueOf(i3), Integer.valueOf(min), Integer.valueOf(this.graph.size()), Integer.valueOf(i2));
        return i3 <= min;
    }

    private int maxBruteForceRows(int i, int i2, int i3) {
        return Math.max(i, expectedNodesVisited(i, i2, i3));
    }

    private int expectedNodesVisited(int i, int i2, int i3) {
        return (int) ((this.actualExpectedRatio.getUpdateCount() >= 10 ? this.actualExpectedRatio.get() : 1.0d) * VectorMemoryIndex.expectedNodesVisited(i, i2, i3));
    }

    private void updateExpectedNodes(int i, int i2) {
        if (!$assertionsDisabled && i2 < 0) {
            throw new AssertionError(i2);
        }
        if (!$assertionsDisabled && i < 0) {
            throw new AssertionError(i);
        }
        if ((i >= 1000 && i > 2 * i2) || i2 > 2 * i) {
            logger.warn("Predicted visiting {} nodes, but actually visited {}", Integer.valueOf(i2), Integer.valueOf(i));
        }
        this.actualExpectedRatio.update(i, i2);
    }

    public String toString() {
        return MoreObjects.toStringHelper(this).add("index", this.index).toString();
    }

    @Override // java.io.Closeable, java.lang.AutoCloseable
    public void close() throws IOException {
        this.graph.close();
    }

    static {
        $assertionsDisabled = !VectorIndexSegmentSearcher.class.desiredAssertionStatus();
        logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
    }
}
