Discovering Data Structures:
Nearest Neighbor Search and Beyond

Omar Salemohamed
Université de Montréal, Mila
&Laurent Charlin
HEC Montréal, Mila
&Shivam Garg
Microsoft Research
&Vatsal Sharan
University of Southern California
&Gregory Valiant
Stanford University
Corresponding Author: [email protected] listed in alphabetical order.
Abstract

We propose a general framework for end-to-end learning of data structures. Our framework adapts to the underlying data distribution and provides fine-grained control over query and space complexity. Crucially, the data structure is learned from scratch, and does not require careful initialization or seeding with candidate data structures. We first apply this framework to the problem of nearest neighbor search. In several settings, we are able to reverse-engineer the learned data structures and query algorithms. For 1D nearest neighbor search, the model discovers optimal distribution (in)dependent algorithms such as binary search and variants of interpolation search. In higher dimensions, the model learns solutions that resemble k-d trees in some regimes, while in others, elements of locality-sensitive hashing emerge. Additionally, the model learns useful representations of high-dimensional data and exploits them to design effective data structures. We also adapt our framework to the problem of estimating frequencies over a data stream, and believe it could be a powerful discovery tool for new problems.

1 Introduction

Can deep learning models be trained to discover data structures from scratch?

There are several motivations for this question. The first is scientific. Deep learning models are increasingly performing tasks once considered exclusive to humans, from image recognition and mastering the game of Go to engaging in natural language conversations. Designing data structures and algorithms, along with solving complex math problems, are particularly challenging tasks. They require searching through a vast combinatorial space with a difficult to define structure. It is therefore natural to ask what it would take for deep learning models to solve such problems. There are already promising signs: these models have discovered fast matrix-multiplication algorithms (Fawzi et al., 2022), solved SAT problems (Selsam et al., 2018), and learned optimization algorithms for various learning tasks (Garg et al., 2022; Akyürek et al., 2022; Fu et al., 2023; Von Oswald et al., 2023). In this work, we investigate the problem of data structure discovery, with a focus on nearest neighbor search.

The second motivation is practical. Data structures are ubiquitous objects that enable efficient querying. Traditionally, they have been designed to be worst-case optimal and therefore agnostic to the underlying data and query distributions. However, in many applications there are patterns in these distributions that can be exploited to design more efficient data structures. This has motivated recent work on learning-augmented data structures which leverages knowledge of the data distribution to modify existing data structures with predictions (Lykouris & Vassilvitskii, 2018; Ding et al., 2020; Lin et al., 2022a; Mitzenmacher & Vassilvitskii, 2022). In much of this work, the goal of the learning algorithm is to learn distributional properties of the data, while the underlying query algorithm/data structure is hand-designed. Though this line of work clearly demonstrates the potential of leveraging distributional information, it still relies on expert knowledge to incorporate learning into such structures. In our work, we ask if it is possible to go one step further and let deep learning models discover entire data structures and query algorithms in an end-to-end manner.

1.1 Framework for data structure discovery

Refer to caption
Figure 1: Our model has two components: 1) A data-processing network that transforms raw data into structured data, arranging it for efficient querying and generating additional statistics when given extra space (not shown in the figure). 2) A query-execution network that performs M𝑀Mitalic_M lookups into the output of the data-processing network in order to retrieve the answer to some query q𝑞qitalic_q. Each lookup i𝑖iitalic_i is managed by a separate query model Qisuperscript𝑄𝑖Q^{i}italic_Q start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT, which takes q𝑞qitalic_q and the lookup history Hisubscript𝐻𝑖H_{i}italic_H start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, and outputs a one-hot lookup vector misubscript𝑚𝑖m_{i}italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT indicating the position to query.

Data structure problems can often be decomposed into two steps: 1) data structure construction and 2) query execution. The first step transforms a raw dataset D𝐷Ditalic_D into a structured database D^^𝐷\hat{D}over^ start_ARG italic_D end_ARG, while query-execution performs lookups into D^^𝐷\hat{D}over^ start_ARG italic_D end_ARG to retrieve the answer for some query q𝑞qitalic_q. The performance of a data structure is typically quantified in terms of two measures: space complexity—how much memory is required to store the data structure, and query complexity—how many lookups into the data structure are required to answer some query. One can typically tradeoff larger space complexity for smaller query complexity, and vice versa. We focus on these criteria as they are widely studied and have clear practical connections to efficiency.

To learn such data structures, we have a data-processing network which learns how to map a raw dataset to a data structure, and a query network which learns an algorithm for using the data structure to answer queries (Figure 1). In order to learn efficient data structures and query algorithms we impose constraints on the size of the data structures and on the number of lookups that the query network can make into the data structure. Crucially, we propose end-to-end training of both networks such that the learned data structure and query algorithm are optimized for one another. Moreover, in settings where it is beneficial to learn lower-dimensional representations from high-dimensional data, end-to-end training encourages the representations to capture features of the problem that the data structure can exploit.

On the one hand, learning the data-processing network and query network jointly in an end-to-end fashion seems obvious, especially given the many successes of end-to-end learning over the past decade. On the other hand, it might be hard to imagine such learning getting off the ground. For instance, if the data-processing network produces a random garbled function of the dataset, we cannot hope the query model to do anything meaningful. This is further complicated by the fact that these data structure tasks are more discrete and combinatorial in terms of how the query model accesses the data structure.

1.2 Summary of Results

We apply this framework to the problem of nearest neighbor (NN) search in both low and high dimensions. Given the extensive theoretical work on this topic, along with its widespread practical applications, NN search is an ideal starting point for understanding the landscape of end-to-end data structure discovery. Beyond NN search, we explore the problem of frequency estimation in streaming data and discuss other potential applications of this framework. Our findings are:

Sorting and searching in 1D (Section 2.2)

For 1D nearest neighbor search, the data-processing network learns to sort, while the query network simultaneously learns to search over the sorted data. When the data follows a uniform or Zipfian distribution, the query network exploits this structure to outperform binary search. On harder distributions lacking structure, the network adapts by discovering binary search, which is worst-case optimal. Importantly, the model discovers that sorting followed by the appropriate search algorithm is effective for NN search in 1D without explicit supervision for these primitives.

K-d trees in 2D (Section 2.3)

In 2D, when the data is drawn from a uniform distribution, the model discovers a data structure that outperforms k-d trees. On harder distributions, the learned data structure shows surprising resemblance to a k-d tree. This is striking as a k-d tree is a non-trivial data structure, constructed by recursively partitioning the data and finding the median along alternating dimensions.

Useful representations in high dimensions (Section 2.4)

For high-dimensional data, the model learns representations that make NN search efficient. For example, with data from a uniform distribution on a 30-dimensional hypersphere, the model partitions the space by projecting onto a pair of vectors, similar to locality-sensitive hashing. When trained on an extended 3-digit MNIST dataset, the model finds features that capture the relative ordering of the digits, sorts the images using these features, and performs a search on the sorted images—all of which is learned jointly from scratch.

Trading off space and query efficiency (Section 2.5)

An ideal data structure can use extra space to improve query efficiency by storing additional statistics. The learned model demonstrates this behavior, with performance improving monotonically as more space is provided, in both low and high dimensions. Thus, the model learns to effectively trade off space for query efficiency.

Beyond nearest neighbor search (Section 3)

We also explore the classical problem of frequency estimation, where a memory-constrained model observes a stream of items and must approximate the frequency of a query item. The learned structure exploits the underlying data distribution to outperform baselines like CountMin sketch, demonstrating the broader applicability of the framework beyond nearest neighbor search.

2 Nearest Neighbor Search

Given a dataset D={x1,,xN}𝐷subscript𝑥1subscript𝑥𝑁D=\{x_{1},...,x_{N}\}italic_D = { italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_x start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT } of N𝑁Nitalic_N points where xidsubscript𝑥𝑖superscript𝑑x_{i}\in\mathbb{R}^{d}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT and a query qd𝑞superscript𝑑q\in\mathbb{R}^{d}italic_q ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT, the nearest neighbor y𝑦yitalic_y of q𝑞qitalic_q is defined as y=argminxiDdist(xi,q)𝑦subscriptsubscript𝑥𝑖𝐷𝑑𝑖𝑠𝑡subscript𝑥𝑖𝑞y=\arg\min_{x_{i}\in D}\ dist(x_{i},q)italic_y = roman_arg roman_min start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ italic_D end_POSTSUBSCRIPT italic_d italic_i italic_s italic_t ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_q ). We mostly focus on the case where dist()𝑑𝑖𝑠𝑡dist(\cdot)italic_d italic_i italic_s italic_t ( ⋅ ) corresponds to the Euclidean distance. Our objective is to learn a data structure D^^𝐷\hat{D}over^ start_ARG italic_D end_ARG for D𝐷Ditalic_D such that given q𝑞qitalic_q and a budget of M𝑀Mitalic_M lookups, we can output a (approximate) nearest neighbor of q𝑞qitalic_q by querying at most M𝑀Mitalic_M elements in D^^𝐷\hat{D}over^ start_ARG italic_D end_ARG. When MN𝑀𝑁M\geq Nitalic_M ≥ italic_N, y𝑦yitalic_y can be trivially recovered via linear search so D^=D^𝐷𝐷\hat{D}=Dover^ start_ARG italic_D end_ARG = italic_D is sufficient. Instead, we are interested in the case when MNmuch-less-than𝑀𝑁M\ll Nitalic_M ≪ italic_N.111E.g. in 1D, binary search requires M=log(N)𝑀𝑁~{}M=\log(N)italic_M = roman_log ( italic_N ) lookups given a sorted list.

2.1 Setup

Data-processing Network

Recall that the role of the data-processing network is to transform a raw dataset into a data structure. The backbone of our data-processing network is an 8-layer transformer model based on the NanoGPT architecture (Karpathy, 2024). In the case of NN search, we want the data structure to preserve the original inputs and just reorder them appropriately as the answer to the nearest neighbor query should be one of elements in the dataset. The model achieves this by outputting a rank associated with each element in the dataset, which is then used to reorder the elements. More precisely, the transformer takes as input the dataset D𝐷Ditalic_D and outputs a scalar oisubscript𝑜𝑖o_{i}\in\mathbb{R}italic_o start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R representing the rank for each point xiDsubscript𝑥𝑖𝐷x_{i}\in Ditalic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ italic_D. These rankings {o1,,oN}subscript𝑜1subscript𝑜𝑁\{o_{1},...,o_{N}\}{ italic_o start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_o start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT } are then sorted using a differentiable sort function, sort({o1,o2,oN})𝑠𝑜𝑟𝑡subscript𝑜1subscript𝑜2subscript𝑜𝑁sort(\{o_{1},o_{2}\ldots,o_{N}\})italic_s italic_o italic_r italic_t ( { italic_o start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_o start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT … , italic_o start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT } ) (Grover et al., 2019; Cuturi et al., 2019; Petersen et al., 2022), which produces a permutation matrix P𝑃Pitalic_P that encodes the order based on the rankings. By applying P𝑃Pitalic_P to the input dataset D𝐷Ditalic_D, we obtain D^Psubscript^𝐷𝑃\hat{D}_{P}over^ start_ARG italic_D end_ARG start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT, where the input data points are arranged in order of their rankings. By learning to rank rather than directly outputting the transformed dataset, the transformer avoids the need to reproduce the exact inputs. Note that this division into a ranking model followed by sorting is without loss of generality as the overall model can represent any arbitrary ordering of the inputs.

We also consider scenarios where the data structure can use additional space. To support this use case, the transformer outputs T𝑇Titalic_T extra vectors b1,,bTdsubscript𝑏1subscript𝑏𝑇superscript𝑑b_{1},...,b_{T}\in\mathbb{R}^{d}italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_b start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT which can be retrieved by the query-execution network. We form the data structure D^^𝐷\hat{D}over^ start_ARG italic_D end_ARG by concatenating the permuted inputs and the extra vectors: D^=[D^P,b1,,bT]^𝐷subscript^𝐷𝑃subscript𝑏1subscript𝑏𝑇\hat{D}=[\hat{D}_{P},b_{1},...,b_{T}]over^ start_ARG italic_D end_ARG = [ over^ start_ARG italic_D end_ARG start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_b start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ].

Query Execution Network

The role of the query-execution network is to output a (approximate) nearest-neighbor of some query q𝑞qitalic_q given a budget of M𝑀Mitalic_M lookups into the data structure D^^𝐷\hat{D}over^ start_ARG italic_D end_ARG. The query-execution network consists of M𝑀Mitalic_M MLP query models Q1,,QMsuperscript𝑄1superscript𝑄𝑀Q^{1},...,Q^{M}italic_Q start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , … , italic_Q start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT. The query models do not share weights. Each query model Qisuperscript𝑄𝑖Q^{i}italic_Q start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT outputs a one-hot vector miN+Tsubscript𝑚𝑖superscript𝑁𝑇m_{i}\in\mathbb{R}^{N+T}italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_N + italic_T end_POSTSUPERSCRIPT which represents a lookup position in D^^𝐷\hat{D}over^ start_ARG italic_D end_ARG. To execute the lookup, we compute the value visubscript𝑣𝑖v_{i}italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT at the position denoted by misubscript𝑚𝑖m_{i}italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT in D^^𝐷\hat{D}over^ start_ARG italic_D end_ARG as vi=miD^subscript𝑣𝑖superscriptsubscript𝑚𝑖top^𝐷v_{i}=m_{i}^{\top}\hat{D}italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT over^ start_ARG italic_D end_ARG. In addition to the query q𝑞qitalic_q, each query model Qisuperscript𝑄𝑖Q^{i}italic_Q start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT also takes as input the query execution history Hi={(m1,v1),,(mi1,vi1)}subscript𝐻𝑖subscript𝑚1subscript𝑣1subscript𝑚𝑖1subscript𝑣𝑖1H_{i}=\{(m_{1},v_{1}),...,(m_{i-1},v_{i-1})\}italic_H start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = { ( italic_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_v start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , … , ( italic_m start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT , italic_v start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT ) } where H1=subscript𝐻1H_{1}=\emptysetitalic_H start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = ∅. The final answer of the network for the nearest-neighbor query is given by y^=mMD^^𝑦subscriptsuperscript𝑚top𝑀^𝐷\hat{y}=m^{\top}_{M}\hat{D}over^ start_ARG italic_y end_ARG = italic_m start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT over^ start_ARG italic_D end_ARG.

To restrict our model to exactly M𝑀Mitalic_M lookups, we enforce each lookup vector misubscript𝑚𝑖m_{i}italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT to be a one-hot vector. Enforcing this constraint during training poses a challenge as it is a non-differentiable operation. Instead, during training, our model outputs soft-lookups where misubscript𝑚𝑖m_{i}italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is the output of the softmax function and jmij=1subscript𝑗subscript𝑚𝑖𝑗1\sum_{j}m_{ij}=1∑ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_m start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT = 1. This alone, however, leads to non-sparse queries. To address this, we add noise to the logits before the softmax operation (only during training). This regularizes the query network, encouraging it to produce sparser solutions (see App C.1 for details as to why this occurs). Intuitively, the network learns a function that is robust to noise, and the softmax output becomes robust when the logits are well-separated. Well-separated logits, in turn, lead to sparser solutions. At inference time, we do not add this noise and we ensure the lookup vector misubscript𝑚𝑖m_{i}italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is a one-hot vector by applying a hardmax function to the network output instead of a softmax.

Data Generation and Training

Each training example is a tuple (D,q,y)𝐷𝑞𝑦(D,q,y)( italic_D , italic_q , italic_y ) consisting of a dataset D𝐷Ditalic_D, query q𝑞qitalic_q, and nearest neighbor y𝑦yitalic_y generated as follows: (i) sample dataset D={x1,,xN}𝐷subscript𝑥1subscript𝑥𝑁D=\{x_{1},...,x_{N}\}italic_D = { italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_x start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT } from dataset distribution PDsubscript𝑃𝐷P_{D}italic_P start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT, (ii) sample query q𝑞qitalic_q from query distribution Pqsubscript𝑃𝑞P_{q}italic_P start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT, (iii) compute nearest neighbor y=argminxiDdist(xiq)𝑦subscriptsubscript𝑥𝑖𝐷𝑑𝑖𝑠𝑡subscript𝑥𝑖𝑞y=\arg\min_{x_{i}\in D}dist(x_{i}-q)italic_y = roman_arg roman_min start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ italic_D end_POSTSUBSCRIPT italic_d italic_i italic_s italic_t ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - italic_q ). Unless otherwise specified, dist𝑑𝑖𝑠𝑡distitalic_d italic_i italic_s italic_t corresponds to the Euclidean distance. The dataset and query distributions PD,Pqsubscript𝑃𝐷subscript𝑃𝑞P_{D},P_{q}italic_P start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT , italic_P start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT vary across the different settings we consider and are defined later. Given a training example (D,q,y)𝐷𝑞𝑦(D,q,y)( italic_D , italic_q , italic_y ), the data-processing network transforms D𝐷Ditalic_D into the data structure D^^𝐷\hat{D}over^ start_ARG italic_D end_ARG. Subsequently, the query-execution network, conditioned on q𝑞qitalic_q, queries the data structure to output y^^𝑦\hat{y}over^ start_ARG italic_y end_ARG. We use SGD to minimize either the squared loss between y𝑦yitalic_y and y^^𝑦\hat{y}over^ start_ARG italic_y end_ARG, or the cross-entropy loss between the corresponding vectors encoding their positions. This is an empirical choice, and in some settings one loss function performs better than the other. All models are trained for at most 4 million gradient steps with early-stopping using a batch size of 1024. After training, we test our model on 10101010k inputs (D,q,y)𝐷𝑞𝑦(D,q,y)( italic_D , italic_q , italic_y ) generated in the same way. We describe the exact model architecture and training hyper-parameters in App A.1.

Evaluation and Baselines

We evaluate our end-to-end model (referred to as E2E) on one-dimensional, two-dimensional, and high-dimensional nearest-neighbor problems. We primarily focus on data structures that do not use extra space, but in Section 2.5, we also explore scenarios with additional space.

We compare against suitable NN data structures in each setting (e.g., sorting followed by binary search in 1D), and also against several ablations to study the impact of various model components. The E2E (frozen) model does not train the data-processing network, relying on rankings generated by the initial weights. The E2E (no-permute) model removes the permutation component of the data-processing network so that the transformer has to learn to transform the data points directly. The E2E (non-adaptive) model ablation conditions each query model Qisuperscript𝑄𝑖Q^{i}italic_Q start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT on only the query q𝑞qitalic_q and not the query history Hisubscript𝐻𝑖H_{i}italic_H start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. From the M𝑀Mitalic_M query models, we select the prediction that is closest to the query as the final prediction y^^𝑦\hat{y}over^ start_ARG italic_y end_ARG.

2.2 One-dimensional data

Refer to caption
Refer to caption
Refer to caption
Figure 2: (Left) Our model (E2E) trained with 1D data from the uniform distribution over (1,1)11(-1,1)( - 1 , 1 ) outperforms binary search and several ablations. (Center) Distribution of lookups by the first query model. Unlike binary search, the model does not always start in the middle but rather closer to the query’s likely position in the sorted data. (Right) When trained on data from a “hard” distribution for which the query value does not reveal information about the query’s relative position, the model finds a solution similar to binary search. The figure shows an example of the model performing binary search (’X’ denotes the nearest neighbor location).
Uniform Distribution

We consider a setting where the data distribution PDsubscript𝑃𝐷P_{D}italic_P start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT and query distribution Pqsubscript𝑃𝑞P_{q}italic_P start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT correspond to the uniform distribution over (1,1)11(-1,1)( - 1 , 1 ), N=100𝑁100N=100italic_N = 100 and M=7𝑀7M=7italic_M = 7. We plot the accuracy222We include MSE plots as well in App. B., which refers to zero-one loss in identifying the nearest neighbor, after each lookup in Figure 2 (Left). Recall that visubscript𝑣𝑖v_{i}italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT corresponds to the output of the i𝑖iitalic_i-th lookup. Let visuperscriptsubscript𝑣𝑖v_{i}^{*}italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT denote the closest element to the query so far: vi=argminv{v1,,vi}vq22subscriptsuperscript𝑣𝑖subscript𝑣subscript𝑣1subscript𝑣𝑖superscriptsubscriptnorm𝑣𝑞22v^{*}_{i}=\arg\min_{v\in\{v_{1},...,v_{i}\}}||v-q||_{2}^{2}italic_v start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = roman_arg roman_min start_POSTSUBSCRIPT italic_v ∈ { italic_v start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } end_POSTSUBSCRIPT | | italic_v - italic_q | | start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT. At each lookup index we plot the nearest neighbor accuracy corresponding to visuperscriptsubscript𝑣𝑖v_{i}^{*}italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT. We do this for all the methods.

A key component in being able to do NN search in 1D is sorting. We observe that the trained model does indeed learn to sort. We verify this by measuring the fraction of inputs that are mapped to the correct position in the sorted order, averaged over multiple datasets. The trained model correctly positions approximately 99.5% of the inputs. This is interesting as the model never received explicit feedback to sort the inputs and figured it out in the end-to-end training. The separate sorting function aids the process, but the model still has to learn to output the correct rankings.

The second key component is the ability to search over the sorted inputs. Here, our model learns a search algorithm that outperforms binary search, which is designed for the worst case. This is because unlike binary search, our model exploits knowledge of the data distribution to start its search closer to the nearest neighbor, similar to interpolation search (Peterson, 1957). For instance, if the query q1𝑞1q\approx 1italic_q ≈ 1, the model begins its search near the end of the list (Figure 2 (Center)). The minor sorting error (0.5%similar-toabsentpercent0.5\sim 0.5\%∼ 0.5 %) our model makes likely explains its worse performance on the final query.

Refer to caption
Figure 3: For 1D Zipfian query distribution, our model performs slightly better than the the learning-augmented treap algorithm from Hsu et al. (2019) and both methods significantly outperforms binary search.

To understand the relevance of different model components, we compare against various ablations. The E2E (frozen) model (untrained transformer) positions only about 9% of inputs correctly, explaining its under-performance. This shows that the transformer must learn to rank the inputs, and that merely using a separate function for sorting the transformer output is insufficient. The E2E (non-adaptive) baseline, lacking query history access, underperforms as it fails to learn adaptive solutions crucial for 1D NN search. The E2E (no-permute) ablation does not fully retain inputs and so we do not measure accuracy for this baseline. We verify this by measuring the average minimum distance between each of the transformer’s inputs to its outputs. These ablations highlight the crucial role of both learned orderings and query adaptivity for our model.

Zipfian Distribution

Prior work has shown that several real-world query distributions follow a Zipfian trend whereby a few elements are queried far more frequently than others, leading to the development of learning-augmented algorithms aimed at exploiting this (Lin et al., 2022b). We consider a setting where PDsubscript𝑃𝐷P_{D}italic_P start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT is the discrete uniform distribution over {1,,200}1200\{1,...,200\}{ 1 , … , 200 } and Pqsubscript𝑃𝑞P_{q}italic_P start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT is a Zipfian distribution over {1,,200}1200\{1,...,200\}{ 1 , … , 200 } skewed towards smaller numbers such that the number i𝑖iitalic_i is sampled with probability proportional to 1iα1superscript𝑖𝛼\frac{1}{i^{\alpha}}divide start_ARG 1 end_ARG start_ARG italic_i start_POSTSUPERSCRIPT italic_α end_POSTSUPERSCRIPT end_ARG. We set α=1.2𝛼1.2\alpha=1.2italic_α = 1.2. Again, in this setting N=100𝑁100N=100italic_N = 100 and M=7𝑀7M=7italic_M = 7.

In Figure 3 we compare our model to both binary search and the learning-augmented treap from Lin et al. (2022a). Our model performs slightly better than the learning-augmented treap and both algorithms significantly outperform binary search with less than log(N)𝑁\log(N)roman_log ( italic_N ) queries. This setting highlights a crucial difference in spirit between our work and much of the existing work on learning-augmented algorithms. While the Zipfian treap incorporates learning in the algorithm, the authors still had to figure out how an existing data structure could be modified to support learning. On the other hand, by learning end-to-end, our framework altogether removes the need for the human-in-the-loop. This is promising as it could be useful in settings where we lack insight on appropriate data structures. The flip side, however, is that learning-augmented data structures usually come with provable guarantees which are difficult to get when training models in an end-to-end fashion.

Hard Distribution

To verify that our model can also learn worst-case optimal algorithms such as binary search, we set PDsubscript𝑃𝐷P_{D}italic_P start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT to be a “hard” distribution, with the property that for any given query there does not exist a strong prior over the position of its nearest neighbor in the sorted data (see App. B.1 for more details). To produce a problem instance, we first sample a dataset from PDsubscript𝑃𝐷P_{D}italic_P start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT. We then generate the query by sampling a point (uniformly at random) from this dataset, and adding standard Gaussian noise to it. The hard distribution generates numbers at several scales, and this makes it challenging to train the model with larger N𝑁Nitalic_N. Thus, we use N=15𝑁15N=15italic_N = 15 and M=3𝑀3M=3italic_M = 3. In general, we find that training models is easier when there is more structure in the distribution to be exploited.

The model does indeed discover a search algorithm similar to binary search. In Figure 2 (Right), we show a representative example of the model’s search behavior, resembling binary search (see Figure 16 for more examples). The error curve in Figure 14 also closely matches that of binary search.

In summary, in all the above settings, starting from scratch, the data-processing network discovers that the optimal way to arrange the data is in sorted order. Simultaneously, the query-execution network learns to efficiently query this sorted data, leveraging the properties of the data distribution.

2.3 Two-dimensional data

Refer to caption
Figure 4: Our model’s learned data structure for an instance from the uniform distribution in 2D. While the original order of the stored points showed no structure, the learned data structure arranges points that are close together in the Euclidean plane next to each other.
Refer to caption
Refer to caption
Figure 5: The learned data structure resembles a k-d tree in 2D. We show the average pairwise distances (along the first, second, and both dimensions) between points for the learned structure and the k-d tree, with darker colors indicating smaller distances. For the k-d tree, we arrange the points by in-order traversal. It recursively splits the points into two groups based on whether their value is smaller or larger than the median along a given dimension, alternating between dimensions at each level, starting with dimension 1. The learned data structure approximately mirrors this pattern, splitting by dimension 2 followed by dimension 1.

Beyond one dimension it is less clear how to optimally represent a collection of points as there is no canonical notion of sorting along multiple dimensions. In fact, we observe in these experiments that different data/query distributions lead to altogether different data structures. This reinforces the value in learning both the data structure and query algorithm together end-to-end.

Uniform Distribution

We use a setup similar to 1D, sampling both coordinates independently from the uniform distribution on (1,1)11(-1,1)( - 1 , 1 ). We set N=100𝑁100N=100italic_N = 100 and M=6𝑀6M=6italic_M = 6, and compare to a k-d tree baseline. A k-d tree is a binary tree for organizing points in k-dimensional space, with each node splitting the space along one of the k axes, cycling through the axes at each tree level. Here, our E2E model achieves an accuracy of 75%percent75\!~{}75\%75 % vs 52%percent52\!~{}52\%52 % for the k-d tree (Fig. 10 in App. B). The model outperforms the k-d tree as it can exploit distributional information. By studying the permutations, we find that our model learns to put points that are close together in the 2D plane next to each other in the permuted order (see Fig. 4 for an example).

Hard Distribution

We also consider the case where we sample both coordinates independently from the hard distribution considered in the 1D setup (see Figure 17 for the corresponding error curve). We observe that the data structure learned by our model is surprisingly similar to a k-d tree (see Fig 5). This is striking as a k-d tree is a non-trivial data structure, requiring recursively partitioning the data and finding the median along alternating dimensions at each level of the tree.

2.4 High-dimensional data

Refer to caption
Refer to caption
Refer to caption
Figure 6: (Left) For NN search in higher dimensions (d = 30), the trained models perform comparably to (E2E) or better than (E2E (non-adaptive)) locality-sensitive hashing (LSH) baselines. (Center) When trained with a single query, the model partitions the query space based on projection onto two vectors, similar to LSH. We show the query projection onto the subspace spanned by these vectors and the lookup positions for different queries. (Right) When trained end-to-end to do nearest neighbor search over 3-Digit MNIST Images, our model learns 1D features that capture the relative ordering of the numbers in the images.

High-dimensional NN search poses a challenge for traditional low-dimensional algorithms due to the curse of dimensionality. K-d trees, for instance, can require an exponential number of queries in high dimensions (Kleinberg, 1997). This has led to the development of approximate NN search methods such as locality sensitive hashing (LSH) which have a milder dependence on d𝑑ditalic_d (Andoni et al., 2018), relying on hash functions that map closer points in the space to the same hash bucket.

We train our model on datasets uniformly sampled from the d𝑑ditalic_d-dimensional unit hypersphere. The query is sampled to have a fixed inner-product ρ[0,1]𝜌01\rho\in[0,1]italic_ρ ∈ [ 0 , 1 ] with a dataset point. When ρ=1𝜌1\rho=1italic_ρ = 1, the query matches a data point, making regular hashing-based methods sufficient. For ρ<1𝜌1\rho<1italic_ρ < 1, LSH-based solutions are competitive. We train our model for ρ=0.8𝜌0.8\rho=0.8italic_ρ = 0.8 and compare it to an LSH baseline when N=100,M=6formulae-sequence𝑁100𝑀6N=100,M=6italic_N = 100 , italic_M = 6, and d=30𝑑30d=30italic_d = 30. The LSH baseline partitions 30superscript30\mathbb{R}^{30}blackboard_R start_POSTSUPERSCRIPT 30 end_POSTSUPERSCRIPT using K𝐾Kitalic_K random vectors and buckets each point in the dataset according to its signed projection onto each of the K𝐾Kitalic_K vectors. To retrieve the nearest neighbor of a query point, the baseline maps the query vector to its corresponding bucket and selects the closest vector among M𝑀Mitalic_M candidates (refer App D for more details).

In Figure 6 (Left), we observe that our model performs competitively with LSH baselines.333We exclude LSH baselines with larger K as they under-perform. The non-adaptive model does slightly better as adaptivity is not needed to do well here (e.g., LSH is non-adaptive), and lack of adaptivity likely makes training easier. To better understand the data structure our model learns we consider a smaller setting where N=8𝑁8N=8italic_N = 8 and M=1𝑀1M=1italic_M = 1. We find that the model learns an LSH like solution, partitioning the space by projecting onto two vectors in 30superscript30\mathbb{R}^{30}blackboard_R start_POSTSUPERSCRIPT 30 end_POSTSUPERSCRIPT (see Figure 6 (Center)). We provide more details in App C.3.

Learning useful representations

High-dimensional data often contains low-dimensional structure, such as data lying on a manifold, which can be leveraged to improve the efficiency of NN search. ML models are particularly well-suited to exploit these structures. Here, we explore whether our end-to-end learning framework can learn representations that capture such structures. This is a challenging task as it involves jointly optimizing the learned representation, data structure, and query algorithm.

We consider the following task: given a dataset of distinct 3-digit handwritten number images, and a query image, find its nearest neighbor in the dataset, which corresponds to the image encoding the closest number to the query image (i.e., nearest is defined over the label space).

We generate images of 3-digit numbers by concatenating digits from MNIST (see Figure 13 for image samples). To construct a nearest-neighbor dataset D𝐷Ditalic_D, we sample N=50𝑁50N=50italic_N = 50 labels (each label corresponds to a number) uniformly from 0 to 199. For each label, we then sample one of its associated training images from 3-digit MNIST. Additionally, we sample a query label (uniformly over {0,..,199}\{0,..,199\}{ 0 , . . , 199 }) and its corresponding training image and find its nearest neighbor in D𝐷Ditalic_D, which corresponds to the image with the same label. We emphasize that the model has no label supervision but rather only has access to the query’s nearest neighbor. After training, we evaluate the model using the same data generation process but with images sampled from the 3-digit MNIST test set.

As both the data-processing and query-execution networks should operate over the same low-dimensional representation we train a CNN feature model Fϕsubscript𝐹italic-ϕF_{\phi}italic_F start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT as well. Our setup remains the same as before except now the data-processing network and query-execution network operate on {Fϕ(x1),,Fϕ(xN)}subscript𝐹italic-ϕsubscript𝑥1subscript𝐹italic-ϕsubscript𝑥𝑁\{F_{\phi}(x_{1}),...,F_{\phi}(x_{N})\}{ italic_F start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , … , italic_F start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ) } and Fϕ(q)subscript𝐹italic-ϕ𝑞F_{\phi}(q)italic_F start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT ( italic_q ), respectively. As the underlying distance metric does not correspond to the Euclidean distance, we minimize the cross-entropy loss instead of the MSE loss. Note that the cross-entropy loss only requires supervision about the nearest neighbor of the query, and does not require the exact metric structure, so it can be used even where the exact metric structure is unknown.

Refer to caption
Figure 7: Trained on 3-digit MNIST images, our data-processing model learns to sort the images without explicit supervision for sorting. While we train our model with datasets of size N=50𝑁50N=50italic_N = 50, we show a smaller instance with 5 images for better visualization.

Ideally, the feature model F𝐹Fitalic_F should learn 1d features encoding the relative ordering of the numbers, the data model sorts them, and the query model should do some form of interpolation search where it can use the fact that the data distribution is uniform to do better than binary search. This is almost exactly what all models learn to do, from scratch, in an end-to-end fashion, without any explicit supervision about which image encodes which number. In Figure 6 (Right) we plot the learned features of the model. We find that the data model learns to sort the features (Figure 7) with  98% accuracy and the query model finds the nearest neighbor with almost 100% accuracy (Figure 12).

2.5 Leveraging Extra Space

Refer to caption
Figure 8: For NN search in 1D the model learns to use extra space and outperforms a bucketing baseline.

The previous experiments demonstrate our model’s ability to learn useful orderings for efficient querying. However, data structures can also store additional pre-computed information to speed up querying. For instance, with infinite extra space, a data structure could store the nearest neighbor for every possible query, enabling O(1)𝑂1O(1)italic_O ( 1 ) search. Here, we evaluate if our model can effectively use extra space.

We run an experiment where the data and query distribution are uniform over (1,1)11(-1,1)( - 1 , 1 ) with N=50,M=2formulae-sequence𝑁50𝑀2N=50,M=2italic_N = 50 , italic_M = 2. We allow the data-processing network to output T{0,21,22,23,24,25,26,27}𝑇0superscript21superscript22superscript23superscript24superscript25superscript26superscript27T\in\{0,2^{1},2^{2},2^{3},2^{4},2^{5},2^{6},2^{7}\}italic_T ∈ { 0 , 2 start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , 2 start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , 2 start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT , 2 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT , 2 start_POSTSUPERSCRIPT 5 end_POSTSUPERSCRIPT , 2 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT , 2 start_POSTSUPERSCRIPT 7 end_POSTSUPERSCRIPT } numbers b1,,bTsubscript𝑏1subscript𝑏𝑇b_{1},...,b_{T}\in\mathbb{R}italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_b start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ∈ blackboard_R in addition to the N𝑁Nitalic_N rankings. We plot the NN accuracy as a function of T𝑇Titalic_T in Figure 6 (Right) compared to a simple bucketing baseline. This baseline partitions [1,1]11[-1,1][ - 1 , 1 ] into T𝑇Titalic_T evenly-sized buckets and in each bucket stores argminxjDxjlisubscriptargminsubscript𝑥𝑗𝐷normsubscript𝑥𝑗subscript𝑙𝑖\operatorname*{arg\,min}_{x_{j}\in D}||x_{j}-l_{i}||start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∈ italic_D end_POSTSUBSCRIPT | | italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT - italic_l start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | | where lisubscript𝑙𝑖l_{i}italic_l start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is the midpoint of the segment corresponding to bucket i𝑖iitalic_i. The baseline maps a query to its corresponding bucket and predicts the input stored in that bucket as the nearest-neighbor. Our model’s accuracy monotonically increases with extra space demonstrating that the data-processing network learns to pre-compute useful statistics that enable more efficient querying. We provide some insights into the learned solution in App C.4 and show that our model can be trained to use extra space in the high-dimensional case as well (App C.5).

3 Beyond Nearest Neighbor Search

Many other data structure problems beyond nearest neighbor search can be modeled by our framework. Here, we illustrate this broader applicability by applying the framework to the classical problem of frequency estimation: a memory-constrained model observes a stream of elements, and is subsequently asked to approximate the number of times a query element has appeared (Cormode & Muthukrishnan, 2005; Cormode & Hadjieleftheriou, 2010). In Section 3.2 we describe several other data structure problems that the framework can be applied to.

3.1 Frequency Estimation

Given a sequence of T𝑇Titalic_T elements e(1),,e(T)superscript𝑒1superscript𝑒𝑇e^{(1)},...,e^{(T)}\!italic_e start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT , … , italic_e start_POSTSUPERSCRIPT ( italic_T ) end_POSTSUPERSCRIPT drawn from some universe, the task is to estimate the frequency of a query element eqsubscript𝑒𝑞e_{q}italic_e start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT up until time-step T𝑇Titalic_T. Specifically, we aim to minimize the mean absolute error444We use absolute error as this is the metric commonly used in prior work (Cormode & Muthukrishnan, 2005; Cormode & Hadjieleftheriou, 2010) but our setup works for squared error as well. between the true count and the estimated count. As in the nearest neighbor setup, the two constraints of interest are the size of the data structure and the number of lookups for query execution. Consequently, our framework can be easily adapted to model this problem. We also choose this problem to highlight the versatility of our framework as it can be applied to streaming settings.

Data processing Network

We model the data structure as a k𝑘kitalic_k dimensional vector D^^𝐷\hat{D}over^ start_ARG italic_D end_ARG and use an MLP as the data-processing network which is responsible for writing to D^^𝐷\hat{D}over^ start_ARG italic_D end_ARG. When a new element arrives in our stream, we allow the model to update M values in the data structure. Specifically, when an element arrives at time-step t𝑡titalic_t, the data-processing network outputs M𝑀Mitalic_M k𝑘kitalic_k-dimensional one-hot update position vectors u1,,uMsubscript𝑢1subscript𝑢𝑀u_{1},...,u_{M}italic_u start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_u start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT and M corresponding scalar update values v1,,vMsubscript𝑣1subscript𝑣𝑀v_{1},...,v_{M}italic_v start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_v start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT. We then apply the update, obtaining D^t+1=Dt^+i=1Muivisubscript^𝐷𝑡1^subscript𝐷𝑡superscriptsubscript𝑖1𝑀subscript𝑢𝑖subscript𝑣𝑖\hat{D}_{t+1}=\hat{D_{t}}+\sum_{i=1}^{M}u_{i}*v_{i}over^ start_ARG italic_D end_ARG start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT = over^ start_ARG italic_D start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG + ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_u start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∗ italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. Unlike in the NN setting where we did not constrain the construction complexity of the data structure, here we have limited each update to the data structure to a budget of M𝑀Mitalic_M lookups. We do so as in the streaming settings updates typically occur often, so it is less reasonable to consider them as a one-time construction overhead cost.

Query processing Network

Query processing is handled in a similar fashion to NN search — we have M𝑀Mitalic_M query MLP models that output lookup positions. Finally, we also train a MLP predictor network ψ(v1,,vM)𝜓subscript𝑣1subscript𝑣𝑀\psi(v_{1},...,v_{M})italic_ψ ( italic_v start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_v start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT ) that takes in the M𝑀Mitalic_M values retrieved from the lookups and outputs the final prediction.

Experiments

Refer to caption
Figure 9: When estimating frequencies of elements drawn from a randomly ordered Zipfian distribution, our model outperforms the CountMinSketch baseline given 1, 2, and 4 queries.
Zipfian Distribution

We evaluate our model in a setting where both the stream and query distributions follow a Zipfian distribution. This simulates a common feature of frequency-estimation datasets where a few “heavy hitter” elements are updated or queried more frequently than others (Hsu et al., 2019). For any fixed training instance, the rank order of the elements in the domain is consistent across both the stream and query distributions, but it is randomized across different training instances. As a result, the model cannot rely on knowing which specific elements are more frequent than others; only the overall Zipfian skew is consistent across training instances.

We use a data structure of size k=32𝑘32k=32italic_k = 32 and train our model with M{1,2,4}𝑀124M\in\{1,2,4\}italic_M ∈ { 1 , 2 , 4 } queries. Both the data and query distributions are Zipfian over {1,,1000}11000\{1,...,1000\}{ 1 , … , 1000 } with a fixed skew of α=1.2𝛼1.2\alpha=1.2italic_α = 1.2. We evaluate the mean absolute error over streams of length 100 and compare with the CountMinSketch algorithm, a hashing-based method for frequency estimation (Cormode & Muthukrishnan, 2005) (See App. E for an overview). Our model’s performance improves with more queries and outperforms CountMinSketch (Figure 9). In this case, CountMinSketch degrades with more queries as for a fixed size memory (k=32𝑘32k=32italic_k = 32), it is more effective for this distribution to apply a single hash function over the whole memory than to split the memory into k𝑘kitalic_k partitions of size k/M𝑘𝑀k/Mitalic_k / italic_M and use separate hash functions. We look at the learned algorithm in more detail and find that our model learns an algorithm similar to CountMinSketch, but with an important difference: it uses an update delta of less than 1 when a new item arrives, instead of the delta of 1 used by CountMinSketch. We find that this can be particularly useful when the size of the data structure is small and collisions are frequent. We hypothesize that the better performance of the learned solution is at least partially due to the smaller delta. In Figure 22, we show that we are able to recover the performance of our learned data structure with 1 and 2 queries when we use a smaller delta.

Learning Heavy Hitter Features

In the previous experiment, the Zipfian distribution shape was fixed across training instances but the rank ordering of elements was random. In some settings, however, it may be possible to predict which elements are more likely to occur in the stream. While the exact elements may vary between streams, frequently occurring items could share features across streams. For instance, Hsu et al. (2019) show that in frequency estimation for network flows, certain types of IP addresses receive much more traffic than others. We simulate such a setting by fixing the rank ordering of the Zipfian distribution. However, instead of using a universe of integer elements {1,,K}1𝐾\{1,...,K\}{ 1 , … , italic_K }, we instead use their corresponding 3-digit MNIST images with K=100𝐾100K=100italic_K = 100 (constructed as in the MNIST NN experiment). Given a stream of integers, we map them to their corresponding MNIST labels and then for each label we sample a corresponding image from the training set. During evaluation, we use images samples from the test set. As the distribution is skewed and the ranking is fixed, images with smaller numbers are sampled much more frequently than those with larger numbers. As in the MNIST NN experiment, we also use a feature-learning CNN model to process the images before passing them to the data-processing and query-execution networks.

We compare our model to CountMinSketch with 1-query that is given the underlying labels instead of the images. Our model has a significantly lower error than the baseline (0.15 vs 2.81 averaged over a stream of size 100 (see Fig. 23)) as the latter is distribution-independent. By training from the data-distribution end-to-end, our framework is able to simultaneously learn features of heavy hitters (in this case, clustering images with the same label) and use this information to design an efficient frequency estimation data structure. We investigate the learned structure and find that the model has reserved separate memory positions for heavy hitters, thereby preventing collisions (Fig. 24).

3.2 Other potential applications

Here, we outline several other potential applications of our framework to facilitate future work.

Graph data structures: Many graph-related problems require an efficient representation to support connectivity or distance queries between vertices. For distance queries, one approach is to use quadratic space to store the distances between all vertex pairs, allowing O(1)𝑂1O(1)italic_O ( 1 ) query time. Alternatively, one could use no extra space and simply store the graph (which may require significantly less than quadratic space) and run a shortest-path algorithm at query time. The challenge is to find a middle ground: using sub-quadratic space while still answering distance queries faster than a full shortest-path computation (Thorup & Zwick, 2005).

Sparse matrices: Another common problem that can be framed as a data structure problem is that of compressing sparse matrices. Given an M×N𝑀𝑁M\times Nitalic_M × italic_N matrix, on one hand, one could store the full matrix and access elements in O(1)𝑂1O(1)italic_O ( 1 ) time. However, depending on the number and distribution of 0s in the matrix, different data structures could be designed that use less than O(MN)𝑂𝑀𝑁O(MN)italic_O ( italic_M italic_N ) space. There is an inherent trade-off between how compressed the representation is and the time required to access elements of the matrix to solve various linear algebraic tasks involving the matrix such as matrix-vector multiplication (Buluç et al., 2011; Chakraborty et al., 2018).

Learning statistical models: Our framework can also handle problems such as learning statistical models like decision trees, where the input to the data-processing network is a training dataset, and the output is a model such as a decision tree. The query algorithm would then access a subset of the model at inference time, such as by doing a traversal on the nodes of the decision tree. This could be used to explore questions around optimal algorithms and heuristics for learning decision tress, which are not properly understood (Blanc et al., 2021; 2022).

4 Related Work

Learning-Augmented Algorithms

Recent work has shown that traditional data structures and algorithms can be made more efficient by learning properties of the underlying data distribution. Kraska et al. (2018) introduced the concept of learned index structures which use ML models to replace traditional index structures in databases, resulting in significant performance improvements for certain query workloads. By learning the cumulative distribution function of the data distribution the model has a stronger prior over where to start the search for a record, which can lead to provable improvements to the query time over non-learned structures (Zeighami & Shahabi, 2023). Other works augment the data structure with predictions instead of the query algorithm. For example, Lin et al. (2022a) use learned frequency estimation oracles to estimate the priority in which elements should be stored in a treap. Perhaps more relevant to the theme of our work is Dong et al. (2019), which trains neural networks to learn a partitioning of the space for efficient nearest neighbor search using locality sensitive hashing, and the body of work on learned hash functions (Wang et al., 2015; Sabek et al., 2022). While all these works focus on augmenting data structure design with learning, we explore whether data structures can be discovered entirely end-to-end using deep learning.

Neural Algorithmic Learners

There is a significant body of work on encoding algorithms into deep networks. Graves et al. (2014) introduced the Neural Turing Machine (NTM), which uses external memory to learn tasks like sorting and copying. Veličković et al. (2019) used graph neural networks (GNNs) to encode classical algorithms such as breadth-first search. These works train deep networks with a great degree of supervision with the aim of encoding known algorithms. For instance, Graves et al. (2014) use the ground truth sorted list as supervision to train the model to sort. There has also been work on learning algorithms in an end-to-end fashion. Fawzi et al. (2022) train a model using reinforcement learning to discover matrix multiplication algorithms, while Selsam et al. (2018) train neural networks to solve SAT problems. Garg et al. (2022) show that transformers can be trained to encode learning algorithms for function classes such as linear functions and decision trees. Our work adds to this line of research on end-to-end learning, focusing on discovering data structures.

5 Conclusion

We began with the question of whether deep learning models can be trained to discover data structures from scratch. This work provides initial evidence that it is possible. For both nearest neighbor search and frequency estimation, the models—trained end-to-end—discover distribution-dependent data structures that outperform worst-case baselines. We hope this research inspires further exploration into data structure and algorithm discovery.

One limitation that future research could address is scale. Due to computational constraints, most of our experiments are conducted with datasets of size N=100𝑁100N=100italic_N = 100, although in App. F we scale some to N=500𝑁500N=500italic_N = 500. While this scale can be sufficient for gaining insights into data structure design, practical end-to-end use would require further scaling. We believe both larger models and better inductive biases could enable scaling up further (see App. F for details).

6 Acknowledgements

OS and LC acknowledge the support of the CIFAR AI Chair program. This research was also enabled in part by compute resources provided by Mila – Quebec AI Institute (mila.quebec). GV is supported by a Simons Foundation Investigator Award, NSF award AF-2341890 and UT Austin’s Foundations of ML NSF AI Institute. VS was supported by NSF CAREER Award CCF-2239265 and an Amazon Research Award.

References

  • Akyürek et al. (2022) Ekin Akyürek, Dale Schuurmans, Jacob Andreas, Tengyu Ma, and Denny Zhou. What learning algorithm is in-context learning? investigations with linear models. arXiv preprint arXiv:2211.15661, 2022.
  • Andoni et al. (2018) Alexandr Andoni, Piotr Indyk, and Ilya Razenshteyn. Approximate nearest neighbor search in high dimensions. In Proceedings of the International Congress of Mathematicians: Rio de Janeiro 2018, pp.  3287–3318. World Scientific, 2018.
  • Arya et al. (1998) Sunil Arya, David M Mount, Nathan S Netanyahu, Ruth Silverman, and Angela Y Wu. An optimal algorithm for approximate nearest neighbor searching fixed dimensions. Journal of the ACM (JACM), 45(6):891–923, 1998.
  • Ba et al. (2016) Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E Hinton. Layer normalization. arXiv preprint arXiv:1607.06450, 2016.
  • Blanc et al. (2021) Guy Blanc, Jane Lange, Mingda Qiao, and Li-Yang Tan. Decision tree heuristics can fail, even in the smoothed setting. arXiv preprint arXiv:2107.00819, 2021.
  • Blanc et al. (2022) Guy Blanc, Jane Lange, Mingda Qiao, and Li-Yang Tan. Properly learning decision trees in almost polynomial time. Journal of the ACM, 69(6):1–19, 2022.
  • Buluç et al. (2011) Aydın Buluç, John Gilbert, and Viral B Shah. Implementing sparse matrices for graph algorithms. Graph Algorithms in the Language of Linear Algebra, pp.  287–313, 2011.
  • Chakraborty et al. (2018) Diptarka Chakraborty, Lior Kamma, and Kasper Green Larsen. Tight cell probe bounds for succinct boolean matrix-vector multiplication. In Proceedings of the 50th Annual ACM SIGACT Symposium on Theory of Computing, pp.  1297–1306, 2018.
  • Cormode & Hadjieleftheriou (2010) Graham Cormode and Marios Hadjieleftheriou. Methods for finding frequent items in data streams. The VLDB Journal, 19:3–20, 2010.
  • Cormode & Muthukrishnan (2005) Graham Cormode and Shan Muthukrishnan. An improved data stream summary: the count-min sketch and its applications. Journal of Algorithms, 55(1):58–75, 2005.
  • Cuturi et al. (2019) Marco Cuturi, Olivier Teboul, and Jean-Philippe Vert. Differentiable ranking and sorting using optimal transport. Advances in neural information processing systems, 32, 2019.
  • Ding et al. (2020) Jialin Ding, Umar Farooq Minhas, Jia Yu, Chi Wang, Jaeyoung Do, Yinan Li, Hantian Zhang, Badrish Chandramouli, Johannes Gehrke, Donald Kossmann, et al. Alex: an updatable adaptive learned index. In Proceedings of the 2020 ACM SIGMOD International Conference on Management of Data, pp.  969–984, 2020.
  • Dong et al. (2019) Yihe Dong, Piotr Indyk, Ilya Razenshteyn, and Tal Wagner. Learning space partitions for nearest neighbor search. arXiv preprint arXiv:1901.08544, 2019.
  • Fawzi et al. (2022) Alhussein Fawzi, Matej Balog, Aja Huang, Thomas Hubert, Bernardino Romera-Paredes, Mohammadamin Barekatain, Alexander Novikov, Francisco J R Ruiz, Julian Schrittwieser, Grzegorz Swirszcz, et al. Discovering faster matrix multiplication algorithms with reinforcement learning. Nature, 610(7930):47–53, 2022.
  • Fu et al. (2023) Deqing Fu, Tian-Qi Chen, Robin Jia, and Vatsal Sharan. Transformers learn higher-order optimization methods for in-context learning: A study with linear models. arXiv preprint arXiv:2310.17086, 2023.
  • Garg et al. (2022) Shivam Garg, Dimitris Tsipras, Percy S Liang, and Gregory Valiant. What can transformers learn in-context? a case study of simple function classes. Advances in Neural Information Processing Systems, 35:30583–30598, 2022.
  • Graves et al. (2014) Alex Graves, Greg Wayne, and Ivo Danihelka. Neural turing machines. arXiv preprint arXiv:1410.5401, 2014.
  • Grover et al. (2019) Aditya Grover, Eric Wang, Aaron Zweig, and Stefano Ermon. Stochastic optimization of sorting networks via continuous relaxations. In International Conference on Learning Representations, 2019. URL https://openreview.net/forum?id=H1eSS3CcKX.
  • Hsu et al. (2019) Chen-Yu Hsu, Piotr Indyk, Dina Katabi, and Ali Vakilian. Learning-based frequency estimation algorithms. In International Conference on Learning Representations, 2019. URL https://openreview.net/forum?id=r1lohoCqY7.
  • Jang et al. (2017) Eric Jang, Shixiang Gu, and Ben Poole. Categorical reparameterization with gumbel-softmax. 2017. URL https://arxiv.org/abs/1611.01144.
  • Karpathy (2024) Andrej Karpathy. nanogpt. https://github.com/karpathy/nanoGPT, 2024. Accessed: 2024-05-28.
  • Kingma & Ba (2017) Diederik P. Kingma and Jimmy Ba. Adam: A method for stochastic optimization, 2017.
  • Kleinberg (1997) Jon M Kleinberg. Two algorithms for nearest-neighbor search in high dimensions. In Proceedings of the twenty-ninth annual ACM symposium on Theory of computing, pp.  599–608, 1997.
  • Kraska et al. (2018) Tim Kraska, Alex Beutel, Ed H Chi, Jeffrey Dean, and Neoklis Polyzotis. The case for learned index structures. In Proceedings of the 2018 international conference on management of data, pp.  489–504, 2018.
  • Lin et al. (2022a) Honghao Lin, Tian Luo, and David Woodruff. Learning augmented binary search trees. In International Conference on Machine Learning, pp.  13431–13440. PMLR, 2022a.
  • Lin et al. (2022b) Honghao Lin, Tian Luo, and David P. Woodruff. Learning augmented binary search trees, 2022b. URL https://arxiv.org/abs/2206.12110.
  • Lykouris & Vassilvitskii (2018) Thodoris Lykouris and Sergei Vassilvitskii. Better caching with machine learned advice. 2018.
  • Mitzenmacher & Vassilvitskii (2022) Michael Mitzenmacher and Sergei Vassilvitskii. Algorithms with predictions. Communications of the ACM, 65(7):33–35, 2022.
  • Nair & Hinton (2010) Vinod Nair and Geoffrey E Hinton. Rectified linear units improve restricted boltzmann machines. In Proceedings of the 27th international conference on machine learning (ICML-10), pp.  807–814, 2010.
  • Paszke et al. (2017) Adam Paszke, Sam Gross, Soumith Chintala, Gregory Chanan, Edward Yang, Zachary DeVito, Zeming Lin, Alban Desmaison, Luca Antiga, and Adam Lerer. Automatic differentiation in pytorch. 2017.
  • Petersen et al. (2022) Felix Petersen, Christian Borgelt, Hilde Kuehne, and Oliver Deussen. Monotonic differentiable sorting networks, 2022.
  • Peterson (1957) W. W. Peterson. Addressing for random-access storage. IBM Journal of Research and Development, 1(2):130–146, 1957. doi: 10.1147/rd.12.0130.
  • Sabek et al. (2022) Ibrahim Sabek, Kapil Vaidya, Dominik Horn, Andreas Kipf, Michael Mitzenmacher, and Tim Kraska. Can learned models replace hash functions? Proceedings of the VLDB Endowment, 16(3), 2022.
  • Selsam et al. (2018) Daniel Selsam, Matthew Lamm, Benedikt Bünz, Percy Liang, Leonardo de Moura, and David L Dill. Learning a SAT solver from single-bit supervision. arXiv preprint arXiv:1802.03685, 2018.
  • Thorup & Zwick (2005) Mikkel Thorup and Uri Zwick. Approximate distance oracles. Journal of the ACM (JACM), 52(1):1–24, 2005.
  • Veličković et al. (2019) Petar Veličković, Rex Ying, Matilde Padovano, Raia Hadsell, and Charles Blundell. Neural execution of graph algorithms. arXiv preprint arXiv:1910.10593, 2019.
  • Von Oswald et al. (2023) Johannes Von Oswald, Eyvind Niklasson, Ettore Randazzo, João Sacramento, Alexander Mordvintsev, Andrey Zhmoginov, and Max Vladymyrov. Transformers learn in-context by gradient descent. In International Conference on Machine Learning, pp.  35151–35174. PMLR, 2023.
  • Wang et al. (2015) Jun Wang, Wei Liu, Sanjiv Kumar, and Shih-Fu Chang. Learning to hash for indexing big data—a survey. Proceedings of the IEEE, 104(1):34–57, 2015.
  • Zeighami & Shahabi (2023) Sepanta Zeighami and Cyrus Shahabi. On distribution dependent sub-logarithmic query time of learned indexing. In International Conference on Machine Learning, pp.  40669–40680. PMLR, 2023.

Appendix

Appendix A Training Details

A.1 Nearest Neighbors

The transformer in the data-processing network is based on the NanoGPT architecture (Karpathy, 2024) and has 8 layers with 8 heads each and an embedding size of 64. Each query model Qisuperscript𝑄𝑖Q^{i}italic_Q start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT is a 3-layer MLP with a hidden dimension of size 1024. Each hidden layer consists of a linear mapping followed by LayerNorm (Ba et al., 2016) and the ReLU activation function (Nair & Hinton, 2010). In all experiments we use a batch size of 1024, 1e-3 weight decay and the Adam optimizer (Kingma & Ba, 2017) with default PyTorch (Paszke et al., 2017) settings. We do a grid search over {0.0001,0.00001,0.00005}0.00010.000010.00005\{0.0001,0.00001,0.00005\}{ 0.0001 , 0.00001 , 0.00005 } to find the best learning rate for both models. All models are trained for at most 4 million gradient steps with early-stopping. We apply the Gumbel Softmax (Jang et al., 2017) with a temperature of 2222 to the lookup vectors to encourage sparsity. All experiments are run on a single NVIDIA RTX8000 GPU.

A.2 Frequency Estimation

We follow the same setup as the nearest neighbors training except for frequency estimation, the data-processing network is a 3-layer MLP with a hidden dimension of size 1024. We do a grid search over {0.0001,0.00005,0.00001}0.00010.000050.00001\{0.0001,0.00005,0.00001\}{ 0.0001 , 0.00005 , 0.00001 } to find the best learning rate for both models. Models are trained for 200k gradient steps with early stopping. All experiments are run on a single NVIDIA RTX8000 GPU.

Appendix B Additional Nearest Neighbor Performance Plots

Refer to caption
Figure 10: 2D Uniform Accuracy.
Refer to caption
Figure 11: Mean square error plots for (Left) 1D Uniform distribution, (Center) 2D Uniform distribution, (Right) 30D Uniform distribution over unit hyper-sphere.
Refer to caption
Figure 12: 3-Digit MNIST Nearest Neighbors Accuracy. Even though binary search (over the underlying digits) is an unfair comparison, we include it as a reference to compare our model’s performance with.
Refer to caption
Figure 13: Samples from 3-Digit MNIST

B.1 Hard Distribution

To generate data from the hard distribution, we first sample the element at the 50th percentile from the uniform distribution over a large range. We then sample the 25th and 75th percentile elements from a smaller range and so on. The intuition behind this distribution is to reduce concentration such that p(NN|q)𝑝conditional𝑁𝑁𝑞p(NN|q)italic_p ( italic_N italic_N | italic_q ) is roughly uniform where NN𝑁𝑁NNitalic_N italic_N denotes the index of the nearest-neighbor of q𝑞qitalic_q in the sorted list.

Precisely, to sample N𝑁Nitalic_N points from the hard distribution we generate a random balanced binary tree of size N𝑁Nitalic_N. All vertices are random variables of the form Uniform(0,alognk)𝑈𝑛𝑖𝑓𝑜𝑟𝑚0superscript𝑎𝑛𝑘Uniform(0,a^{\log n-k})italic_U italic_n italic_i italic_f italic_o italic_r italic_m ( 0 , italic_a start_POSTSUPERSCRIPT roman_log italic_n - italic_k end_POSTSUPERSCRIPT ) where a𝑎aitalic_a is some constant and k𝑘kitalic_k is the level in the tree that the vertice belongs to. If the ith𝑖𝑡i-thitalic_i - italic_t italic_h node in the tree is the left-child of its parent, we generate the point xisubscript𝑥𝑖x_{i}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT as xi=xp(i)disubscript𝑥𝑖subscript𝑥𝑝𝑖subscript𝑑𝑖x_{i}=x_{p(i)}-d_{i}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_x start_POSTSUBSCRIPT italic_p ( italic_i ) end_POSTSUBSCRIPT - italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT where p(i)𝑝𝑖p(i)italic_p ( italic_i ) denotes the parent of the ith𝑖𝑡i-thitalic_i - italic_t italic_h node and disubscript𝑑𝑖d_{i}italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is a sample from node i𝑖iitalic_i of the random binary tree. Similarly, if node i𝑖iitalic_i is the right child of its parent, xi=xp(i)+disubscript𝑥𝑖subscript𝑥𝑝𝑖subscript𝑑𝑖x_{i}=x_{p(i)}+d_{i}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_x start_POSTSUBSCRIPT italic_p ( italic_i ) end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. For the root element x0=d0subscript𝑥0subscript𝑑0x_{0}=d_{0}italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. In our experiments we set a=7𝑎7a=7italic_a = 7. The larger the value of a𝑎aitalic_a, the greater the degree of anti-concentration. We found it challenging to train models with N>16𝑁16N>16italic_N > 16 as the range of values that xisubscript𝑥𝑖x_{i}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT can take increases with N𝑁Nitalic_N. Thus for larger N𝑁Nitalic_N, the model needs to deal with numbers at several scales, making learning challenging.

Refer to caption
Figure 14: Our model’s performance is closely aligned with binary search on the hard distribution in 1D. By design, this distribution does not have a useful prior our model can exploit and so it learns a binary search like solution.
Refer to caption
Figure 15: The positional distribution per lookup in the 1D Hard experiment. Our model closely aligns with binary search, first looking at the middle element, then (approximately) either the 25th or 75th percentile elements, and so on.
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 16: Binary Search vs. our model on the hard distribution in 1D.
Refer to caption
Figure 17: On the 2D hard distribution our model roughly tracks the performance of a k-d tree.

Appendix C Additional Experiment Findings

C.1 Noise Injection for Lookup Sparsity

We find that adding noise prior to applying the soft-max on the lookup vector misubscript𝑚𝑖m_{i}italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT leads to sparser queries. We hypothesize that this is because the noise injection forces the model to learn a noise-robust solution which corresponds to a sparse solution. Consider a simplified setup in 1D where the query model is not conditioned on q𝑞qitalic_q and is only allowed one lookup (M=1𝑀1M=1italic_M = 1) and D𝐷Ditalic_D is a sorted list of three elements: D=[x1,x2,x3]𝐷subscript𝑥1subscript𝑥2subscript𝑥3D=[x_{1},x_{2},x_{3}]italic_D = [ italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ]. For a given query q𝑞qitalic_q and its nearest neighbor y𝑦yitalic_y, the query-execution network is trying to find the optimal vector m^3^𝑚superscript3\hat{m}\in\mathbb{R}^{3}over^ start_ARG italic_m end_ARG ∈ blackboard_R start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT that minimizes ymTD22superscriptsubscriptnorm𝑦superscript𝑚𝑇𝐷22||y-m^{T}D||_{2}^{2}| | italic_y - italic_m start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_D | | start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT where m=softmax(m^+ϵ),ϵm=softmax(\hat{m}+\epsilon),\epsilon\simitalic_m = italic_s italic_o italic_f italic_t italic_m italic_a italic_x ( over^ start_ARG italic_m end_ARG + italic_ϵ ) , italic_ϵ ∼ Gumbel distribution Jang et al. (2017). Given that M=1𝑀1M=1italic_M = 1, the model cannot always make enough queries to identify y𝑦yitalic_y and so in the absence of noise the model may try to predict the ’middle’ element by setting m^1=m^2=m^3subscript^𝑚1subscript^𝑚2subscript^𝑚3\hat{m}_{1}=\hat{m}_{2}=\hat{m}_{3}over^ start_ARG italic_m end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = over^ start_ARG italic_m end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = over^ start_ARG italic_m end_ARG start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT. However, when noise is added to the logits m^^𝑚\hat{m}over^ start_ARG italic_m end_ARG this solution is destabilized. Instead, in the presence of noise, the model can robustly select the middle element by making m^2subscript^𝑚2\hat{m}_{2}over^ start_ARG italic_m end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT much greater than m^1,m^3subscript^𝑚1subscript^𝑚3\hat{m}_{1},\hat{m}_{3}over^ start_ARG italic_m end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , over^ start_ARG italic_m end_ARG start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT. We test this intuition by running this experiment for large values of N𝑁Nitalic_N and find that with noise the average gradient is much larger for m^N/2subscript^𝑚𝑁2\hat{m}_{N/2}over^ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_N / 2 end_POSTSUBSCRIPT.

C.2 2D Uniform Distribution

Refer to caption
Refer to caption
Refer to caption
Figure 18: k-d search vs. our model on the uniform distribution in 2D. Unlike the k-d tree, our model has a stronger prior over where to begin its search.

C.3 N=8, M=1 30D Experiment

To determine if our model has learned an LSH-like solution, we try to reverse engineer the query model in a simple setting where N=8𝑁8N=8italic_N = 8 and M=1𝑀1M=1italic_M = 1. The query-execution model is only allowed one lookup. We fit 8 one-vs-rest logistic regression classifiers using queries sampled from the query distribution and the output of the query model (lookup position) as features and labels, respectively. We then do PCA on the set of 8 classifier coefficients. We find that the top 2 principal components explain all of the variance which suggests that the query model’s mapping can be explained by the projection onto these two components. In Figure 19 we plot the projection of queries onto these components and color them based on the position they were assigned by the query model. We do the same for inputs xiDsubscript𝑥𝑖𝐷x_{i}\in Ditalic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ italic_D and color them by the position they were permuted to. The plot on the right suggests that the data-processing network permutes the input vectors based on their projection onto these two components. This assignment is noisy because there may be multiple inputs in a dataset that map to the same bucket and because the model can only store a permutation, some buckets experience overflow. Similarly, the query model does a lookup in the position that corresponds to the query vector’s bucket. This behaviour suggests the model has learned a locality-sensitive hashing type solution!

Refer to caption
Figure 19: (Left) Projection of queries onto top two PCA components of the decision boundaries of the query model, colored by the lookup position the query is mapped to. (Right) Projection of inputs onto the same PCA components colored by the position the data-processing model places them in. Both the data-processing and query models map similar regions to the same positions, suggesting an LSH-like bucketing solution has been learned.

C.4 1D Extra Space

Refer to caption
Refer to caption
Figure 20: (Top) Decision boundary of the first query model. (Bottom) The regression coefficients of the values stored in extra positions as a linear function of the (sorted) inputs.

C.4.1 Bucket Baseline

We create a simple bucket baseline that partitions [1,1]11[-1,1][ - 1 , 1 ] into T𝑇Titalic_T evenly sized buckets. In each bucket bisubscript𝑏𝑖b_{i}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT we store argminxjDxjli𝑎𝑟𝑔𝑚𝑖subscript𝑛subscript𝑥𝑗𝐷normsubscript𝑥𝑗subscript𝑙𝑖argmin_{x_{j}\in D}||x_{j}-l_{i}||italic_a italic_r italic_g italic_m italic_i italic_n start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∈ italic_D end_POSTSUBSCRIPT | | italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT - italic_l start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | | where lisubscript𝑙𝑖l_{i}italic_l start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is the midpoint of the segment partitioned in bisubscript𝑏𝑖b_{i}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. This baseline maps a query to its corresponding bucket and predicts the input stored in that bucket as the nearest-neighbor. As T𝑇T\to\inftyitalic_T → ∞ this becomes an optimal hashing-like solution.

C.4.2 Understanding Extra Space Usage

By analyzing the lookup patterns of the first query model, we can better understand how the model uses extra space. In Figure 20 we plot the decision boundary of the first query model. The plot demonstrates that the model chunks the query space ([1,1])11([-1,1])( [ - 1 , 1 ] ) into different buckets. To get a sense of what the model stores in the extra space, we fit a linear function on the sorted inputs and regress the values stored in each of the extra space tokens bisubscript𝑏𝑖b_{i}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and plot the coefficients for several of the extra spaces in Figure 20. For a given subset of the query range, the value stored at its corresponding extra space is approximately a weighted sum of the values stored at the indices that correspond to the percentile of that query range subset. This is useful information as it tells the model for a given query percentile how ’shifted’ the values in the current dataset stored in the corresponding indices are from model’s prior.

C.5 30D Extra Space

Refer to caption
Figure 21: Our unconstrained model (E2E) and a more interpretable version (E2E (Coefficients)) both learn to effectively leverage an increasing amount of extra space in 30D, with the unconstrained model outperforming an LSH baseline.

In high-dimensions it is less clear what solutions there are to effectively leverage extra space, and in fact understanding optimal tradeoffs in this case is open theoretically (Arya et al., 1998).

We follow a similar setup to the 1D extra space experiments but use the data and query distributions from section 2.4. We experiment with two versions of extra space (unrestricted) and (coefficients). For the unrestricted version the data model can store whatever 30 dimensional vector it chooses in each of the extra spaces. For the coefficient model, instead of outputting a 30 dimensional vector, for each extra space, the model outputs a separate N dimensional vector of coefficients. We then take a linear combination of the (permuted) input dataset using these coefficients and store the resulting vector in the corresponding extra positions. While the unrestricted version is more expressive the coefficient version is more interpretable. We include both versions to demonstrate the versatility of our framework. If one is only interested in identifying a strong lower-bound of how well one can use a fixed budget of extra space they may use the unrestricted model. However, if they are more concerned with investigating specific classes of solutions or would like a greater degree of interpretability they can easily augment the model with additional inductive biases such as linear coefficients.

We plot the performance of both models along with an LSH baseline in Figure 21. While both models perform competitively with an LSH baseline and can effectively leverage an increasing amount of extra space, the unrestricted model outperforms the coefficient model at a certain point.

C.6 Frequency Estimation

Refer to caption
Figure 22: We apply the insight that we learned from our E2E model to improve CountMinSketch on the Zipfian distribution. By changing the CountMinSketch update delta of 1 to our model’s learned delta (Δ=0.87Δ0.87\Delta=0.87roman_Δ = 0.87 for M=1𝑀1M=1italic_M = 1 and Δ=0.93Δ0.93\Delta=0.93roman_Δ = 0.93 for M=2𝑀2M=2italic_M = 2), we can improve the performance of CountMinSketch on M{1,2}𝑀12M\in\{1,2\}italic_M ∈ { 1 , 2 } queries.
Refer to caption
Figure 23: On the MNIST heavy-hitters frequency estimation experiment, our model significantly outperforms CountMinSketch. This is because our model can learn features predictive of heavy hitters, as opposed to the distribution-agnostic CountMinSketch.
Refer to caption
Figure 24: We show the decision boundary learned by the query/data-processing network in the MNIST heavy hitters experiment. As images with smaller numbers occur more frequently in the stream, the memory-constrained model learns to reserve separate memory positions for these items in order to prevent collisions among them.

Appendix D LSH Baseline

Our LSH baseline samples K𝐾Kitalic_K random vectors 𝐫𝟏,,𝐫𝐊subscript𝐫1subscript𝐫𝐊\mathbf{r_{1}},...,\mathbf{r_{K}}bold_r start_POSTSUBSCRIPT bold_1 end_POSTSUBSCRIPT , … , bold_r start_POSTSUBSCRIPT bold_K end_POSTSUBSCRIPT from the standard normal distribution in dsuperscript𝑑\mathbb{R}^{d}blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT. For a given vector 𝐯d𝐯superscript𝑑\mathbf{v}\in\mathbb{R}^{d}bold_v ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT, its hash code is computed as hash(𝐯)=[sign(𝐯𝐓𝐫𝟏),,sign(𝐯𝐓𝐫𝐊)]𝑎𝑠𝐯𝑠𝑖𝑔𝑛superscript𝐯𝐓subscript𝐫1𝑠𝑖𝑔𝑛superscript𝐯𝐓subscript𝐫𝐊hash(\mathbf{v})=[sign(\mathbf{v^{T}r_{1}}),...,sign(\mathbf{v^{T}r_{K}})]italic_h italic_a italic_s italic_h ( bold_v ) = [ italic_s italic_i italic_g italic_n ( bold_v start_POSTSUPERSCRIPT bold_T end_POSTSUPERSCRIPT bold_r start_POSTSUBSCRIPT bold_1 end_POSTSUBSCRIPT ) , … , italic_s italic_i italic_g italic_n ( bold_v start_POSTSUPERSCRIPT bold_T end_POSTSUPERSCRIPT bold_r start_POSTSUBSCRIPT bold_K end_POSTSUBSCRIPT ) ]. In total, there are 2Ksuperscript2𝐾2^{K}2 start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT possible hash codes. To create a hash table, we assign each hash code a bucket of size N/2K𝑁superscript2𝐾N/2^{K}italic_N / 2 start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT. For a given dataset D={x1,,xN}𝐷subscript𝑥1subscript𝑥𝑁D=\{x_{1},...,x_{N}\}italic_D = { italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_x start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT }, we place each input in its corresponding bucket (determined by its hash code hash(xi)𝑎𝑠subscript𝑥𝑖hash(x_{i})italic_h italic_a italic_s italic_h ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ). If the bucket is full, we place xisubscript𝑥𝑖x_{i}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT in a vacant bucket chosen at random. Given a query q𝑞qitalic_q and a budget of M𝑀Mitalic_M lookups, the baseline retrieves the first M𝑀Mitalic_M vectors in the bucket corresponding to hash(q)𝑎𝑠𝑞hash(q)italic_h italic_a italic_s italic_h ( italic_q ). If there are less than M𝑀Mitalic_M vectors in the bucket, we choose the remaining vectors at random from other buckets. We design this setup like so to closely align with the constraints of our model (i.e. only learning a permutation).

Appendix E CountMinSketch

CountMinSketch (Cormode & Muthukrishnan, 2005) is a probabilistic data structure used for estimating the frequency of items in a data stream with sublinear space. It uses a two-dimensional array of counters and multiple independent hash functions to map each item to several buckets. When a new item x𝑥xitalic_x arrives, the algorithm computes d𝑑ditalic_d hash functions h1(x),h2(x),,hd(x)subscript1𝑥subscript2𝑥subscript𝑑𝑥h_{1}(x),h_{2}(x),\dots,h_{d}(x)italic_h start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x ) , italic_h start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_x ) , … , italic_h start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ), each of which maps the item to one of w𝑤witalic_w buckets in different rows of the array. The counters in the corresponding buckets are incremented by 1111. To estimate the frequency of an item x𝑥xitalic_x, the minimum value across all counters C[1,h1(x)],C[2,h2(x)],,C[d,hd(x)]𝐶1subscript1𝑥𝐶2subscript2𝑥𝐶𝑑subscript𝑑𝑥C[1,h_{1}(x)],C[2,h_{2}(x)],\dots,C[d,h_{d}(x)]italic_C [ 1 , italic_h start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x ) ] , italic_C [ 2 , italic_h start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_x ) ] , … , italic_C [ italic_d , italic_h start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ) ] is returned. The sketch guarantees that the estimated frequency f^(x)^𝑓𝑥\hat{f}(x)over^ start_ARG italic_f end_ARG ( italic_x ) of an item x𝑥xitalic_x is at least its true frequency f(x)𝑓𝑥f(x)italic_f ( italic_x ), and at most f(x)+ϵN𝑓𝑥italic-ϵ𝑁f(x)+\epsilon Nitalic_f ( italic_x ) + italic_ϵ italic_N, where N𝑁Nitalic_N is the total number of items processed, ϵ=1witalic-ϵ1𝑤\epsilon=\frac{1}{w}italic_ϵ = divide start_ARG 1 end_ARG start_ARG italic_w end_ARG, and w𝑤witalic_w is the width of the sketch. The probability that the estimate exceeds this bound is at most δ=1d𝛿1𝑑\delta=\frac{1}{d}italic_δ = divide start_ARG 1 end_ARG start_ARG italic_d end_ARG, where d𝑑ditalic_d is the depth of the sketch (i.e., the number of hash functions). These guarantees hold even in the presence of hash collisions, providing strong worst-case accuracy with 𝒪(wd)𝒪𝑤𝑑\mathcal{O}(w\cdot d)caligraphic_O ( italic_w ⋅ italic_d ) space.

Appendix F Limitations and Future Work

Refer to caption
Refer to caption
Refer to caption
Figure 25: We scale both the 1D (Left) and 30D (Center) experiments to datasets of size N=500𝑁500N=500italic_N = 500. (Right) We compare our E2E model with a version where the query-execution network is only composed of one query-model (E2E (shared)) that is used in a loop for M=7𝑀7M=7italic_M = 7 queries during training on the 1D Uniform distribution, thereby conserving parameters by reusing weights. This could be a promising direction for problem settings where there is a recursive structure to the query algorithm.

One limitation of our work is the scale at which we learn data structures. Most of our nearest neighbor search experiments are done with input dataset sizes around N=100𝑁100N=100italic_N = 100, however, we are also able to scale up to N=500𝑁500N=500italic_N = 500 (Figure 25 (Left/Center)), though with less than log(N)𝑁\log(N)roman_log ( italic_N ) queries. While we demonstrate that useful data structures can still be learned at this scale, it is possible that other classes of structures only emerge for larger datasets. We also believe that many of the insights that can be derived from our models’ learned solutions would scale to larger N𝑁Nitalic_N. For instance, sorting in 1D and locality-sensitive hashing in higher dimensions. We limit ourselves to datasets of these sizes due to computational constraints, and because our primary goal was to understand whether end-to-end data structure design is feasible at any reasonable scale. However, we believe our framework could scale to datasets with thousands of points by increasing the parameter counts of the data-processing and query-execution models. Moreover, as transformers become increasingly efficient at handling larger context sizes in language modeling settings, some of these modeling advancements may also be used for scaling models in the context of data structure discovery.

Complementary to our work, it could also be valuable to explore better inductive biases for the query and data-processing networks, and other methods to ensure sparse lookups, enabling smaller models to scale to larger datasets. For instance, using shared weights among query models can be helpful in scaling up the number of queries. As a first step in this direction we show that a single query model can be used in-a-loop for NN search in 1D (Figure 25 (Right)). However, we leave further investigation for future work.