Skip to content

Commit

Permalink
Initial HNSW implementation
Browse files Browse the repository at this point in the history
This commit includes the work done in collaboration with Hugo Wen from
Amazon:

    MDEV-33408 Alter HNSW graph storage and fix memory leak

    This commit changes the way HNSW graph information is stored in the
    second table. Instead of storing connections as separate records, it now
    stores neighbors for each node, leading to significant performance
    improvements and storage savings.

    Comparing with the previous approach, the insert speed is 5 times faster,
    search speed improves by 23%, and storage usage is reduced by 73%, based
    on ann-benchmark tests with random-xs-20-euclidean and
    random-s-100-euclidean datasets.

    Additionally, in previous code, vector objects were not released after
    use, resulting in excessive memory consumption (over 20GB for building
    the index with 90,000 records), preventing tests with large datasets.
    Now ensure that vectors are released appropriately during the insert and
    search functions. Note there are still some vectors that need to be
    cleaned up after search query completion. Needs to be addressed in a
    future commit.

    All new code of the whole pull request, including one or several files
    that are either new files or modified ones, are contributed under the
    BSD-new license. I am contributing on behalf of my employer Amazon Web
    Services, Inc.

As well as the commit:

    Introduce session variables to manage HNSW index parameters

    Three variables:

    hnsw_max_connection_per_layer
    hnsw_ef_constructor
    hnsw_ef_search

    ann-benchmark tool is also updated to support these variables in commit
    HugoWenTD/ann-benchmarks@e09784e for branch
    https://github.com/HugoWenTD/ann-benchmarks/tree/mariadb-configurable

    All new code of the whole pull request, including one or several files
    that are either new files or modified ones, are contributed under the
    BSD-new license. I am contributing on behalf of my employer Amazon Web
    Services, Inc.
  • Loading branch information
cvicentiu committed May 9, 2024
1 parent 1d18e5c commit f6f6059
Show file tree
Hide file tree
Showing 16 changed files with 789 additions and 146 deletions.
2 changes: 1 addition & 1 deletion sql/filesort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class Bounded_queue
uchar **m_sort_keys;
size_t m_compare_length;
Sort_param *m_sort_param;
Queue<uchar*,uchar*,size_t> m_queue;
Queue<uchar*, size_t> m_queue;
};


Expand Down
1 change: 0 additions & 1 deletion sql/item.h
Original file line number Diff line number Diff line change
Expand Up @@ -6418,7 +6418,6 @@ class Item_int_with_ref :public Item_int
#include "item_subselect.h"
#include "item_xmlfunc.h"
#include "item_jsonfunc.h"
#include "item_vectorfunc.h"
#include "item_create.h"
#include "item_vers.h"
#endif
Expand Down
1 change: 1 addition & 0 deletions sql/item_create.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "sp.h"
#include "sql_time.h"
#include "sql_type_geom.h"
#include "item_vectorfunc.h"
#include <mysql/plugin_function.h>


Expand Down
5 changes: 3 additions & 2 deletions sql/item_subselect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6676,8 +6676,9 @@ subselect_rowid_merge_engine::cmp_keys_by_null_selectivity(Ordered_key **k1,
*/

int
subselect_rowid_merge_engine::cmp_keys_by_cur_rownum(void *, Ordered_key *k1,
Ordered_key *k2)
subselect_rowid_merge_engine::cmp_keys_by_cur_rownum(void *,
const Ordered_key *k1,
const Ordered_key *k2)
{
rownum_t r1= k1->current();
rownum_t r2= k2->current();
Expand Down
8 changes: 5 additions & 3 deletions sql/item_subselect.h
Original file line number Diff line number Diff line change
Expand Up @@ -1349,7 +1349,7 @@ class Ordered_key : public Sql_alloc
return FALSE;
};
/* Return the current index element. */
rownum_t current()
rownum_t current() const
{
DBUG_ASSERT(key_buff_elements && cur_key_idx < key_buff_elements);
return key_buff[cur_key_idx];
Expand Down Expand Up @@ -1488,7 +1488,7 @@ class subselect_rowid_merge_engine: public subselect_partial_match_engine
Priority queue of Ordered_key indexes, one per NULLable column.
This queue is used by the partial match algorithm in method exec().
*/
Queue<Ordered_key, Ordered_key> pq;
Queue<Ordered_key> pq;
protected:
/*
Comparison function to compare keys in order of decreasing bitmap
Expand All @@ -1499,7 +1499,9 @@ class subselect_rowid_merge_engine: public subselect_partial_match_engine
Comparison function used by the priority queue pq, the 'smaller' key
is the one with the smaller current row number.
*/
static int cmp_keys_by_cur_rownum(void *arg, Ordered_key *k1, Ordered_key *k2);
static int cmp_keys_by_cur_rownum(void *arg,
const Ordered_key *k1,
const Ordered_key *k2);

bool test_null_row(rownum_t row_num);
bool exists_complementing_null_row(MY_BITMAP *keys_to_complement);
Expand Down
15 changes: 13 additions & 2 deletions sql/item_vectorfunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

#include <my_global.h>
#include "item.h"
#include "item_vectorfunc.h"

key_map Item_func_vec_distance::part_of_sortkey() const
{
Expand All @@ -48,8 +49,18 @@ double Item_func_vec_distance::val_real()
return 0;
float *v1= (float*)r1->ptr();
float *v2= (float*)r2->ptr();
return euclidean_vec_distance(v1, v2, (r1->length()) / sizeof(float));
}

double euclidean_vec_distance(float *v1, float *v2, size_t v_len)
{
float *p1= v1;
float *p2= v2;
double d= 0;
for (uint i=0; i < r1->length() / sizeof(float); i++)
d+= (v1[i] - v2[i])*(v1[i] - v2[i]);
for (size_t i= 0; i < v_len; p1++, p2++, i++)
{
float dist= *p1 - *p2;
d+= dist * dist;
}
return d;
}
6 changes: 6 additions & 0 deletions sql/item_vectorfunc.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1335 USA */

/* This file defines all vector functions */
#include <my_global.h>
#include "item.h"
#include "lex_string.h"
#include "item_func.h"

Expand All @@ -34,6 +36,7 @@ class Item_func_vec_distance: public Item_real_func
{
return check_argument_types_or_binary(NULL, 0, arg_count);
}

public:
Item_func_vec_distance(THD *thd, Item *a, Item *b)
:Item_real_func(thd, a, b) {}
Expand All @@ -46,6 +49,9 @@ class Item_func_vec_distance: public Item_real_func
key_map part_of_sortkey() const override;
Item *get_copy(THD *thd) override
{ return get_item_copy<Item_func_vec_distance>(thd, this); }
virtual ~Item_func_vec_distance() {};
};


double euclidean_vec_distance(float *v1, float *v2, size_t v_len);
#endif
4 changes: 2 additions & 2 deletions sql/sql_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9842,7 +9842,7 @@ int TABLE::hlindex_open(uint nr)
}
TABLE *table= (TABLE*)alloc_root(&mem_root, sizeof(*table));
if (!table ||
open_table_from_share(in_use, s->hlindex, &empty_clex_str, db_stat, 0,
open_table_from_share(in_use, s->hlindex, &empty_clex_str, db_stat, EXTRA_RECORD,
in_use->open_options, table, 0))
return 1;
hlindex= table;
Expand Down Expand Up @@ -9897,7 +9897,7 @@ int TABLE::hlindex_first(uint nr, Item *item, ulonglong limit)

DBUG_ASSERT(hlindex->in_use == in_use);

return mhnsw_first(this, item, limit);
return mhnsw_first(this, key_info + s->keys, item, limit);
}

int TABLE::hlindex_next()
Expand Down
5 changes: 5 additions & 0 deletions sql/sql_class.h
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,11 @@ typedef struct system_variables
my_bool binlog_alter_two_phase;

Charset_collation_map_st character_set_collations;

/* Temporary for HNSW tests */
uint hnsw_max_connection_per_layer;
uint hnsw_ef_constructor;
uint hnsw_ef_search;
} SV;

/**
Expand Down
19 changes: 15 additions & 4 deletions sql/sql_hset.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,28 @@ class Hash_set
@retval FALSE OK. The value either was inserted or existed
in the hash.
*/
bool insert(T *value)
bool insert(const T *value)
{
return my_hash_insert(&m_hash, reinterpret_cast<const uchar*>(value));
}
bool remove(T *value)
bool remove(const T *value)
{
return my_hash_delete(&m_hash, reinterpret_cast<uchar*>(value));
return my_hash_delete(&m_hash,
reinterpret_cast<uchar*>(const_cast<T*>(value)));
}
T *find(const void *key, size_t klen) const
{
return (T*)my_hash_search(&m_hash, reinterpret_cast<const uchar *>(key), klen);
}

T *find(const T *other) const
{
DBUG_ASSERT(m_hash.get_key);
size_t klen;
uchar *key= m_hash.get_key(reinterpret_cast<const uchar *>(other),
&klen, false);
return find(key, klen);
}
/** Is this hash set empty? */
bool is_empty() const { return m_hash.records == 0; }
/** Returns the number of unique elements. */
Expand All @@ -82,7 +92,8 @@ class Hash_set
void clear() { my_hash_reset(&m_hash); }
const T* at(size_t i) const
{
return reinterpret_cast<T*>(my_hash_element(const_cast<HASH*>(&m_hash), i));
return reinterpret_cast<const T*>(
my_hash_element(const_cast<HASH*>(&m_hash), i));
}
/** An iterator over hash elements. Is not insert-stable. */
class Iterator
Expand Down
4 changes: 2 additions & 2 deletions sql/sql_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -496,8 +496,8 @@ template <class T> class List :public base_list
inline List() :base_list() {}
inline List(const List<T> &tmp, MEM_ROOT *mem_root) :
base_list(tmp, mem_root) {}
inline bool push_back(T *a) { return base_list::push_back(a); }
inline bool push_back(T *a, MEM_ROOT *mem_root)
inline bool push_back(const T *a) { return base_list::push_back((void *)a); }
inline bool push_back(const T *a, MEM_ROOT *mem_root)
{ return base_list::push_back((void*) a, mem_root); }
inline bool push_front(T *a) { return base_list::push_front(a); }
inline bool push_front(T *a, MEM_ROOT *mem_root)
Expand Down
11 changes: 6 additions & 5 deletions sql/sql_queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,25 @@
#ifndef QUEUE_INCLUDED
#define QUEUE_INCLUDED

#include <my_global.h>
#include "queues.h"

/**
A typesafe wrapper of QUEUE, a priority heap
*/
template<typename Element, typename Key, typename Param=void>
template<typename Element, typename Param=void>
class Queue
{
public:
typedef int (*Queue_compare)(Param *, Key *, Key *);
typedef int (*Queue_compare)(Param *, const Element *, const Element *);

Queue() { m_queue.root= 0; }
~Queue() { delete_queue(&m_queue); }
int init(uint max_elements, uint offset_to_key, bool max_at_top,
Queue_compare compare, Param *param= 0)
{
return init_queue(&m_queue, max_elements, offset_to_key, max_at_top,
(queue_compare)compare, param, 0, 0);
(queue_compare)compare, (void *)param, 0, 0);
}

size_t elements() const { return m_queue.elements; }
Expand All @@ -42,11 +43,11 @@ class Queue
bool is_empty() const { return elements() == 0; }
Element *top() const { return (Element*)queue_top(&m_queue); }

void push(Element *element) { queue_insert(&m_queue, (uchar*)element); }
void push(const Element *element) { queue_insert(&m_queue, (uchar*)element); }
Element *pop() { return (Element *)queue_remove_top(&m_queue); }
void clear() { queue_remove_all(&m_queue); }
void propagate_top() { queue_replace_top(&m_queue); }
void replace_top(Element *element)
void replace_top(const Element *element)
{
queue_top(&m_queue)= (uchar*)element;
propagate_top();
Expand Down
20 changes: 20 additions & 0 deletions sql/sys_vars.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7330,3 +7330,23 @@ static Sys_var_enum Sys_block_encryption_mode(
"AES_ENCRYPT() and AES_DECRYPT() functions",
SESSION_VAR(block_encryption_mode), CMD_LINE(REQUIRED_ARG),
block_encryption_mode_values, DEFAULT(0));

/* Temporary for HNSW tests */
static Sys_var_uint Sys_hnsw_ef_search(
"hnsw_ef_search",
"hnsw_ef_search",
SESSION_VAR(hnsw_ef_search), CMD_LINE(NO_ARG),
VALID_RANGE(0, UINT_MAX), DEFAULT(10),
BLOCK_SIZE(1));
static Sys_var_uint Sys_hnsw_ef_constructor(
"hnsw_ef_constructor",
"hnsw_ef_constructor",
SESSION_VAR(hnsw_ef_constructor), CMD_LINE(NO_ARG),
VALID_RANGE(0, UINT_MAX), DEFAULT(10),
BLOCK_SIZE(1));
static Sys_var_uint Sys_hnsw_max_connection_per_layer(
"hnsw_max_connection_per_layer",
"hnsw_max_connection_per_layer",
SESSION_VAR(hnsw_max_connection_per_layer), CMD_LINE(NO_ARG),
VALID_RANGE(0, UINT_MAX), DEFAULT(50),
BLOCK_SIZE(1));

0 comments on commit f6f6059

Please sign in to comment.