-
Notifications
You must be signed in to change notification settings - Fork 0
/
dispatch.lua
310 lines (292 loc) · 10.2 KB
/
dispatch.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
-----------------------------------------------------------------------------
-- A hacked dispatcher module
-- LuaSocket sample files
-- Author: Diego Nehab
-- RCS ID: $$
-----------------------------------------------------------------------------
local base = _G
local table = require("table")
local socket = require("socket")
local coroutine = require("coroutine")
local print = print
local debug = debug
local assert = assert
module("dispatch")
-- if too much time goes by without any activity in one of our sockets, we
-- just kill it
TIMEOUT = 60
-----------------------------------------------------------------------------
-- We implement 3 types of dispatchers:
-- sequential
-- coroutine
-- threaded
-- The user can choose whatever one is needed
-----------------------------------------------------------------------------
local handlert = {}
-- default handler is coroutine
function newhandler(mode)
mode = mode or "coroutine"
return handlert[mode]()
end
local function seqstart(self, func)
return func()
end
-- sequential handler simply calls the functions and doesn't wrap I/O
function handlert.sequential()
return {
tcp = socket.tcp,
start = seqstart
}
end
-----------------------------------------------------------------------------
-- Mega hack. Don't try to do this at home.
-----------------------------------------------------------------------------
-- we can't yield across calls to protect, so we rewrite it with coxpcall
-- make sure you don't require any module that uses socket.protect before
-- loading our hack
function socket.protect(f)
return function(...)
local co = coroutine.create(f)
while true do
local results = {coroutine.resume(co, base.unpack(arg))}
local status = table.remove(results, 1)
if not status then
if type(results[1]) == 'table' then
return nil, results[1][1]
else base.error(results[1]) end
end
if coroutine.status(co) == "suspended" then
arg = {coroutine.yield(base.unpack(results))}
else
return base.unpack(results)
end
end
end
end
-----------------------------------------------------------------------------
-- Simple set data structure. O(1) everything.
-----------------------------------------------------------------------------
local function newset()
local reverse = {}
local set = {}
return base.setmetatable(set, {__index = {
insert = function(set, value)
if not reverse[value] then
table.insert(set, value)
reverse[value] = table.getn(set)
end
end,
remove = function(set, value)
local index = reverse[value]
if index then
reverse[value] = nil
local top = table.remove(set)
if top ~= value then
reverse[top] = index
set[index] = top
end
end
end
}})
end
-----------------------------------------------------------------------------
-- socket.tcp() wrapper for the coroutine dispatcher
-----------------------------------------------------------------------------
local function cowrap(dispatcher, tcp, error)
if not tcp then return nil, error end
-- put it in non-blocking mode right away
tcp:settimeout(0)
-- metatable for wrap produces new methods on demand for those that we
-- don't override explicitly.
local metat = { __index = function(table, key)
table[key] = function(...)
arg[1] = tcp
return tcp[key](base.unpack(arg))
end
return table[key]
end}
-- does our user want to do his own non-blocking I/O?
local zero = false
-- create a wrap object that will behave just like a real socket object
local wrap = { }
-- we ignore settimeout to preserve our 0 timeout, but record whether
-- the user wants to do his own non-blocking I/O
function wrap:settimeout(value, mode)
if value == 0 then zero = true
else zero = false end
return 1
end
-- send in non-blocking mode and yield on timeout
function wrap:send(data, first, last)
first = (first or 1) - 1
local result, error
while true do
-- return control to dispatcher and tell it we want to send
-- if upon return the dispatcher tells us we timed out,
-- return an error to whoever called us
if coroutine.yield(dispatcher.sending, tcp) == "timeout" then
return nil, "timeout"
end
-- try sending
result, error, first = tcp:send(data, first+1, last)
-- if we are done, or there was an unexpected error,
-- break away from loop
if error ~= "timeout" then return result, error, first end
end
end
-- receive in non-blocking mode and yield on timeout
-- or simply return partial read, if user requested timeout = 0
function wrap:receive(pattern, partial)
local error = "timeout"
local value
while true do
-- return control to dispatcher and tell it we want to receive
-- if upon return the dispatcher tells us we timed out,
-- return an error to whoever called us
if coroutine.yield(dispatcher.receiving, tcp) == "timeout" then
return nil, "timeout"
end
-- try receiving
value, error, partial = tcp:receive(pattern, partial)
-- if we are done, or there was an unexpected error,
-- break away from loop. also, if the user requested
-- zero timeout, return all we got
if (error ~= "timeout") or zero then
return value, error, partial
end
end
end
-- connect in non-blocking mode and yield on timeout
function wrap:connect(host, port)
local result, error = tcp:connect(host, port)
if error == "timeout" then
-- return control to dispatcher. we will be writable when
-- connection succeeds.
-- if upon return the dispatcher tells us we have a
-- timeout, just abort
if coroutine.yield(dispatcher.sending, tcp) == "timeout" then
return nil, "timeout"
end
-- when we come back, check if connection was successful
result, error = tcp:connect(host, port)
if result or error == "already connected" then return 1
else return nil, "non-blocking connect failed" end
else return result, error end
end
-- accept in non-blocking mode and yield on timeout
function wrap:accept()
while 1 do
-- return control to dispatcher. we will be readable when a
-- connection arrives.
-- if upon return the dispatcher tells us we have a
-- timeout, just abort
if coroutine.yield(dispatcher.receiving, tcp) == "timeout" then
return nil, "timeout"
end
print("before accept")
local client, error = tcp:accept()
print("after accept")
assert(nil)
print(debug.traceback())
if error ~= "timeout" then
return cowrap(dispatcher, client, error)
end
end
end
-- remove cortn from context
function wrap:close()
dispatcher.stamp[tcp] = nil
dispatcher.sending.set:remove(tcp)
dispatcher.sending.cortn[tcp] = nil
dispatcher.receiving.set:remove(tcp)
dispatcher.receiving.cortn[tcp] = nil
return tcp:close()
end
return base.setmetatable(wrap, metat)
end
-----------------------------------------------------------------------------
-- Our coroutine dispatcher
-----------------------------------------------------------------------------
local cometat = { __index = {} }
function schedule(cortn, status, operation, tcp)
if status then
if cortn and operation then
operation.set:insert(tcp)
operation.cortn[tcp] = cortn
operation.stamp[tcp] = socket.gettime()
end
else base.error(operation) end
end
function kick(operation, tcp)
operation.cortn[tcp] = nil
operation.set:remove(tcp)
end
function wakeup(operation, tcp)
local cortn = operation.cortn[tcp]
-- if cortn is still valid, wake it up
if cortn then
kick(operation, tcp)
return cortn, coroutine.resume(cortn)
-- othrewise, just get scheduler not to do anything
else
return nil, true
end
end
function abort(operation, tcp)
local cortn = operation.cortn[tcp]
if cortn then
kick(operation, tcp)
coroutine.resume(cortn, "timeout")
end
end
-- step through all active cortns
function cometat.__index:step()
-- check which sockets are interesting and act on them
local readable, writable = socket.select(self.receiving.set,
self.sending.set, 1)
-- for all readable connections, resume their cortns and reschedule
-- when they yield back to us
for _, tcp in base.ipairs(readable) do
print("...schedule wakeup")
schedule(wakeup(self.receiving, tcp))
end
-- for all writable connections, do the same
for _, tcp in base.ipairs(writable) do
schedule(wakeup(self.sending, tcp))
end
-- politely ask replacement I/O functions in idle cortns to
-- return reporting a timeout
local now = socket.gettime()
for tcp, stamp in base.pairs(self.stamp) do
if tcp.class == "tcp{client}" and now - stamp > TIMEOUT then
abort(self.sending, tcp)
abort(self.receiving, tcp)
end
end
end
function cometat.__index:start(func)
local cortn = coroutine.create(func)
schedule(cortn, coroutine.resume(cortn))
end
function handlert.coroutine()
local stamp = {}
local dispatcher = {
stamp = stamp,
sending = {
name = "sending",
set = newset(),
cortn = {},
stamp = stamp
},
receiving = {
name = "receiving",
set = newset(),
cortn = {},
stamp = stamp
},
}
function dispatcher.tcp()
return cowrap(dispatcher, socket.tcp())
end
return base.setmetatable(dispatcher, cometat)
end