Skip to content

Commit

Permalink
rpc : resource management rework (llama/7562)
Browse files Browse the repository at this point in the history
* rpc : resource management rework

* address review comments
  • Loading branch information
rgerganov authored and ggerganov committed May 29, 2024
1 parent ef03773 commit 0a0070b
Showing 1 changed file with 75 additions and 58 deletions.
133 changes: 75 additions & 58 deletions src/ggml-rpc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <string>
#include <vector>
#include <memory>
#include <mutex>
#include <unordered_map>
#include <unordered_set>
#ifdef _WIN32
Expand Down Expand Up @@ -47,6 +48,7 @@ struct socket_t {
sockfd_t fd;
socket_t(sockfd_t fd) : fd(fd) {}
~socket_t() {
GGML_PRINT_DEBUG("[%s] closing socket %d\n", __func__, this->fd);
#ifdef _WIN32
closesocket(this->fd);
#else
Expand Down Expand Up @@ -97,7 +99,7 @@ static ggml_guid_t ggml_backend_rpc_guid() {
}

struct ggml_backend_rpc_buffer_type_context {
std::shared_ptr<socket_t> sock;
std::string endpoint;
std::string name;
size_t alignment;
size_t max_size;
Expand All @@ -106,8 +108,6 @@ struct ggml_backend_rpc_buffer_type_context {
struct ggml_backend_rpc_context {
std::string endpoint;
std::string name;
std::shared_ptr<socket_t> sock;
ggml_backend_buffer_type_t buft;
};

struct ggml_backend_rpc_buffer_context {
Expand Down Expand Up @@ -231,14 +231,13 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
return true;
}

static bool parse_endpoint(const char * endpoint, std::string & host, int & port) {
std::string str(endpoint);
size_t pos = str.find(':');
static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
size_t pos = endpoint.find(':');
if (pos == std::string::npos) {
return false;
}
host = str.substr(0, pos);
port = std::stoi(str.substr(pos + 1));
host = endpoint.substr(0, pos);
port = std::stoi(endpoint.substr(pos + 1));
return true;
}

Expand Down Expand Up @@ -273,6 +272,44 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm

// RPC client-side implementation

static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
static std::unordered_map<std::string, std::weak_ptr<socket_t>> sockets;
static bool initialized = false;

auto it = sockets.find(endpoint);
if (it != sockets.end()) {
if (auto sock = it->second.lock()) {
return sock;
}
}
std::string host;
int port;
if (!parse_endpoint(endpoint, host, port)) {
return nullptr;
}
#ifdef _WIN32
if (!initialized) {
WSADATA wsaData;
int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
if (res != 0) {
return nullptr;
}
initialized = true;
}
#else
UNUSED(initialized);
#endif
auto sock = socket_connect(host.c_str(), port);
if (sock == nullptr) {
return nullptr;
}
GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
sockets[endpoint] = sock;
return sock;
}

GGML_CALL static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) {
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
return ctx->name.c_str();
Expand Down Expand Up @@ -442,7 +479,8 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
std::vector<uint8_t> input(input_size, 0);
memcpy(input.data(), &size, sizeof(size));
std::vector<uint8_t> output;
bool status = send_rpc_cmd(buft_ctx->sock, ALLOC_BUFFER, input, output);
auto sock = get_socket(buft_ctx->endpoint);
bool status = send_rpc_cmd(sock, ALLOC_BUFFER, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
Expand All @@ -453,7 +491,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
if (remote_ptr != 0) {
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
ggml_backend_rpc_buffer_interface,
new ggml_backend_rpc_buffer_context{buft_ctx->sock, {}, remote_ptr, "RPC"},
new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC"},
remote_size);
return buffer;
} else {
Expand Down Expand Up @@ -508,7 +546,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_type_supports_backend(ggml_backend
}
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
return buft_ctx->sock == rpc_ctx->sock;
return buft_ctx->endpoint == rpc_ctx->endpoint;
}

static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
Expand All @@ -521,7 +559,6 @@ static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
/* .is_host = */ NULL,
};


GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;

Expand All @@ -530,16 +567,13 @@ GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {

GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) {
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)rpc_ctx->buft->context;
delete buft_ctx;
delete rpc_ctx->buft;
delete rpc_ctx;
delete backend;
}

GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) {
ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
return ctx->buft;
return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
}

GGML_CALL static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
Expand Down Expand Up @@ -590,7 +624,8 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t
std::vector<uint8_t> input;
serialize_graph(cgraph, input);
std::vector<uint8_t> output;
bool status = send_rpc_cmd(rpc_ctx->sock, GRAPH_COMPUTE, input, output);
auto sock = get_socket(rpc_ctx->endpoint);
bool status = send_rpc_cmd(sock, GRAPH_COMPUTE, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.size() == 1);
return (enum ggml_status)output[0];
Expand Down Expand Up @@ -624,65 +659,48 @@ static ggml_backend_i ggml_backend_rpc_interface = {
/* .event_synchronize = */ NULL,
};

static std::unordered_map<std::string, ggml_backend_t> instances;

GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
ggml_backend_t backend = ggml_backend_rpc_init(endpoint);
return backend != nullptr ? ggml_backend_rpc_get_default_buffer_type(backend) : nullptr;
}

GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
std::string endpoint_str(endpoint);
if (instances.find(endpoint_str) != instances.end()) {
return instances[endpoint_str];
}
#ifdef _WIN32
{
WSADATA wsaData;
int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
if (res != 0) {
return nullptr;
}
}
#endif
fprintf(stderr, "Connecting to %s\n", endpoint);
std::string host;
int port;
if (!parse_endpoint(endpoint, host, port)) {
return nullptr;
}
auto sock = socket_connect(host.c_str(), port);
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
// NOTE: buffer types are allocated and never freed; this is by design
static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;
auto it = buft_map.find(endpoint);
if (it != buft_map.end()) {
return it->second;
}
auto sock = get_socket(endpoint);
if (sock == nullptr) {
return nullptr;
}
size_t alignment = get_alignment(sock);
size_t max_size = get_max_size(sock);
ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
/* .sock = */ sock,
/* .name = */ "RPC" + std::to_string(sock->fd),
/* .endpoint = */ endpoint,
/* .name = */ "RPC[" + std::string(endpoint) + "]",
/* .alignment = */ alignment,
/* .max_size = */ max_size
/* .max_size = */ max_size
};

ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
/* .iface = */ ggml_backend_rpc_buffer_type_interface,
/* .context = */ buft_ctx
};
buft_map[endpoint] = buft;
return buft;
}

GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
/* .endpoint = */ endpoint,
/* .name = */ "RPC" + std::to_string(sock->fd),
/* .sock = */ sock,
/* .buft = */ buft
/* .endpoint = */ endpoint,
/* .name = */ "RPC",
};

instances[endpoint] = new ggml_backend {
ggml_backend_t backend = new ggml_backend {
/* .guid = */ ggml_backend_rpc_guid(),
/* .interface = */ ggml_backend_rpc_interface,
/* .context = */ ctx
};

return instances[endpoint];
return backend;
}

GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) {
Expand All @@ -706,14 +724,13 @@ static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * f
}

GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
ggml_backend_t backend = ggml_backend_rpc_init(endpoint);
if (backend == nullptr) {
auto sock = get_socket(endpoint);
if (sock == nullptr) {
*free = 0;
*total = 0;
return;
}
ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
get_device_memory(ctx->sock, free, total);
get_device_memory(sock, free, total);
}

// RPC server-side implementation
Expand Down

0 comments on commit 0a0070b

Please sign in to comment.