Skip to content

Commit

Permalink
Merge branch 'master' into gb_cpu_cached_feature
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin authored Jul 19, 2024
2 parents 4f2b098 + 41bb3b6 commit 8749a25
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 5 deletions.
4 changes: 3 additions & 1 deletion graphbolt/src/cache_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ torch::Tensor BaseCachePolicy::ReplaceImpl(
const auto pos = pos_optional ? *pos_optional : policy.Insert(key);
positions_ptr[i] = pos;
TORCH_CHECK(
std::get<1>(position_set.insert(pos)),
// If there are duplicate values and the key was just inserted,
// we do not have to check for the uniqueness of the positions.
pos_optional.has_value() || std::get<1>(position_set.insert(pos)),
"Can't insert all, larger cache capacity is needed.");
}
}));
Expand Down
15 changes: 12 additions & 3 deletions python/dgl/graphbolt/impl/feature_cache.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""CPU Feature Cache implementation wrapper for graphbolt."""
import torch

__all__ = ["FeatureCache"]
__all__ = ["CPUFeatureCache"]

caching_policies = {
"s3-fifo": torch.ops.graphbolt.s3_fifo_cache_policy,
Expand All @@ -11,7 +11,7 @@
}


class FeatureCache(object):
class CPUFeatureCache(object):
r"""High level wrapper for the CPU feature cache.
Parameters
Expand All @@ -34,15 +34,24 @@ def __init__(
self,
cache_shape,
dtype,
policy="sieve",
policy=None,
num_parts=None,
pin_memory=False,
):
if policy is None:
policy = "sieve"
assert (
policy in caching_policies
), f"{list(caching_policies.keys())} are the available caching policies."
if num_parts is None:
num_parts = torch.get_num_threads()
min_num_cache_items = num_parts * (10 if policy == "s3-fifo" else 1)
# Since we partition the cache, each partition needs to have a positive
# number of slots. In addition, each "s3-fifo" partition needs at least
# 10 slots since the small queue is 10% and the small queue needs a
# positive size.
if cache_shape[0] < min_num_cache_items:
cache_shape = (min_num_cache_items,) + cache_shape[1:]
self._policy = caching_policies[policy](cache_shape[0], num_parts)
self._cache = torch.ops.graphbolt.feature_cache(
cache_shape, dtype, pin_memory
Expand Down
2 changes: 1 addition & 1 deletion tests/python/pytorch/graphbolt/impl/test_feature_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_feature_cache(dtype, feature_size, num_parts, policy):
torch.get_num_threads() if num_parts is None else num_parts
)
a = torch.randint(0, 2, [1024, feature_size], dtype=dtype)
cache = gb.impl.FeatureCache(
cache = gb.impl.CPUFeatureCache(
(cache_size,) + a.shape[1:], a.dtype, policy, num_parts
)

Expand Down

0 comments on commit 8749a25

Please sign in to comment.