Skip to content

Commit

Permalink
Fused sampling
Browse files Browse the repository at this point in the history
Co-authored-by: Hesham Mostafa <[email protected]>
  • Loading branch information
agrabow and hesham-mostafa committed Jun 28, 2023
1 parent d22049e commit ab09bed
Show file tree
Hide file tree
Showing 13 changed files with 1,163 additions and 94 deletions.
39 changes: 39 additions & 0 deletions benchmarks/benchmarks/api/bench_fused_sample_neighbors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import time

import dgl
import dgl.function as fn

import numpy as np
import torch

from .. import utils


@utils.benchmark("time")
@utils.parametrize_cpu("graph_name", ["livejournal", "reddit"])
@utils.parametrize_gpu("graph_name", ["ogbn-arxiv", "reddit"])
@utils.parametrize("format", ["csr", "csc"])
@utils.parametrize("seed_nodes_num", [200, 5000, 20000])
@utils.parametrize("fanout", [5, 20, 40])
def track_time(graph_name, format, seed_nodes_num, fanout):
device = utils.get_bench_device()
graph = utils.get_graph(graph_name, format).to(device)

edge_dir = "in" if format == "csc" else "out"
seed_nodes = np.random.randint(0, graph.num_nodes(), seed_nodes_num)
seed_nodes = torch.from_numpy(seed_nodes).to(device)

# dry run
for i in range(3):
dgl.sampling.sample_neighbors(
graph, seed_nodes, fanout, edge_dir=edge_dir, fused=True
)

# timing
with utils.Timer() as t:
for i in range(50):
dgl.sampling.sample_neighbors(
graph, seed_nodes, fanout, edge_dir=edge_dir, fused=True
)

return t.elapsed_secs / 50
66 changes: 66 additions & 0 deletions include/dgl/aten/csr.h
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,72 @@ COOMatrix CSRRowWiseSampling(
CSRMatrix mat, IdArray rows, int64_t num_samples,
NDArray prob_or_mask = NDArray(), bool replace = true);

/*!
* @brief Randomly select a fixed number of non-zero entries along each given
* row independently.
*
* The function performs random choices along each row independently.
* The picked indices are returned in the form of a CSR matrix, with
* additional IdArray that is an extended version of CSR's index pointers.
*
* With template parameter set to True rows are also saved as new seed nodes and
* mapped
*
* If replace is false and a row has fewer non-zero values than num_samples,
* all the values are picked.
*
* Examples:
*
* // csr.num_rows = 4;
* // csr.num_cols = 4;
* // csr.indptr = [0, 2, 3, 3, 5]
* // csr.indices = [0, 1, 1, 2, 3]
* // csr.data = [2, 3, 0, 1, 4]
* CSRMatrix csr = ...;
* IdArray rows = ... ; // [1, 3]
* IdArray seed_mapping = [-1, -1, -1, -1];
* std::vector<IdType> new_seed_nodes = {};
*
* std::pair<CSRMatrix, IdArray> sampled = CSRRowWiseSamplingFused<
* typename IdType, True>(
* csr, rows, seed_mapping,
* new_seed_nodes, 2,
* FloatArray(), false);
* // possible sampled csr matrix:
* // sampled.first.num_rows = 2
* // sampled.first.num_cols = 3
* // sampled.first.indptr = [0, 1, 3]
* // sampled.first.indices = [1, 2, 3]
* // sampled.first.data = [0, 1, 4]
* // sampled.second = [0, 1, 1]
* // seed_mapping = [-1, 0, -1, 1];
* // new_seed_nodes = {1, 3};
*
* @tparam IdType Graph's index data type, can be int32_t or int64_t
* @tparam map_seed_nodes If set for true we map and copy rows to new_seed_nodes
* @param mat Input CSR matrix.
* @param rows Rows to sample from.
* @param seed_mapping Mapping array used if map_seed_nodes=true. If so each row
* from rows will be set to its position e.g. mapping[rows[i]] = i.
* @param new_seed_nodes Vector used if map_seed_nodes=true. If so it will
* contain rows.
* @param rows Rows to sample from.
* @param num_samples Number of samples
* @param prob_or_mask Unnormalized probability array or mask array.
* Should be of the same length as the data array.
* If an empty array is provided, assume uniform.
* @param replace True if sample with replacement
* @return A CSRMatrix storing the picked row, col and data indices,
* COO version of picked rows
* @note The edges of the entire graph must be ordered by their edge types,
* rows must be unique
*/
template <typename IdType, bool map_seed_nodes>
std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingFused(
CSRMatrix mat, IdArray rows, IdArray seed_mapping,
std::vector<IdType>& new_seed_nodes, int64_t num_samples,
NDArray prob_or_mask = NDArray(), bool replace = true);

/**
* @brief Randomly select a fixed number of non-zero entries for each edge type
* along each given row independently.
Expand Down
49 changes: 49 additions & 0 deletions include/dgl/sampling/neighbor.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,55 @@ HeteroSubgraph SampleNeighbors(
const std::vector<FloatArray>& probability,
const std::vector<IdArray>& exclude_edges, bool replace = true);

/**
* @brief Sample from the neighbors of the given nodes and convert a graph into
* a bipartite-structured graph for message passing.
*
* Specifically, we create one node type \c ntype_l on the "left" side and
* another node type \c ntype_r on the "right" side for each node type \c ntype.
* The nodes of type \c ntype_r would contain the nodes designated by the
* caller, and node type \c ntype_l would contain the nodes that has an edge
* connecting to one of the designated nodes.
*
* The nodes of \c ntype_l would also contain the nodes in node type \c ntype_r.
* When sampling with replacement, the sampled subgraph could have parallel
* edges.
*
* For sampling without replace, if fanout > the number of neighbors, all the
* neighbors will be sampled.
*
* Non-deterministic algorithm, requires nodes parameter to store unique Node
* IDs.
*
* @tparam IdType Graph's index data type, can be int32_t or int64_t
* @param hg The input graph.
* @param nodes Node IDs of each type. The vector length must be equal to the
* number of node types. Empty array is allowed.
* @param mapping External parameter that should be set to a vector of IdArrays
* filled with -1, required for mapping of nodes in returned
* graph
* @param fanouts Number of sampled neighbors for each edge type. The vector
* length should be equal to the number of edge types, or one if they all have
* the same fanout.
* @param dir Edge direction.
* @param probability A vector of 1D float arrays, indicating the transition
* probability of each edge by edge type. An empty float array assumes uniform
* transition.
* @param exclude_edges Edges IDs of each type which will be excluded during
* sampling. The vector length must be equal to the number of edges types. Empty
* array is allowed.
* @param replace If true, sample with replacement.
* @return Sampled neighborhoods as a graph. The return graph has the same
* schema as the original one.
*/
template <typename IdType>
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>
SampleNeighborsFused(
const HeteroGraphPtr hg, const std::vector<IdArray>& nodes,
std::vector<IdArray>& mapping, const std::vector<int64_t>& fanouts,
EdgeDir dir, const std::vector<NDArray>& prob_or_mask,
const std::vector<IdArray>& exclude_edges, bool replace = true);

/**
* Select the neighbors with k-largest weights on the connecting edges for each
* given node.
Expand Down
57 changes: 42 additions & 15 deletions python/dgl/dataloading/neighbor_sampler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Data loading components for neighbor sampling"""
from .. import backend as F
from ..base import EID, NID
from ..transforms import to_block
from .base import BlockSampler
Expand Down Expand Up @@ -54,6 +55,9 @@ class NeighborSampler(BlockSampler):
output_device : device, optional
The device of the output subgraphs or MFGs. Default is the same as the
minibatch of seed nodes.
fused : bool, default True
If True and device is CPU fused sample neighbors is invoked. This version
requires seed_nodes to be unique
Examples
--------
Expand Down Expand Up @@ -120,6 +124,7 @@ def __init__(
prefetch_labels=None,
prefetch_edge_feats=None,
output_device=None,
fused=True,
):
super().__init__(
prefetch_node_feats=prefetch_node_feats,
Expand All @@ -137,25 +142,47 @@ def __init__(
)
self.prob = prob or mask
self.replace = replace
self.fused = fused
self.mapping = {}
self.g = None

def sample_blocks(self, g, seed_nodes, exclude_eids=None):
output_nodes = seed_nodes
blocks = []
for fanout in reversed(self.fanouts):
frontier = g.sample_neighbors(
seed_nodes,
fanout,
edge_dir=self.edge_dir,
prob=self.prob,
replace=self.replace,
output_device=self.output_device,
exclude_edges=exclude_eids,
)
eid = frontier.edata[EID]
block = to_block(frontier, seed_nodes)
block.edata[EID] = eid
seed_nodes = block.srcdata[NID]
blocks.insert(0, block)
if F.device_type(g.device) == "cpu" and self.fused:
if self.g != g:
self.mapping = {}
self.g = g
for fanout in reversed(self.fanouts):
block = g.sample_neighbors(
seed_nodes,
fanout,
edge_dir=self.edge_dir,
prob=self.prob,
replace=self.replace,
output_device=self.output_device,
fused=True,
exclude_edges=exclude_eids,
mapping=self.mapping,
)
seed_nodes = block.srcdata[NID]
blocks.insert(0, block)
else:
for fanout in reversed(self.fanouts):
frontier = g.sample_neighbors(
seed_nodes,
fanout,
edge_dir=self.edge_dir,
prob=self.prob,
replace=self.replace,
output_device=self.output_device,
exclude_edges=exclude_eids,
)
eid = frontier.edata[EID]
block = to_block(frontier, seed_nodes)
block.edata[EID] = eid
seed_nodes = block.srcdata[NID]
blocks.insert(0, block)

return seed_nodes, output_nodes, blocks

Expand Down
Loading

0 comments on commit ab09bed

Please sign in to comment.