package org.apache.cassandra.cql3.functions;

import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import java.nio.ByteBuffer;
import java.util.List;
import org.apache.cassandra.cql3.CQL3Type;
import org.apache.cassandra.db.marshal.AbstractType;
import org.apache.cassandra.db.marshal.FloatType;
import org.apache.cassandra.db.marshal.VectorType;
import org.apache.cassandra.exceptions.InvalidRequestException;
import org.apache.cassandra.transport.ProtocolVersion;

/* loaded from: input_file:org/apache/cassandra/cql3/functions/VectorFcts.class */
public class VectorFcts {
    public static void addFunctionsTo(NativeFunctions nativeFunctions) {
        nativeFunctions.add(createSimilarityFunctionFactory("similarity_cosine", VectorSimilarityFunction.COSINE, false));
        nativeFunctions.add(createSimilarityFunctionFactory("similarity_euclidean", VectorSimilarityFunction.EUCLIDEAN, true));
        nativeFunctions.add(createSimilarityFunctionFactory("similarity_dot_product", VectorSimilarityFunction.DOT_PRODUCT, true));
    }

    private static FunctionFactory createSimilarityFunctionFactory(String str, final VectorSimilarityFunction vectorSimilarityFunction, final boolean z) {
        return new FunctionFactory(str, new FunctionParameter[]{FunctionParameter.sameAs(1, false, FunctionParameter.vector(CQL3Type.Native.FLOAT)), FunctionParameter.sameAs(0, false, FunctionParameter.vector(CQL3Type.Native.FLOAT))}) { // from class: org.apache.cassandra.cql3.functions.VectorFcts.1
            @Override // org.apache.cassandra.cql3.functions.FunctionFactory
            protected NativeFunction doGetOrCreateFunction(List<AbstractType<?>> list, AbstractType<?> abstractType) {
                VectorType vectorType = (VectorType) list.get(0);
                int i = vectorType.dimension;
                if (list.stream().allMatch(abstractType2 -> {
                    return ((VectorType) abstractType2).dimension == i;
                })) {
                    return VectorFcts.createSimilarityFunction(this.name.name, vectorType, vectorSimilarityFunction, z);
                }
                throw new InvalidRequestException("All arguments must have the same vector dimensions");
            }
        };
    }

    private static NativeFunction createSimilarityFunction(String str, final VectorType<Float> vectorType, final VectorSimilarityFunction vectorSimilarityFunction, final boolean z) {
        return new NativeScalarFunction(str, FloatType.instance, new AbstractType[]{vectorType, vectorType}) { // from class: org.apache.cassandra.cql3.functions.VectorFcts.2
            @Override // org.apache.cassandra.cql3.functions.NativeFunction, org.apache.cassandra.cql3.functions.Function
            public Arguments newArguments(ProtocolVersion protocolVersion) {
                VectorType vectorType2 = vectorType;
                VectorType vectorType3 = vectorType;
                return new FunctionArguments(protocolVersion, (protocolVersion2, byteBuffer) -> {
                    return vectorType2.composeAsFloat(byteBuffer);
                }, (protocolVersion3, byteBuffer2) -> {
                    return vectorType3.composeAsFloat(byteBuffer2);
                });
            }

            @Override // org.apache.cassandra.cql3.functions.ScalarFunction
            public ByteBuffer execute(Arguments arguments) throws InvalidRequestException {
                if (arguments.containsNulls()) {
                    return null;
                }
                float[] fArr = (float[]) arguments.get(0);
                float[] fArr2 = (float[]) arguments.get(1);
                if (z || !(isAllZero(fArr) || isAllZero(fArr2))) {
                    return FloatType.instance.decompose(Float.valueOf(vectorSimilarityFunction.compare(fArr, fArr2)));
                }
                throw new InvalidRequestException("Function " + this.name + " doesn't support all-zero vectors.");
            }

            private boolean isAllZero(float[] fArr) {
                for (float f : fArr) {
                    if (f != 0.0f) {
                        return false;
                    }
                }
                return true;
            }
        };
    }
}
