Source code for homura.modules.functional.knn

from __future__ import annotations

import torch

from homura.utils import is_faiss_available

if is_faiss_available():
    import faiss


def _tensor_to_ptr(input: torch.Tensor):
    assert input.is_contiguous()
    assert input.dtype in {torch.float32, torch.int64}
    if input.dtype is torch.float32:
        return faiss.cast_integer_to_float_ptr(input.storage().data_ptr() + input.storage_offset() * 4)
    else:
        return faiss.cast_integer_to_long_ptr(input.storage().data_ptr() + input.storage_offset() * 8)


[docs]@torch.no_grad() def torch_knn(keys: torch.Tensor, queries: torch.Tensor, num_neighbors: int, distance: str ) -> tuple[torch.Tensor, torch.Tensor]: """ k nearest neighbor using torch. Users are recommended to use `k_nearest_neighbor` instead. """ if distance == 'inner_product': # for JIT keys = torch.nn.functional.normalize(keys, p=2.0, dim=1) queries = torch.nn.functional.normalize(queries, p=2.0, dim=1) scores = keys.mm(queries.t()) else: scores = keys.mm(queries.t()) scores *= 2 scores -= (keys.pow(2)).sum(1, keepdim=True) scores -= (queries.pow(2)).sum(1).unsqueeze_(0) scores, indices = scores.topk(k=num_neighbors, dim=0, largest=True) scores = scores.t() indices = indices.t() return scores, indices
[docs]@torch.no_grad() def faiss_knn(keys: torch.Tensor, queries: torch.Tensor, num_neighbors: int, distance: str ) -> tuple[torch.Tensor, torch.Tensor]: """ k nearest neighbor using faiss. Users are recommended to use `k_nearest_neighbor` instead. :param keys: tensor of (num_keys, dim) :param queries: tensor of (num_queries, dim) :param num_neighbors: `k` :param distance: user can use str or faiss.METRIC_*. :return: scores, indices in tensor """ if not is_faiss_available(): raise RuntimeError("_faiss_knn requires faiss-gpu") metric_map = {"inner_product": faiss.METRIC_INNER_PRODUCT, "l2": faiss.METRIC_L2, "l1": faiss.METRIC_L1, "linf": faiss.METRIC_Linf, "jansen_shannon": faiss.METRIC_JensenShannon} k_ptr = _tensor_to_ptr(keys) q_ptr = _tensor_to_ptr(queries) scores = keys.new_empty((queries.size(0), num_neighbors), dtype=torch.float32) indices = keys.new_empty((queries.size(0), num_neighbors), dtype=torch.int64) s_ptr = _tensor_to_ptr(scores) i_ptr = _tensor_to_ptr(indices) args = faiss.GpuDistanceParams() args.metric = metric_map[distance] if isinstance(distance, str) else distance args.k = num_neighbors args.dims = queries.size(1) args.vectors = k_ptr args.vectorsRowMajor = True args.numVectors = keys.size(0) args.queries = q_ptr args.queriesRowMajor = True args.numQueries = queries.size(0) args.outDistances = s_ptr args.outIndices = i_ptr faiss.bfKnn(FAISS_RES, args) return scores, indices
[docs]def k_nearest_neighbor(keys: torch.Tensor, queries: torch.Tensor, num_neighbors: int, distance: str, *, backend: str = "torch") -> tuple[torch.Tensor, torch.Tensor]: """ k-Nearest Neighbor search. Faiss backend requires GPU. torch backend is JITtable :param keys: tensor of (num_keys, dim) :param queries: tensor of (num_queries, dim) :param num_neighbors: `k` :param distance: name of distance (`inner_product` or `l2`). Faiss backend additionally supports `l1`, `linf`, `jansen_shannon`. :param backend: backend (`faiss` or `torch`) :return: scores, indices """ assert backend in {"faiss", "torch", "torch_jit"} assert keys.size(1) == queries.size(1) assert keys.ndim == 2 and queries.ndim == 2 f = faiss_knn if backend == "faiss" and is_faiss_available() else torch_knn return f(keys, queries, num_neighbors, distance)
if is_faiss_available(): import faiss FAISS_RES = faiss.StandardGpuResources() FAISS_RES.setDefaultNullStreamAllDevices() FAISS_RES.setTempMemory(1200 * 1024 * 1024)