Skip to content

Commit

Permalink
Improve the accuracy of BM25 score when multiple data parts are involved
Browse files Browse the repository at this point in the history
  • Loading branch information
Mochi Xu authored and Shanfeng Pang committed Apr 22, 2024
1 parent abbf1f8 commit 40f8339
Show file tree
Hide file tree
Showing 16 changed files with 412 additions and 16 deletions.
2 changes: 1 addition & 1 deletion rust/supercrate/libs/tantivy_search
Submodule tantivy_search updated from 06a58e to eaee87
4 changes: 4 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,10 @@ if (TARGET ch_contrib::jemalloc)
target_link_libraries (clickhouse_common_io PRIVATE ch_contrib::jemalloc)
endif()

if (TARGET ch_rust::supercrate)
target_link_libraries (clickhouse_common_io PRIVATE ch_rust::supercrate)
endif()

add_subdirectory(Access/Common)
add_subdirectory(Common/ZooKeeper)
add_subdirectory(Common/Config)
Expand Down
2 changes: 2 additions & 0 deletions src/Common/CurrentMetrics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@
M(LocalThreadActive, "Number of threads in local thread pools running a task.") \
M(MergeTreeDataSelectExecutorThreads, "Number of threads in the MergeTreeDataSelectExecutor thread pool.") \
M(MergeTreeDataSelectExecutorThreadsActive, "Number of threads in the MergeTreeDataSelectExecutor thread pool running a task.") \
M(MergeTreeDataSelectBM25CollectThreads, "Number of threads in the thread pool for collecting statistics for text search.") \
M(MergeTreeDataSelectBM25CollectThreadsActive, "Number of threads in the thread pool for collecting statistics for text search running a task.") \
M(BackupsThreads, "Number of threads in the thread pool for BACKUP.") \
M(BackupsThreadsActive, "Number of threads in thread pool for BACKUP running a task.") \
M(RestoreThreads, "Number of threads in the thread pool for RESTORE.") \
Expand Down
32 changes: 27 additions & 5 deletions src/Storages/MergeTree/TantivyIndexStore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -507,22 +507,44 @@ rust::cxxbridge1::Vec<std::uint8_t> TantivyIndexStore::termsQueryBitmap(String c

return ffi_query_terms_bitmap(index_files_cache_path, column_name, terms);
}
rust::cxxbridge1::Vec<RowIdWithScore> TantivyIndexStore::bm25Search(String sentence, size_t topk)

rust::cxxbridge1::Vec<RowIdWithScore> TantivyIndexStore::bm25Search(String sentence, Statistics & statistics, size_t topk)
{
if (!index_reader_status)
getTantivyIndexReader();

std::vector<uint8_t> u8_alived_bitmap;
return ffi_bm25_search(index_files_cache_path, sentence, static_cast<uint32_t>(topk), u8_alived_bitmap, false);
return ffi_bm25_search(index_files_cache_path, sentence, static_cast<uint32_t>(topk), u8_alived_bitmap, false, statistics);
}

rust::cxxbridge1::Vec<RowIdWithScore> TantivyIndexStore::bm25SearchWithFilter(
String sentence, Statistics & statistics, size_t topk, const std::vector<uint8_t> & u8_alived_bitmap)
{
if (!index_reader_status)
getTantivyIndexReader();

return ffi_bm25_search(index_files_cache_path, sentence, static_cast<uint32_t>(topk), u8_alived_bitmap, true, statistics);
}

rust::cxxbridge1::Vec<DocWithFreq> TantivyIndexStore::getDocFreq(String sentence)
{
if (!index_reader_status)
getTantivyIndexReader();
return ffi_get_doc_freq(index_files_cache_path, sentence);
}

rust::cxxbridge1::Vec<RowIdWithScore>
TantivyIndexStore::bm25SearchWithFilter(String sentence, size_t topk, const std::vector<uint8_t> & u8_alived_bitmap)
UInt64 TantivyIndexStore::getTotalNumDocs()
{
if (!index_reader_status)
getTantivyIndexReader();
return ffi_get_total_num_docs(index_files_cache_path);
}

return ffi_bm25_search(index_files_cache_path, sentence, static_cast<uint32_t>(topk), u8_alived_bitmap, true);
UInt64 TantivyIndexStore::getTotalNumTokens()
{
if (!index_reader_status)
getTantivyIndexReader();
return ffi_get_total_num_tokens(index_files_cache_path);
}

UInt64 TantivyIndexStore::getIndexedDocsNum()
Expand Down
14 changes: 11 additions & 3 deletions src/Storages/MergeTree/TantivyIndexStore.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,17 @@ class TantivyIndexStore
rust::cxxbridge1::Vec<std::uint8_t> termsQueryBitmap(String column_name, std::vector<String> terms);

/// For BM25Search and HybridSearch
rust::cxxbridge1::Vec<RowIdWithScore> bm25Search(String sentence, size_t topk);

rust::cxxbridge1::Vec<RowIdWithScore> bm25SearchWithFilter(String sentence, size_t topk, const std::vector<uint8_t> & u8_alived_bitmap);
/// New version
rust::cxxbridge1::Vec<RowIdWithScore> bm25Search(String sentence, Statistics & statistics, size_t topk);
rust::cxxbridge1::Vec<RowIdWithScore>
bm25SearchWithFilter(String sentence, Statistics & statistics, size_t topk, const std::vector<uint8_t> & u8_alived_bitmap);

/// Get current part sentence doc_freq, sentence will be tokenized by tokenizer with each indexed column.
rust::cxxbridge1::Vec<DocWithFreq> getDocFreq(String sentence);
/// Get current part total_num_docs, each column will have same total_num_docs.
UInt64 getTotalNumDocs();
/// Get current part total_num_tokens, each column will have it's own total_num_tokens.
UInt64 getTotalNumTokens();

/// Get the number of documents stored in the index file.
UInt64 getIndexedDocsNum();
Expand Down
9 changes: 8 additions & 1 deletion src/VectorIndex/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,11 @@ target_link_libraries(clickhouse_vector_index
dbms
PRIVATE
ch_contrib::search_index
)
)

if (TARGET ch_rust::supercrate)
target_link_libraries(clickhouse_vector_index
PRIVATE
ch_rust::supercrate
)
endif()
67 changes: 67 additions & 0 deletions src/VectorIndex/Common/BM25InfoInDataParts.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#include <VectorIndex/Common/BM25InfoInDataParts.h>

#include <Common/logger_useful.h>

namespace DB
{

#if USE_TANTIVY_SEARCH
UInt64 BM25InfoInDataPart::getTotalDocsCount() const
{
return total_docs;
}

UInt64 BM25InfoInDataPart::getTotalNumTokens() const
{
return total_num_tokens;
}

const RustVecDocWithFreq & BM25InfoInDataPart::getTermWithDocNums() const
{
return term_with_doc_nums;
}


UInt64 BM25InfoInDataParts::getTotalDocsCountAllParts() const
{
UInt64 result = 0;
for (const auto & part : *this)
result += part.getTotalDocsCount();
return result;
}

UInt64 BM25InfoInDataParts::getTotalNumTokensAllParts() const
{
UInt64 result = 0;
for (const auto & part : *this)
result += part.getTotalNumTokens();
return result;
}

RustVecDocWithFreq BM25InfoInDataParts::getTermWithDocNumsAllParts() const
{
/// Add number of docs containing a term in all parts based on term name and column name
using FieldIdAndTokenName = std::pair<UInt32, String>;
std::map<FieldIdAndTokenName, UInt64> field_token_name_with_docs_map;
for (const auto & part : *this)
{
auto & doc_nums_in_part = part.getTermWithDocNums();

/// Loop through the vector of Vec<DocWithFreq> in a part
for (auto & field_token_doc_freq : doc_nums_in_part)
{
FieldIdAndTokenName field_token = FieldIdAndTokenName(field_token_doc_freq.field_id, field_token_doc_freq.term_str);
field_token_name_with_docs_map[field_token] += field_token_doc_freq.doc_freq;
}
}

RustVecDocWithFreq result;
result.reserve(field_token_name_with_docs_map.size());

for (const auto & [col_token, doc_freq] : field_token_name_with_docs_map)
result.push_back({col_token.second, col_token.first, doc_freq});

return result;
}
#endif
}
50 changes: 50 additions & 0 deletions src/VectorIndex/Common/BM25InfoInDataParts.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#pragma once

#include <base/types.h>
#include <vector>
#include "config.h"

#if USE_TANTIVY_SEARCH
# include <tantivy_search.h>
#endif

namespace DB
{

#if USE_TANTIVY_SEARCH

using RustVecDocWithFreq = rust::cxxbridge1::Vec<DocWithFreq>;

struct BM25InfoInDataPart
{
UInt64 total_docs; /// Total number of documents in a data part
UInt64 total_num_tokens; /// Total number of tokens from all documents in a data part
RustVecDocWithFreq term_with_doc_nums; /// vector of terms with number of documents containing it

BM25InfoInDataPart() = default;

BM25InfoInDataPart(
const UInt64 & total_docs_,
const UInt64 & total_num_tokens_,
const RustVecDocWithFreq & term_with_doc_nums_)
: total_docs{total_docs_}
, total_num_tokens{total_num_tokens_}
, term_with_doc_nums{term_with_doc_nums_}
{}

UInt64 getTotalDocsCount() const;
UInt64 getTotalNumTokens() const;
const RustVecDocWithFreq & getTermWithDocNums() const;
};

struct BM25InfoInDataParts: public std::vector<BM25InfoInDataPart>
{
using std::vector<BM25InfoInDataPart>::vector;

UInt64 getTotalDocsCountAllParts() const;
UInt64 getTotalNumTokensAllParts() const;
RustVecDocWithFreq getTermWithDocNumsAllParts() const;
};

#endif
}
Loading

0 comments on commit 40f8339

Please sign in to comment.