Skip to content

Commit

Permalink
Merge branch 'master'
Browse files Browse the repository at this point in the history
  • Loading branch information
agrabow committed Jul 19, 2023
2 parents fc383d8 + 4ceb0bf commit 2f35f23
Show file tree
Hide file tree
Showing 25 changed files with 1,023 additions and 174 deletions.
1 change: 1 addition & 0 deletions graphbolt/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
cmake_minimum_required(VERSION 3.5)
project(graphbolt C CXX)
set (CMAKE_CXX_STANDARD 17)

# Find PyTorch cmake files and PyTorch versions with the python interpreter
# $PYTHON_INTERP ("python3" or "python" if empty)
Expand Down
40 changes: 37 additions & 3 deletions graphbolt/include/graphbolt/csc_sampling_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,21 @@
namespace graphbolt {
namespace sampling {

enum SamplerType { NEIGHBOR, LABOR };

template <SamplerType S>
struct SamplerArgs;

template <>
struct SamplerArgs<SamplerType::NEIGHBOR> {};

template <>
struct SamplerArgs<SamplerType::LABOR> {
const torch::Tensor& indices;
int64_t random_seed;
int64_t num_nodes;
};

/**
* @brief A sampling oriented csc format graph.
*
Expand Down Expand Up @@ -143,6 +158,9 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
* @param replace Boolean indicating whether the sample is preformed with or
* without replacement. If True, a value can be selected multiple times.
* Otherwise, each value can be selected only once.
* @param layer Boolean indicating whether neighbors should be sampled in a
* layer sampling fashion. Uses the LABOR-0 algorithm to increase overlap of
* sampled edges, see arXiv:2210.13339.
* @param return_eids Boolean indicating whether edge IDs need to be returned,
* typically used when edge features are required.
* @param probs_name An optional string specifying the name of an edge
Expand All @@ -156,7 +174,7 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
*/
c10::intrusive_ptr<SampledSubgraph> SampleNeighbors(
const torch::Tensor& nodes, const std::vector<int64_t>& fanouts,
bool replace, bool return_eids,
bool replace, bool layer, bool return_eids,
torch::optional<std::string> probs_name) const;

/**
Expand Down Expand Up @@ -204,6 +222,13 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
const std::string& shared_memory_name);

private:
template <SamplerType S>
c10::intrusive_ptr<SampledSubgraph> SampleNeighborsImpl(
const torch::Tensor& nodes, const std::vector<int64_t>& fanouts,
bool replace, bool return_eids,
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<S> args) const;

/**
* @brief Build a CSCSamplingGraph from shared memory tensors.
*
Expand Down Expand Up @@ -298,10 +323,11 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
*
* @return A tensor containing the picked neighbors.
*/
template <SamplerType S>
torch::Tensor Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask);
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args);

/**
* @brief Picks a specified number of neighbors for a node per edge type,
Expand Down Expand Up @@ -330,11 +356,19 @@ torch::Tensor Pick(
*
* @return A tensor containing the picked neighbors.
*/
template <SamplerType S>
torch::Tensor PickByEtype(
int64_t offset, int64_t num_neighbors, const std::vector<int64_t>& fanouts,
bool replace, const torch::TensorOptions& options,
const torch::Tensor& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask);
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args);

template <bool NonUniform, bool Replace, typename T = float>
torch::Tensor LaborPick(
int64_t offset, int64_t num_neighbors, int64_t fanout,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args);

} // namespace sampling
} // namespace graphbolt
Expand Down
Loading

0 comments on commit 2f35f23

Please sign in to comment.