Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBolt] Add CPUCachedFeature. [2] #7537

Merged
merged 10 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/dgl/graphbolt/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
from .uniform_negative_sampler import *
from .gpu_graph_cache import *
from .feature_cache import *
from .cpu_cached_feature import *
119 changes: 119 additions & 0 deletions python/dgl/graphbolt/impl/cpu_cached_feature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""CPU cached feature for GraphBolt."""

import torch

from ..feature_store import Feature

from .feature_cache import CPUFeatureCache

__all__ = ["CPUCachedFeature"]


def num_cache_items(cache_capacity_in_bytes, single_item):
mfbalin marked this conversation as resolved.
Show resolved Hide resolved
"""Returns the number of rows to be cached."""
item_bytes = single_item.nbytes
# Round up so that we never get a size of 0, unless bytes is 0.
return (cache_capacity_in_bytes + item_bytes - 1) // item_bytes


class CPUCachedFeature(Feature):
r"""CPU cached feature wrapping a fallback feature.

Parameters
----------
fallback_feature : Feature
The fallback feature.
max_cache_size_in_bytes : int
The capacity of the cache in bytes.
policy : str
The cache eviction policy algorithm name. See gb.impl.CPUFeatureCache
for the list of available policies.
pin_memory : bool
Whether the cache storage should be allocated on system pinned memory.
Default is False.
"""

def __init__(
self,
fallback_feature: Feature,
max_cache_size_in_bytes: int,
policy: str = None,
pin_memory: bool = False,
):
super(CPUCachedFeature, self).__init__()
assert isinstance(fallback_feature, Feature), (
f"The fallback_feature must be an instance of Feature, but got "
f"{type(fallback_feature)}."
)
self._fallback_feature = fallback_feature
self.max_cache_size_in_bytes = max_cache_size_in_bytes
# Fetching the feature dimension from the underlying feature.
feat0 = fallback_feature.read(torch.tensor([0]))
cache_size = num_cache_items(max_cache_size_in_bytes, feat0)
self._feature = CPUFeatureCache(
(cache_size,) + feat0.shape[1:],
feat0.dtype,
policy=policy,
pin_memory=pin_memory,
)

def read(self, ids: torch.Tensor = None):
"""Read the feature by index.

Parameters
----------
ids : torch.Tensor, optional
The index of the feature. If specified, only the specified indices
of the feature are read. If None, the entire feature is returned.

Returns
-------
torch.Tensor
The read feature.
"""
if ids is None:
return self._fallback_feature.read()
values, missing_index, missing_keys = self._feature.query(ids)
missing_values = self._fallback_feature.read(missing_keys)
values[missing_index] = missing_values
self._feature.replace(missing_keys, missing_values)
return values

def size(self):
"""Get the size of the feature.

Returns
-------
torch.Size
The size of the feature.
"""
return self._fallback_feature.size()

def update(self, value: torch.Tensor, ids: torch.Tensor = None):
"""Update the feature.

Parameters
----------
value : torch.Tensor
The updated value of the feature.
ids : torch.Tensor, optional
The indices of the feature to update. If specified, only the
specified indices of the feature will be updated. For the feature,
the `ids[i]` row is updated to `value[i]`. So the indices and value
must have the same length. If None, the entire feature will be
updated.
"""
if ids is None:
feat0 = value[:1]
self._fallback_feature.update(value)
cache_size = min(
num_cache_items(self.max_cache_size_in_bytes, feat0),
value.shape[0],
)
self._feature = None # Destroy the existing cache first.
self._feature = CPUFeatureCache(
(cache_size,) + feat0.shape[1:], feat0.dtype
)
else:
self._fallback_feature.update(value, ids)
self._feature.replace(ids, value)
Loading