-
Notifications
You must be signed in to change notification settings - Fork 9
/
util.lua
71 lines (65 loc) · 2.19 KB
/
util.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
-- Source: https://github.com/facebook/fbcunn/blob/master/examples/imagenet/util.lua
local ffi=require 'ffi'
------ Some FFI stuff used to pass storages between threads ------------------
ffi.cdef[[
void THFloatStorage_free(THFloatStorage *self);
void THLongStorage_free(THLongStorage *self);
]]
function setFloatStorage(tensor, storage_p)
assert(storage_p and storage_p ~= 0, "FloatStorage is NULL pointer");
local cstorage = ffi.cast('THFloatStorage*', torch.pointer(tensor:storage()))
if cstorage ~= nil then
ffi.C['THFloatStorage_free'](cstorage)
end
local storage = ffi.cast('THFloatStorage*', storage_p)
tensor:cdata().storage = storage
end
function setLongStorage(tensor, storage_p)
assert(storage_p and storage_p ~= 0, "LongStorage is NULL pointer");
local cstorage = ffi.cast('THLongStorage*', torch.pointer(tensor:storage()))
if cstorage ~= nil then
ffi.C['THLongStorage_free'](cstorage)
end
local storage = ffi.cast('THLongStorage*', storage_p)
tensor:cdata().storage = storage
end
function sendTensor(inputs)
local size = inputs:size()
local ttype = inputs:type()
local i_stg = tonumber(ffi.cast('intptr_t', torch.pointer(inputs:storage())))
inputs:cdata().storage = nil
return {i_stg, size, ttype}
end
function receiveTensor(obj, buffer)
local pointer = obj[1]
local size = obj[2]
local ttype = obj[3]
if buffer then
buffer:resize(size)
assert(buffer:type() == ttype, 'Buffer is wrong type')
else
buffer = torch[ttype].new():resize(size)
end
if ttype == 'torch.FloatTensor' then
setFloatStorage(buffer, pointer)
elseif ttype == 'torch.LongTensor' then
setLongStorage(buffer, pointer)
else
error('Unknown type')
end
return buffer
end
function makeDataParallel(model, nGPU)
if nGPU > 1 then
print('converting module to nn.DataParallelTable')
assert(nGPU <= cutorch.getDeviceCount(), 'number of GPUs less than nGPU specified')
local model_single = model
model = nn.DataParallelTable(1)
for i=1, opt.nGPU do
cutorch.setDevice(i)
model:add(model_single:clone():cuda(), i)
end
cutorch.setDevice(opt.GPU)
end
return model
end