Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Allow clearing gpu cache #14252

Merged
merged 9 commits into from
May 25, 2019
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2736,6 +2736,12 @@ MXNET_DLL int MXNDArrayGetSharedMemHandle(NDArrayHandle handle, int* shared_pid,
MXNET_DLL int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, const mx_uint *shape,
mx_uint ndim, int dtype, NDArrayHandle *out);

/*!
* \brief Release all pooled memory from the devices storage manager
* \param dev_type device type, specify device we want to take
* \param dev_id the device id of the specific device
*/
MXNET_DLL int MXStorageReleaseAll(int dev_type, int dev_id);

/*!
* \brief Reconstruct NDArray from shared memory handle
Expand Down
8 changes: 8 additions & 0 deletions include/mxnet/storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,14 @@ class Storage {
* \param handle Handle struct.
*/
virtual void DirectFree(Handle handle) = 0;
/*!
* \brief Release all memory from device if using a pooled storage manager
*
* This release all memory from pool storage managers such as
* GPUPooledStorageManager and GPUPooledRoundedStorageManager.
* For non-pool memory managers this has no effect.
*/
virtual void ReleaseAll(Context ctx) = 0;
/*!
* \brief Destructor.
*/
Expand Down
18 changes: 18 additions & 0 deletions python/mxnet/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,24 @@ def default_ctx(cls, val):
cls._default_ctx.value = val
#pylint: enable=no-self-argument

def empty_cache(self):
"""Empties the memory cache for the current contexts device.

MXNet utilizes a memory pool to avoid excessive allocations.
Calling empty_cache will empty the memory pool of the contexts
device. This will only free the memory of unreferenced data.

Examples
-------
>>> ctx = mx.gpu(0)
szha marked this conversation as resolved.
Show resolved Hide resolved
>>> arr = mx.nd.ones((200,200), ctx=ctx)
>>> del arr
>>> ctx.empty_cache() # forces release of memory allocated for arr
"""
dev_type = ctypes.c_int(self.device_typeid)
dev_id = ctypes.c_int(self.device_id)
check_call(_LIB.MXStorageReleaseAll(dev_type, dev_id))

# initialize the default context in Context
Context._default_ctx.value = Context('cpu', 0)

Expand Down
7 changes: 7 additions & 0 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1526,3 +1526,10 @@ int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param,

API_END();
}

int MXStorageReleaseAll(int dev_type, int dev_id) {
API_BEGIN();
Context ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id);
Storage::Get()->ReleaseAll(ctx);
API_END();
}
6 changes: 4 additions & 2 deletions src/storage/pooled_storage_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ class GPUPooledStorageManager final : public StorageManager {
DirectFreeNoLock(handle);
}

void ReleaseAll() override;

private:
void DirectFreeNoLock(Storage::Handle handle) {
mxnet::common::cuda::DeviceStore device_store(handle.ctx.real_dev_id(), true);
Expand Down Expand Up @@ -115,7 +117,6 @@ class GPUPooledStorageManager final : public StorageManager {
}

private:
void ReleaseAll();
// used memory
size_t used_memory_ = 0;
// page size
Expand Down Expand Up @@ -250,6 +251,8 @@ class GPUPooledRoundedStorageManager final : public StorageManager {
DirectFreeNoLock(handle);
}

void ReleaseAll() override;

private:
inline int div_pow2_round_up(size_t s, int divisor_log2) {
// (1025, 10) -> 2
Expand Down Expand Up @@ -284,7 +287,6 @@ class GPUPooledRoundedStorageManager final : public StorageManager {
}

private:
void ReleaseAll();
// number of devices
const int NDEV = 32;
// log2 of maximum page size. 16GB
Expand Down
12 changes: 12 additions & 0 deletions src/storage/storage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "./pooled_storage_manager.h"
#include "./cpu_shared_storage_manager.h"
#include "./cpu_device_storage.h"
#include "./gpu_device_storage.h"
#include "./pinned_memory_storage.h"
#include "../common/lazy_alloc_array.h"
#include "../profiler/storage_profiler.h"
Expand All @@ -38,6 +39,7 @@ class StorageImpl : public Storage {
void Alloc(Handle* handle) override;
void Free(Handle handle) override;
void DirectFree(Handle handle) override;
void ReleaseAll(Context ctx) override;
void SharedIncrementRefCount(Handle handle) override;
StorageImpl() {}
virtual ~StorageImpl() = default;
Expand Down Expand Up @@ -160,6 +162,16 @@ void StorageImpl::DirectFree(Storage::Handle handle) {
profiler_.OnFree(handle);
}

void StorageImpl::ReleaseAll(Context ctx) {
auto&& device = storage_managers_.at(ctx.dev_type);
std::shared_ptr<storage::StorageManager> manager = device.Get(
ctx.real_dev_id(), []() {
LOG(FATAL) << "Cannot Free space to a device you have not allocated";
return nullptr;
});
manager->ReleaseAll();
}

void StorageImpl::SharedIncrementRefCount(Storage::Handle handle) {
CHECK_EQ(handle.ctx.dev_type, Context::kCPUShared);
auto&& device = storage_managers_.at(Context::kCPUShared);
Expand Down
8 changes: 8 additions & 0 deletions src/storage/storage_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ class StorageManager {
* \param handle Handle struct.
*/
virtual void DirectFree(Storage::Handle handle) = 0;
/*!
* \brief Release all memory if using a pool storage manager
*
* This release all memory from pool storage managers such as
* GPUPooledStorageManager and GPUPooledRoundedStorageManager.
* For non-pool memory managers this has no effect.
*/
virtual void ReleaseAll() {}
/*!
* \brief Destructor.
*/
Expand Down