Skip to content

Commit

Permalink
Merge pull request graphql-python#1 from colanconnon/gevent-websocket
Browse files Browse the repository at this point in the history
add a gevent websocket server (WIP)
  • Loading branch information
syrusakbary committed Nov 10, 2017
2 parents 3df5301 + ae60b25 commit b3e5459
Show file tree
Hide file tree
Showing 11 changed files with 337 additions and 20 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,6 @@ target/
# pyenv python configuration file
.python-version
.DS_Store

.mypy_cache/
.vscode/
Empty file.
67 changes: 67 additions & 0 deletions examples/flask_gevent/flask_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import json

import graphene
from flask import Flask, make_response
from flask_graphql import GraphQLView
from flask_sockets import Sockets
from rx import Observable

from graphql_ws import GeventSubscriptionServer
from template import render_graphiql


class Query(graphene.ObjectType):
base = graphene.String()


class RandomType(graphene.ObjectType):
seconds = graphene.Int()
random_int = graphene.Int()


class Subscription(graphene.ObjectType):

count_seconds = graphene.Int(up_to=graphene.Int())

random_int = graphene.Field(RandomType)


def resolve_count_seconds(root, info, up_to):
return Observable.interval(1000)\
.map(lambda i: "{0}".format(i))\
.take_while(lambda i: int(i) <= up_to)

def resolve_random_int(root, info):
import random
return Observable.interval(1000).map(lambda i: RandomType(seconds=i, random_int=random.randint(0, 500)))

schema = graphene.Schema(query=Query, subscription=Subscription)



app = Flask(__name__)
app.debug = True
sockets = Sockets(app)


@app.route('/graphiql')
def graphql_view():
return make_response(render_graphiql())

app.add_url_rule(
'/graphql', view_func=GraphQLView.as_view('graphql', schema=schema, graphiql=False))

subscription_server = GeventSubscriptionServer(schema)
app.app_protocol = lambda environ_path_info: 'graphql-ws'

@sockets.route('/subscriptions')
def echo_socket(ws):
subscription_server.handle(ws)
return []


if __name__ == "__main__":
from gevent import pywsgi
from geventwebsocket.handler import WebSocketHandler
server = pywsgi.WSGIServer(('', 5000), app, handler_class=WebSocketHandler)
server.serve_forever()
125 changes: 125 additions & 0 deletions examples/flask_gevent/template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@

from string import Template


def render_graphiql():
return Template('''
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8" />
<title>GraphiQL</title>
<meta name="robots" content="noindex" />
<style>
html, body {
height: 100%;
margin: 0;
overflow: hidden;
width: 100%;
}
</style>
<link href="//cdn.jsdelivr.net/graphiql/${GRAPHIQL_VERSION}/graphiql.css" rel="stylesheet" />
<script src="//cdn.jsdelivr.net/fetch/0.9.0/fetch.min.js"></script>
<script src="//cdn.jsdelivr.net/react/15.0.0/react.min.js"></script>
<script src="//cdn.jsdelivr.net/react/15.0.0/react-dom.min.js"></script>
<script src="//cdn.jsdelivr.net/graphiql/${GRAPHIQL_VERSION}/graphiql.min.js"></script>
<script src="//unpkg.com/subscriptions-transport-ws@${SUBSCRIPTIONS_TRANSPORT_VERSION}/browser/client.js"></script>
<script src="//unpkg.com/[email protected]/browser/client.js"></script>
</head>
<body>
<script>
// Collect the URL parameters
var parameters = {};
window.location.search.substr(1).split('&').forEach(function (entry) {
var eq = entry.indexOf('=');
if (eq >= 0) {
parameters[decodeURIComponent(entry.slice(0, eq))] =
decodeURIComponent(entry.slice(eq + 1));
}
});
// Produce a Location query string from a parameter object.
function locationQuery(params, location) {
return (location ? location: '') + '?' + Object.keys(params).map(function (key) {
return encodeURIComponent(key) + '=' +
encodeURIComponent(params[key]);
}).join('&');
}
// Derive a fetch URL from the current URL, sans the GraphQL parameters.
var graphqlParamNames = {
query: true,
variables: true,
operationName: true
};
var otherParams = {};
for (var k in parameters) {
if (parameters.hasOwnProperty(k) && graphqlParamNames[k] !== true) {
otherParams[k] = parameters[k];
}
}
var fetcher;
if (true) {
var subscriptionsClient = new window.SubscriptionsTransportWs.SubscriptionClient('${subscriptionsEndpoint}', {
reconnect: true
});
fetcher = window.GraphiQLSubscriptionsFetcher.graphQLFetcher(subscriptionsClient, graphQLFetcher);
} else {
fetcher = graphQLFetcher;
}
// We don't use safe-serialize for location, because it's not client input.
var fetchURL = locationQuery(otherParams, '${endpointURL}');
// Defines a GraphQL fetcher using the fetch API.
function graphQLFetcher(graphQLParams) {
return fetch(fetchURL, {
method: 'post',
headers: {
'Accept': 'application/json',
'Content-Type': 'application/json',
},
body: JSON.stringify(graphQLParams),
credentials: 'include',
}).then(function (response) {
return response.text();
}).then(function (responseBody) {
try {
return JSON.parse(responseBody);
} catch (error) {
return responseBody;
}
});
}
// When the query and variables string is edited, update the URL bar so
// that it can be easily shared.
function onEditQuery(newQuery) {
parameters.query = newQuery;
updateURL();
}
function onEditVariables(newVariables) {
parameters.variables = newVariables;
updateURL();
}
function onEditOperationName(newOperationName) {
parameters.operationName = newOperationName;
updateURL();
}
function updateURL() {
history.replaceState(null, null, locationQuery(parameters) + window.location.hash);
}
// Render <GraphiQL /> into the body.
ReactDOM.render(
React.createElement(GraphiQL, {
fetcher: fetcher,
onEditQuery: onEditQuery,
onEditVariables: onEditVariables,
onEditOperationName: onEditOperationName,
}),
document.body
);
</script>
</body>
</html>''').substitute(
GRAPHIQL_VERSION='0.10.2',
SUBSCRIPTIONS_TRANSPORT_VERSION='0.7.0',
subscriptionsEndpoint='ws:https://localhost:5000/subscriptions',
# subscriptionsEndpoint='ws:https://localhost:5000/',
endpointURL='/graphql',
)
1 change: 1 addition & 0 deletions examples/src/graphql
Submodule graphql added at ebcd7f
2 changes: 1 addition & 1 deletion graphql_ws/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@

from .observable_aiter import setup_observable_extension
from .server import WebSocketSubscriptionServer

from .gevent_server import GeventSubscriptionServer

setup_observable_extension()
15 changes: 15 additions & 0 deletions graphql_ws/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
GRAPHQL_WS = 'graphql-ws'
WS_PROTOCOL = GRAPHQL_WS

GQL_CONNECTION_INIT = 'connection_init' # Client -> Server
GQL_CONNECTION_ACK = 'connection_ack' # Server -> Client
GQL_CONNECTION_ERROR = 'connection_error' # Server -> Client

# NOTE: This one here don't follow the standard due to connection optimization
GQL_CONNECTION_TERMINATE = 'connection_terminate' # Client -> Server
GQL_CONNECTION_KEEP_ALIVE = 'ka' # Server -> Client
GQL_START = 'start' # Client -> Server
GQL_DATA = 'data' # Server -> Client
GQL_ERROR = 'error' # Server -> Client
GQL_COMPLETE = 'complete' # Server -> Client
GQL_STOP = 'stop' # Client -> Server
120 changes: 120 additions & 0 deletions graphql_ws/gevent_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import json

from graphql import format_error, graphql
from graphql.execution.executors.sync import SyncExecutor
from rx import Observer, Observable
from .server import BaseWebSocketSubscriptionServer, ConnectionContext, ConnectionClosedException
from .constants import *


class GEventConnectionContext(ConnectionContext):

def receive(self):
msg = self.ws.receive()
return msg

def send(self, data):
if self.closed:
return
self.ws.send(data)

@property
def closed(self):
return self.ws.closed

def close(self, code):
self.ws.close(code)

class GeventSubscriptionServer(BaseWebSocketSubscriptionServer):

def get_graphql_params(self, *args, **kwargs):
params = super(GeventSubscriptionServer, self).get_graphql_params(*args, **kwargs)
return dict(params, executor=SyncExecutor())

def handle(self, ws):
connection_context = GEventConnectionContext(ws)
self.on_open(connection_context)
while True:
try:
if connection_context.closed:
raise ConnectionClosedException()
message = connection_context.receive()
except ConnectionClosedException:
self.on_close(connection_context)
return
self.on_message(connection_context, message)

def on_message(self, connection_context, message):
try:
parsed_message = json.loads(message)
assert isinstance(
parsed_message, dict), "Payload must be an object."
except Exception as e:
self.send_error(connection_context, None, e)
return

self.process_message(connection_context, parsed_message)

def on_open(self, connection_context):
pass

def on_connect(self, connection_context, payload):
pass

def on_close(self, connection_context):
remove_operations = list(connection_context.operations.keys())
for op_id in remove_operations:
self.unsubscribe(connection_context, op_id)

def on_connection_init(self, connection_context, op_id, payload):
try:
self.on_connect(connection_context, payload)
self.send_message(connection_context, op_type=GQL_CONNECTION_ACK)

except Exception as e:
self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR)
connection_context.close(1011)

def on_connection_terminate(self, connection_context, op_id):
connection_context.close(1011)


def on_start(self, connection_context, op_id, params):
try:
execution_result = graphql(
self.schema, **params, allow_subscriptions=True
)
assert isinstance(
execution_result, Observable), "A subscription must return an observable"
execution_result.subscribe(SubscriptionObserver(
connection_context,
op_id,
self.send_execution_result,
self.send_error,
self.on_close
)
)
except Exception as e:
self.send_error(connection_context, op_id, str(e))

def on_stop(self, connection_context, op_id):
self.unsubscribe(connection_context, op_id)


class SubscriptionObserver(Observer):

def __init__(self, connection_context, op_id, send_execution_result, send_error, on_close):
self.connection_context = connection_context
self.op_id = op_id
self.send_execution_result = send_execution_result
self.send_error = send_error
self.on_close = on_close

def on_next(self, value):
self.send_execution_result(self.connection_context, self.op_id, value)

def on_completed(self):
self.on_close(self.connection_context)

def on_error(self, error):
self.send_error(self.connection_context, self.op_id, error)
20 changes: 4 additions & 16 deletions graphql_ws/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,12 @@
from websockets.protocol import CONNECTING, OPEN
from inspect import isawaitable, isasyncgen
from graphql import graphql, format_error
from graphql.execution import ExecutionResult
from collections import OrderedDict
import json
from .constants import *


GRAPHQL_WS = 'graphql-ws'
WS_PROTOCOL = GRAPHQL_WS

GQL_CONNECTION_INIT = 'connection_init' # Client -> Server
GQL_CONNECTION_ACK = 'connection_ack' # Server -> Client
GQL_CONNECTION_ERROR = 'connection_error' # Server -> Client

# NOTE: This one here don't follow the standard due to connection optimization
GQL_CONNECTION_TERMINATE = 'connection_terminate' # Client -> Server
GQL_CONNECTION_KEEP_ALIVE = 'ka' # Server -> Client
GQL_START = 'start' # Client -> Server
GQL_DATA = 'data' # Server -> Client
GQL_ERROR = 'error' # Server -> Client
GQL_COMPLETE = 'complete' # Server -> Client
GQL_STOP = 'stop' # Client -> Server


class ConnectionClosedException(Exception):
pass
Expand All @@ -48,6 +34,7 @@ def get_operation(self, op_id):
def remove_operation(self, op_id):
del self.operations[op_id]



class AioHTTPConnectionContext(ConnectionContext):
async def receive(self):
Expand Down Expand Up @@ -186,6 +173,7 @@ def on_operation_complete(self, connection_context, op_id):
pass



class WebSocketSubscriptionServer(BaseWebSocketSubscriptionServer):

def get_graphql_params(self, *args, **kwargs):
Expand Down
Loading

0 comments on commit b3e5459

Please sign in to comment.