Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBolt] Add CPUCachedFeature. [2] #7537

Merged
merged 10 commits into from
Jul 19, 2024
Merged
Prev Previous commit
Next Next commit
put these changes to another PR.
  • Loading branch information
mfbalin committed Jul 18, 2024
commit 4f2b0986082663ab87da7e264e87384ee18c1205
8 changes: 3 additions & 5 deletions python/dgl/graphbolt/impl/feature_cache.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""CPU Feature Cache implementation wrapper for graphbolt."""
import torch

__all__ = ["CPUFeatureCache"]
__all__ = ["FeatureCache"]

caching_policies = {
"s3-fifo": torch.ops.graphbolt.s3_fifo_cache_policy,
Expand All @@ -11,7 +11,7 @@
}


class CPUFeatureCache(object):
class FeatureCache(object):
r"""High level wrapper for the CPU feature cache.

Parameters
Expand All @@ -34,12 +34,10 @@ def __init__(
self,
cache_shape,
dtype,
policy=None,
policy="sieve",
num_parts=None,
pin_memory=False,
):
if policy is None:
policy = "sieve"
assert (
policy in caching_policies
), f"{list(caching_policies.keys())} are the available caching policies."
Expand Down
2 changes: 1 addition & 1 deletion tests/python/pytorch/graphbolt/impl/test_feature_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_feature_cache(dtype, feature_size, num_parts, policy):
torch.get_num_threads() if num_parts is None else num_parts
)
a = torch.randint(0, 2, [1024, feature_size], dtype=dtype)
cache = gb.impl.CPUFeatureCache(
cache = gb.impl.FeatureCache(
(cache_size,) + a.shape[1:], a.dtype, policy, num_parts
)

Expand Down
Loading