forked from torch/threads
-
Notifications
You must be signed in to change notification settings - Fork 0
/
threads.lua
319 lines (278 loc) · 7.94 KB
/
threads.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
local Queue = require 'threads.queue'
local clib = require 'libthreads'
local _unpack = unpack or table.unpack
local Threads = {}
local Threads_ctor = {}
setmetatable(
Threads_ctor, {
__newindex = Threads,
__index = Threads,
__call =
function(self, ...)
return Threads.new(...)
end
}
)
Threads.__index = Threads
Threads.__serialize = "threads.serialize"
-- GC: lua 5.2
Threads.__gc =
function(self)
self:terminate()
end
function Threads.serialization(name)
if name then
assert(type(name) == 'string')
Threads.__serialize = name
else
return Threads.__serialize
end
end
function Threads.new(N, ...)
local self = {N=N, endcallbacks={n=0}, errors=false, __specific=true, __running=true}
local funcs = {...}
local serialize = require(Threads.__serialize)
if #funcs == 0 then
funcs = {function() end}
end
setmetatable(self, Threads)
self.mainqueue = Queue(N, Threads.__serialize)
self.threadqueue = Queue(N, Threads.__serialize)
self.threadspecificqueues = {}
self.mainqueue:retain() -- terminate will free it
self.threadqueue:retain() -- terminate will free it
self.threads = {}
for i=1,N do
self.threadspecificqueues[i] = Queue(N, Threads.__serialize)
self.threadspecificqueues[i]:retain() -- terminate will free it
local thread = clib.Thread(
string.format(
[[
local Queue = require 'threads.queue'
__threadid = %d
local mainqueue = Queue(%d)
local threadqueue = Queue(%d)
local threadspecificqueue = Queue(%d)
local threadid = __threadid
__queue_running = true
__queue_specific = true
while __queue_running do
local status, res, endcallbackid
if __queue_specific then
status, res, endcallbackid = threadspecificqueue:dojob()
else
status, res, endcallbackid = threadqueue:dojob()
end
mainqueue:addjob(function()
return status, res, endcallbackid, threadid
end)
end
]],
i,
self.mainqueue:id(),
self.threadqueue:id(),
self.threadspecificqueues[i]:id()
))
assert(thread, string.format('%d-th thread creation failed', i))
table.insert(self.threads, thread)
end
-- GC: lua 5.1
if newproxy then
self.__gc__ = newproxy(true)
getmetatable(self.__gc__).__gc =
function()
self:terminate() -- all the queues must be alive (hence the retains above)
end
end
local initres = {}
for j=1,#funcs do
for i=1,self.N do
if j ~= #funcs then
self:addjob(
i, -- specific
funcs[j],
function()
end,
i -- passed to callback
)
else
self:addjob(
i, -- specific
funcs[j],
function(...)
table.insert(initres, {...})
end,
i -- passed to callback
)
end
end
end
self:specific(false)
return self, initres
end
function Threads:isrunning()
return self.__running
end
local function checkrunning(self)
assert(self:isrunning(), 'thread system is not running')
end
function Threads:specific(flag)
checkrunning(self)
if flag ~= nil then
assert(type(flag) == 'boolean', 'boolean expected')
self:synchronize() -- finish jobs first
if self.__specific ~= flag then
if self.__specific then
for i=1,self.N do
self:addjob(i,
function()
__queue_specific = false
end)
end
else
for i=1,self.N do
self:addjob(function()
__queue_specific = true
end)
end
end
self.__specific = flag
self:synchronize() -- finish jobs
end
else
return self.__specific
end
end
function Threads:dojob()
checkrunning(self)
self.errors = false
local callstatus, args, endcallbackid, threadid = self.mainqueue:dojob()
local endcallback = self.endcallbacks[endcallbackid]
self.endcallbacks[endcallbackid] = nil
self.endcallbacks.n = self.endcallbacks.n - 1
if callstatus then
local endcallstatus, msg = xpcall(
function() return endcallback(_unpack(args)) end,
debug.traceback)
if not endcallstatus then
self.errors = true
error(string.format('[thread %d endcallback] %s', threadid, msg))
end
else
self.errors = true
error(string.format('[thread %d callback] %s', threadid, args[1]))
end
end
function Threads:acceptsjob(idx)
checkrunning(self)
local threadqueue
if self:specific() then
assert(type(idx) == 'number' and idx >= 1 and idx <= self.N, 'thread index expected')
threadqueue = self.threadspecificqueues[idx]
else
threadqueue = self.threadqueue
end
return threadqueue.isfull ~= 1
end
function Threads:addjob(...) -- endcallback is passed with returned values of callback
checkrunning(self)
self.errors = false
local endcallbacks = self.endcallbacks
local idx, threadqueue, r, callback, endcallback
if self:specific() then
idx = select(1, ...)
assert(type(idx) == 'number' and idx >= 1 and idx <= self.N, 'thread index expected')
threadqueue = self.threadspecificqueues[idx]
callback = select(2, ...)
endcallback = select(3, ...)
r = 4
else
callback = select(1, ...)
endcallback = select(2, ...)
threadqueue = self.threadqueue
r = 3
end
assert(type(callback) == 'function', 'function callback expected')
assert(type(endcallback) == 'function' or type(endcallback) == 'nil', 'function (or nil) endcallback expected')
-- finish running jobs if no space available
while not self:acceptsjob(idx) do
self:dojob()
end
-- now add a new endcallback in the list
local endcallbackid = #endcallbacks+1
endcallbacks[endcallbackid] = endcallback or function() end
endcallbacks.n = endcallbacks.n + 1
local func = function(...)
local args = {...}
local res = {
xpcall(
function()
local _unpack = unpack or table.unpack
return callback(_unpack(args))
end,
debug.traceback)}
local status = table.remove(res, 1)
return status, res, endcallbackid
end
threadqueue:addjob(func, select(r, ...))
end
function Threads:haserror()
-- DEPRECATED; errors are now propagated immediately
-- so the caller doesn't need to explicitly do anything to manage them
return false
end
function Threads:hasjob()
checkrunning(self)
return self.endcallbacks.n > 0
end
function Threads:synchronize()
if not self:isrunning() then
return
end
self.errors = false
while self:hasjob()do
self:dojob()
end
end
function Threads:terminate()
if not self:isrunning() or self.errors then
return
end
local function exit()
-- terminate the threads
for i=1,self.N do
if self:specific() then
self:addjob(
i,
function()
__queue_running = false
end)
else
self:addjob(
function()
__queue_running = false
end)
end
end
-- terminate all jobs
self:synchronize()
-- wait for threads to exit (and free them)
for i=1,self.N do
self.threads[i]:free()
end
-- release the queues
self.mainqueue:free()
self.threadqueue:free()
for i=1,self.N do
self.threadspecificqueues[i]:free()
end
end
-- exit and check for errors
local status, err = pcall(exit)
-- make sure you won't run anything
self.__running = false
if not status then
error(err)
end
end
return Threads_ctor