Skip to content

Commit

Permalink
ProtocolServer: Attach downloads and their lifecycles to clients
Browse files Browse the repository at this point in the history
Previously a download lived independently of the client connection it came
from. This was the source of several undesirable behaviours, including the
potential for clients to influence downloads they didn't start, and
downloads living longer than their associated client connections. Now we
attach downloads to client connections, which means they're cleaned up
automatically when the client goes away, and there's significantly less
risk of clients interfering with each other.
  • Loading branch information
deoxxa authored and awesomekling committed May 17, 2020
1 parent 184ee8a commit f2621f3
Show file tree
Hide file tree
Showing 17 changed files with 48 additions and 57 deletions.
32 changes: 5 additions & 27 deletions Services/ProtocolServer/Download.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,10 @@
// FIXME: What about rollover?
static i32 s_next_id = 1;

static HashMap<i32, RefPtr<Download>>& all_downloads()
{
static HashMap<i32, RefPtr<Download>> map;
return map;
}

Download* Download::find_by_id(i32 id)
{
return const_cast<Download*>(all_downloads().get(id).value_or(nullptr));
}

Download::Download(PSClientConnection& client)
: m_id(s_next_id++)
, m_client(client.make_weak_ptr())
: m_client(client)
, m_id(s_next_id++)
{
all_downloads().set(m_id, this);
}

Download::~Download()
Expand All @@ -55,7 +43,7 @@ Download::~Download()

void Download::stop()
{
all_downloads().remove(m_id);
m_client.did_finish_download({}, *this, false);
}

void Download::set_payload(const ByteBuffer& payload)
Expand All @@ -71,22 +59,12 @@ void Download::set_response_headers(const HashMap<String, String, CaseInsensitiv

void Download::did_finish(bool success)
{
if (!m_client) {
dbg() << "Download::did_finish() after the client already disconnected.";
return;
}
m_client->did_finish_download({}, *this, success);
all_downloads().remove(m_id);
m_client.did_finish_download({}, *this, success);
}

void Download::did_progress(Optional<u32> total_size, u32 downloaded_size)
{
if (!m_client) {
// FIXME: We should also abort the download in this situation, I guess!
dbg() << "Download::did_progress() after the client already disconnected.";
return;
}
m_total_size = total_size;
m_downloaded_size = downloaded_size;
m_client->did_progress_download({}, *this);
m_client.did_progress_download({}, *this);
}
9 changes: 3 additions & 6 deletions Services/ProtocolServer/Download.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,13 @@
#include <AK/Optional.h>
#include <AK/RefCounted.h>
#include <AK/URL.h>
#include <AK/WeakPtr.h>

class PSClientConnection;

class Download : public RefCounted<Download> {
class Download {
public:
virtual ~Download();

static Download* find_by_id(i32);

i32 id() const { return m_id; }
URL url() const { return m_url; }

Expand All @@ -60,11 +57,11 @@ class Download : public RefCounted<Download> {
void set_response_headers(const HashMap<String, String, CaseInsensitiveStringTraits>&);

private:
i32 m_id;
PSClientConnection& m_client;
i32 m_id { 0 };
URL m_url;
Optional<u32> m_total_size {};
size_t m_downloaded_size { 0 };
ByteBuffer m_payload;
HashMap<String, String, CaseInsensitiveStringTraits> m_response_headers;
WeakPtr<PSClientConnection> m_client;
};
9 changes: 6 additions & 3 deletions Services/ProtocolServer/GeminiDownload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
#include <LibGemini/GeminiJob.h>
#include <ProtocolServer/GeminiDownload.h>

GeminiDownload::GeminiDownload(PSClientConnection& client, NonnullRefPtr<Gemini::GeminiJob>&& job)
GeminiDownload::GeminiDownload(PSClientConnection& client, NonnullRefPtr<Gemini::GeminiJob> job)
: Download(client)
, m_job(job)
{
Expand All @@ -55,9 +55,12 @@ GeminiDownload::GeminiDownload(PSClientConnection& client, NonnullRefPtr<Gemini:

GeminiDownload::~GeminiDownload()
{
m_job->on_finish = nullptr;
m_job->on_progress = nullptr;
m_job->shutdown();
}

NonnullRefPtr<GeminiDownload> GeminiDownload::create_with_job(Badge<GeminiProtocol>, PSClientConnection& client, NonnullRefPtr<Gemini::GeminiJob>&& job)
NonnullOwnPtr<GeminiDownload> GeminiDownload::create_with_job(Badge<GeminiProtocol>, PSClientConnection& client, NonnullRefPtr<Gemini::GeminiJob> job)
{
return adopt(*new GeminiDownload(client, move(job)));
return adopt_own(*new GeminiDownload(client, move(job)));
}
4 changes: 2 additions & 2 deletions Services/ProtocolServer/GeminiDownload.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ class GeminiProtocol;
class GeminiDownload final : public Download {
public:
virtual ~GeminiDownload() override;
static NonnullRefPtr<GeminiDownload> create_with_job(Badge<GeminiProtocol>, PSClientConnection&, NonnullRefPtr<Gemini::GeminiJob>&&);
static NonnullOwnPtr<GeminiDownload> create_with_job(Badge<GeminiProtocol>, PSClientConnection&, NonnullRefPtr<Gemini::GeminiJob>);

private:
explicit GeminiDownload(PSClientConnection&, NonnullRefPtr<Gemini::GeminiJob>&&);
explicit GeminiDownload(PSClientConnection&, NonnullRefPtr<Gemini::GeminiJob>);

NonnullRefPtr<Gemini::GeminiJob> m_job;
};
2 changes: 1 addition & 1 deletion Services/ProtocolServer/GeminiProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ GeminiProtocol::~GeminiProtocol()
{
}

RefPtr<Download> GeminiProtocol::start_download(PSClientConnection& client, const URL& url)
OwnPtr<Download> GeminiProtocol::start_download(PSClientConnection& client, const URL& url)
{
Gemini::GeminiRequest request;
request.set_url(url);
Expand Down
2 changes: 1 addition & 1 deletion Services/ProtocolServer/GeminiProtocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,5 @@ class GeminiProtocol final : public Protocol {
GeminiProtocol();
virtual ~GeminiProtocol() override;

virtual RefPtr<Download> start_download(PSClientConnection&, const URL&) override;
virtual OwnPtr<Download> start_download(PSClientConnection&, const URL&) override;
};
9 changes: 6 additions & 3 deletions Services/ProtocolServer/HttpDownload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
#include <LibHTTP/HttpResponse.h>
#include <ProtocolServer/HttpDownload.h>

HttpDownload::HttpDownload(PSClientConnection& client, NonnullRefPtr<HTTP::HttpJob>&& job)
HttpDownload::HttpDownload(PSClientConnection& client, NonnullRefPtr<HTTP::HttpJob> job)
: Download(client)
, m_job(job)
{
Expand All @@ -52,9 +52,12 @@ HttpDownload::HttpDownload(PSClientConnection& client, NonnullRefPtr<HTTP::HttpJ

HttpDownload::~HttpDownload()
{
m_job->on_finish = nullptr;
m_job->on_progress = nullptr;
m_job->shutdown();
}

NonnullRefPtr<HttpDownload> HttpDownload::create_with_job(Badge<HttpProtocol>, PSClientConnection& client, NonnullRefPtr<HTTP::HttpJob>&& job)
NonnullOwnPtr<HttpDownload> HttpDownload::create_with_job(Badge<HttpProtocol>, PSClientConnection& client, NonnullRefPtr<HTTP::HttpJob> job)
{
return adopt(*new HttpDownload(client, move(job)));
return adopt_own(*new HttpDownload(client, move(job)));
}
4 changes: 2 additions & 2 deletions Services/ProtocolServer/HttpDownload.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ class HttpProtocol;
class HttpDownload final : public Download {
public:
virtual ~HttpDownload() override;
static NonnullRefPtr<HttpDownload> create_with_job(Badge<HttpProtocol>, PSClientConnection&, NonnullRefPtr<HTTP::HttpJob>&&);
static NonnullOwnPtr<HttpDownload> create_with_job(Badge<HttpProtocol>, PSClientConnection&, NonnullRefPtr<HTTP::HttpJob>);

private:
explicit HttpDownload(PSClientConnection&, NonnullRefPtr<HTTP::HttpJob>&&);
explicit HttpDownload(PSClientConnection&, NonnullRefPtr<HTTP::HttpJob>);

NonnullRefPtr<HTTP::HttpJob> m_job;
};
2 changes: 1 addition & 1 deletion Services/ProtocolServer/HttpProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ HttpProtocol::~HttpProtocol()
{
}

RefPtr<Download> HttpProtocol::start_download(PSClientConnection& client, const URL& url)
OwnPtr<Download> HttpProtocol::start_download(PSClientConnection& client, const URL& url)
{
HTTP::HttpRequest request;
request.set_method(HTTP::HttpRequest::Method::GET);
Expand Down
2 changes: 1 addition & 1 deletion Services/ProtocolServer/HttpProtocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,5 @@ class HttpProtocol final : public Protocol {
HttpProtocol();
virtual ~HttpProtocol() override;

virtual RefPtr<Download> start_download(PSClientConnection&, const URL&) override;
virtual OwnPtr<Download> start_download(PSClientConnection&, const URL&) override;
};
9 changes: 6 additions & 3 deletions Services/ProtocolServer/HttpsDownload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
#include <LibHTTP/HttpsJob.h>
#include <ProtocolServer/HttpsDownload.h>

HttpsDownload::HttpsDownload(PSClientConnection& client, NonnullRefPtr<HTTP::HttpsJob>&& job)
HttpsDownload::HttpsDownload(PSClientConnection& client, NonnullRefPtr<HTTP::HttpsJob> job)
: Download(client)
, m_job(job)
{
Expand All @@ -52,9 +52,12 @@ HttpsDownload::HttpsDownload(PSClientConnection& client, NonnullRefPtr<HTTP::Htt

HttpsDownload::~HttpsDownload()
{
m_job->on_finish = nullptr;
m_job->on_progress = nullptr;
m_job->shutdown();
}

NonnullRefPtr<HttpsDownload> HttpsDownload::create_with_job(Badge<HttpsProtocol>, PSClientConnection& client, NonnullRefPtr<HTTP::HttpsJob>&& job)
NonnullOwnPtr<HttpsDownload> HttpsDownload::create_with_job(Badge<HttpsProtocol>, PSClientConnection& client, NonnullRefPtr<HTTP::HttpsJob> job)
{
return adopt(*new HttpsDownload(client, move(job)));
return adopt_own(*new HttpsDownload(client, move(job)));
}
4 changes: 2 additions & 2 deletions Services/ProtocolServer/HttpsDownload.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ class HttpsProtocol;
class HttpsDownload final : public Download {
public:
virtual ~HttpsDownload() override;
static NonnullRefPtr<HttpsDownload> create_with_job(Badge<HttpsProtocol>, PSClientConnection&, NonnullRefPtr<HTTP::HttpsJob>&&);
static NonnullOwnPtr<HttpsDownload> create_with_job(Badge<HttpsProtocol>, PSClientConnection&, NonnullRefPtr<HTTP::HttpsJob>);

private:
explicit HttpsDownload(PSClientConnection&, NonnullRefPtr<HTTP::HttpsJob>&&);
explicit HttpsDownload(PSClientConnection&, NonnullRefPtr<HTTP::HttpsJob>);

NonnullRefPtr<HTTP::HttpsJob> m_job;
};
2 changes: 1 addition & 1 deletion Services/ProtocolServer/HttpsProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ HttpsProtocol::~HttpsProtocol()
{
}

RefPtr<Download> HttpsProtocol::start_download(PSClientConnection& client, const URL& url)
OwnPtr<Download> HttpsProtocol::start_download(PSClientConnection& client, const URL& url)
{
HTTP::HttpRequest request;
request.set_method(HTTP::HttpRequest::Method::GET);
Expand Down
2 changes: 1 addition & 1 deletion Services/ProtocolServer/HttpsProtocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,5 @@ class HttpsProtocol final : public Protocol {
HttpsProtocol();
virtual ~HttpsProtocol() override;

virtual RefPtr<Download> start_download(PSClientConnection&, const URL&) override;
virtual OwnPtr<Download> start_download(PSClientConnection&, const URL&) override;
};
10 changes: 8 additions & 2 deletions Services/ProtocolServer/PSClientConnection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,16 @@ OwnPtr<Messages::ProtocolServer::StartDownloadResponse> PSClientConnection::hand
if (!protocol)
return make<Messages::ProtocolServer::StartDownloadResponse>(-1);
auto download = protocol->start_download(*this, url);
return make<Messages::ProtocolServer::StartDownloadResponse>(download->id());
if (!download)
return make<Messages::ProtocolServer::StartDownloadResponse>(-1);
auto id = download->id();
m_downloads.set(id, move(download));
return make<Messages::ProtocolServer::StartDownloadResponse>(id);
}

OwnPtr<Messages::ProtocolServer::StopDownloadResponse> PSClientConnection::handle(const Messages::ProtocolServer::StopDownload& message)
{
auto* download = Download::find_by_id(message.download_id());
auto* download = const_cast<Download*>(m_downloads.get(message.download_id()).value_or(nullptr));
bool success = false;
if (download) {
download->stop();
Expand All @@ -93,6 +97,8 @@ void PSClientConnection::did_finish_download(Badge<Download>, Download& download
for (auto& it : download.response_headers())
response_headers.add(it.key, it.value);
post_message(Messages::ProtocolClient::DownloadFinished(download.id(), success, download.total_size().value(), buffer ? buffer->shbuf_id() : -1, response_headers));

m_downloads.remove(download.id());
}

void PSClientConnection::did_progress_download(Badge<Download>, Download& download)
Expand Down
1 change: 1 addition & 0 deletions Services/ProtocolServer/PSClientConnection.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,6 @@ class PSClientConnection final : public IPC::ClientConnection<ProtocolServerEndp
virtual OwnPtr<Messages::ProtocolServer::StopDownloadResponse> handle(const Messages::ProtocolServer::StopDownload&) override;
virtual OwnPtr<Messages::ProtocolServer::DisownSharedBufferResponse> handle(const Messages::ProtocolServer::DisownSharedBuffer&) override;

HashMap<i32, OwnPtr<Download>> m_downloads;
HashMap<i32, RefPtr<AK::SharedBuffer>> m_shared_buffers;
};
2 changes: 1 addition & 1 deletion Services/ProtocolServer/Protocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Protocol {
virtual ~Protocol();

const String& name() const { return m_name; }
virtual RefPtr<Download> start_download(PSClientConnection&, const URL&) = 0;
virtual OwnPtr<Download> start_download(PSClientConnection&, const URL&) = 0;

static Protocol* find_by_name(const String&);

Expand Down

0 comments on commit f2621f3

Please sign in to comment.