-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for euclidean_distance
, dot_product
, cosine_distance
functions
#22397
base: master
Are you sure you want to change the base?
Conversation
ff9912f
to
3939800
Compare
plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlVectorType.java
Outdated
Show resolved
Hide resolved
void testVectorWrite() | ||
{ | ||
try (TestTable table = new TestTable(onRemoteDatabase(), "test_vector_writes", "(v vector(1))")) { | ||
assertUpdate("INSERT INTO " + table.getName() + " VALUES ARRAY[REAL '1.0'], NULL", 2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to allow writing vector types?
I'm thinking about unforeseen complexities/potential bugs which come with this ability
import static org.assertj.core.api.Assertions.assertThatThrownBy; | ||
|
||
final class TestPostgreSqlVectorType | ||
extends AbstractTestQueryFramework |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if we pushdown pgvector operators (<->, <#>, <=>) to PostgreSQL databases that don't have such an extension?
Pls add a corresponding test case for this use-case in a separate class.
plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java
Outdated
Show resolved
Hide resolved
plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/VectorFunctions.java
Outdated
Show resolved
Hide resolved
plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/VectorFunctions.java
Outdated
Show resolved
Hide resolved
plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java
Outdated
Show resolved
Hide resolved
plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlPlugin.java
Outdated
Show resolved
Hide resolved
plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlPlugin.java
Outdated
Show resolved
Hide resolved
return 1.0 - cosineSimilarity; | ||
} | ||
|
||
private static double dotProduct(double[] first, double[] second) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like the callers all have floats, so this could use floats for the argument and then have a double for the computation. This would keep temp memory down.
Also, as @raunaqmorarka mentioned we could likely just work directly with the blocks, or pull out the underlying arrays (see IntArrayBlock
getRawValues
and getRawValuesOffset
). For the moment these are int[], but I expect we will move to either float[]
or MemorySegment
soon.
I think that eventually we should add support for vector type to the engine as a first class citizen (along with float16 type). Cc @martint |
I think we should look at it. Vector is basically a fixed length array, and knowing the fixed length could help speed up some operations. We talked about float16 and float8 last year. I believe they are coming to the JVM, but I believe there were competing standards with these, and I'm not sure what happened. |
3939800
to
9b132b0
Compare
core/trino-main/src/main/java/io/trino/operator/scalar/ArrayVectorFunctions.java
Outdated
Show resolved
Hide resolved
9b132b0
to
8771178
Compare
8771178
to
ed34787
Compare
core/trino-main/src/main/java/io/trino/operator/scalar/ArrayVectorFunctions.java
Outdated
Show resolved
Hide resolved
core/trino-main/src/main/java/io/trino/operator/scalar/ArrayVectorFunctions.java
Outdated
Show resolved
Hide resolved
core/trino-main/src/main/java/io/trino/operator/scalar/ArrayVectorFunctions.java
Outdated
Show resolved
Hide resolved
7dbb05a
to
e4b33c9
Compare
e4b33c9
to
01bbcf4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First commit looks good.
@Execution(CONCURRENT) | ||
final class TestArrayVectorFunctions | ||
{ | ||
private QueryAssertions assertions; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make this final and initialize it at the declaration site.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
3777742
to
2df3b38
Compare
2df3b38
to
07aa1a1
Compare
euclidean_distance
, dot_product
, cosine_distance
functions
Description
Adds 3 functions calculating distance:
Release notes