Skip to content

Commit

Permalink
more flexible code generation functions
Browse files Browse the repository at this point in the history
Now the code-gen allows generating reduction functions that may contain more than one input arrays.
  • Loading branch information
lindahua committed Jan 4, 2014
1 parent 9ecc321 commit da2d02c
Showing 1 changed file with 43 additions and 36 deletions.
79 changes: 43 additions & 36 deletions base/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,13 @@ function rcompress_dims{N}(siz::NTuple{N,Int}, region)
return (isrd[1], sdims)
end

function generate_reducedim_funcs(fname, comb, sker, ker0!, ker1!)
function generate_reducedim_funcs(fname, params, args, sizexpr, comb, sker, ker0!, ker1!)
# Parameters:
#
# - fname: the interface function name (e.g. sum, maximum)
# - fname: the interface function name (e.g. sum, maximum)
# - params: a list of input parameters in function signatures
# - args: a list of argument symbols
# - sizexpr: the expression that calculates the input size
# - comb: the combination operation (e.g. +)
# - sker: a kernel function that reduces a vector (or a range of it) to a scalar
# - ker0!: a kernel that initializes an accumulator array using the first column of terms
Expand All @@ -179,77 +182,77 @@ function generate_reducedim_funcs(fname, comb, sker, ker0!, ker1!)

quote
global $(fname!)
function $(fname!)(dst::Array, a::Array, dim::Integer)
nd = ndims(a)
siz = size(a)
function $(fname!)(dst::Array, $(params...), dim::Integer)
siz = $(sizexpr)
nd = length(siz)
if 1 <= dim <= nd
if dim == 1
$(fa!)(true, dst, 0, a, 0, prod(siz[2:nd]), siz[1])
$(fa!)(true, dst, 0, $(args...), 0, prod(siz[2:nd]), siz[1])
elseif dim == nd
$(fb!)(true, dst, 0, a, 0, siz[nd], prod(siz[1:nd-1]))
$(fb!)(true, dst, 0, $(args...), 0, siz[nd], prod(siz[1:nd-1]))
else
$(fb!)(true, dst, 0, a, 0, prod(siz[dim+1:nd]), siz[dim], prod(siz[1:dim-1]))
$(fb!)(true, dst, 0, $(args...), 0, prod(siz[dim+1:nd]), siz[dim], prod(siz[1:dim-1]))
end
else
$(ker0!)(dst, 1, a, 1, length(a))
$(ker0!)(dst, 1, $(args...), 1, prod(siz))
end
dst
end

function $(fname!)(dst::Array, a::Array, region)
function $(fname!)(dst::Array, $(params...), region)
if length(region) == 1
$(fname!)(dst, a, region[1])
$(fname!)(dst, $(args...), region[1])
else
isrd1, secs = rcompress_dims(size(a), region)
isrd1, secs = rcompress_dims($(sizexpr), region)
if isrd1
$(fa!)(true, dst, 0, a, 0, secs[end:-1:1]...)
$(fa!)(true, dst, 0, $(args...), 0, secs[end:-1:1]...)
else
$(fb!)(true, dst, 0, a, 0, secs[end:-1:1]...)
$(fb!)(true, dst, 0, $(args...), 0, secs[end:-1:1]...)
end
end
dst
end

# $(fa!)
global $(fa!)
function $(fa!)(isinit::Bool, dst::Array, od::Int, a::Array, oa::Int, n1::Int)
function $(fa!)(isinit::Bool, dst::Array, od::Int, $(params...), oa::Int, n1::Int)
if isinit
dst[od+1] = $(sker)(a, oa+1, oa+n1)
dst[od+1] = $(sker)($(args...), oa+1, oa+n1)
else
dst[od+1] = $(comb)(dst[od+1], $(sker)(a, oa+1, oa+n1))
dst[od+1] = $(comb)(dst[od+1], $(sker)($(args...), oa+1, oa+n1))
end
end

function $(fa!)(isinit::Bool, dst::Array, od::Int, a::Array, oa::Int, n1::Int, n2::Int)
function $(fa!)(isinit::Bool, dst::Array, od::Int, $(params...), oa::Int, n1::Int, n2::Int)
if isinit
for j = 1:n1
alast = oa + n2
dst[od+j] = $(sker)(a, oa+1, alast)
dst[od+j] = $(sker)($(args...), oa+1, alast)
oa = alast
end
else
for j = 1:n1
alast = oa + n2
dst[od+j] = $(comb)(dst[od+j], $(sker)(a, oa+1, alast))
dst[od+j] = $(comb)(dst[od+j], $(sker)($(args...), oa+1, alast))
oa = alast
end
end
end

function $(fa!)(isinit::Bool, dst::Array, od::Int, a::Array, oa::Int, n1::Int, n2::Int, n3::Int, ns::Int...)
function $(fa!)(isinit::Bool, dst::Array, od::Int, $(params...), oa::Int, n1::Int, n2::Int, n3::Int, ns::Int...)
as::Int = *(n2, n3, ns...)
if length(ns) & 1 == 0
$(fa!)(isinit, dst, od, a, oa, n2, n3, ns...)
$(fa!)(isinit, dst, od, $(args...), oa, n2, n3, ns...)
oa += as

for j = 2:n1
$(fa!)(false, dst, od, a, oa, n2, n3, ns...)
$(fa!)(false, dst, od, $(args...), oa, n2, n3, ns...)
oa += as
end
else
ds::Int = *(n3, ns[2:2:end]...)
for j = 1:n1
$(fa!)(isinit, dst, od, a, oa, n2, n3, ns...)
$(fa!)(isinit, dst, od, $(args...), oa, n2, n3, ns...)
od += ds
oa += as
end
Expand All @@ -258,50 +261,54 @@ function generate_reducedim_funcs(fname, comb, sker, ker0!, ker1!)

# $(fb!)
global $(fb!)
function $(fb!)(isinit::Bool, dst::Array, od::Int, a::Array, oa::Int, n1::Int)
function $(fb!)(isinit::Bool, dst::Array, od::Int, $(params...), oa::Int, n1::Int)
if isinit
$(ker0!)(dst, od+1, a, oa+1, n1)
$(ker0!)(dst, od+1, $(args...), oa+1, n1)
else
$(ker1!)(dst, od+1, a, oa+1, n1)
$(ker1!)(dst, od+1, $(args...), oa+1, n1)
end
end

function $(fb!)(isinit::Bool, dst::Array, od::Int, a::Array, oa::Int, n1::Int, n2::Int)
function $(fb!)(isinit::Bool, dst::Array, od::Int, $(params...), oa::Int, n1::Int, n2::Int)
if isinit
$(ker0!)(dst, od+1, a, oa+1, n2)
$(ker0!)(dst, od+1, $(args...), oa+1, n2)
else
$(ker1!)(dst, od+1, a, oa+1, n2)
$(ker1!)(dst, od+1, $(args...), oa+1, n2)
end
oa += n2

for j = 2:n1
$(ker1!)(dst, od+1, a, oa+1, n2)
$(ker1!)(dst, od+1, $(args...), oa+1, n2)
oa += n2
end
end

function $(fb!)(isinit::Bool, dst::Array, od::Int, a::Array, oa::Int, n1::Int, n2::Int, n3::Int, ns::Int...)
function $(fb!)(isinit::Bool, dst::Array, od::Int, $(params...), oa::Int, n1::Int, n2::Int, n3::Int, ns::Int...)
as = *(n2, n3, ns...)
if length(ns) & 1 == 0
ds::Int = *(n3, ns[2:2:end]...)
for j = 1:n1
$(fb!)(isinit, dst, od, a, oa, n2, n3, ns...)
$(fb!)(isinit, dst, od, $(args...), oa, n2, n3, ns...)
od += ds
oa += as
end
else
$(fb!)(isinit, dst, od, a, oa, n2, n3, ns...)
$(fb!)(isinit, dst, od, $(args...), oa, n2, n3, ns...)
oa += as

for j = 2:n1
$(fb!)(false, dst, od, a, oa, n2, n3, ns...)
$(fb!)(false, dst, od, $(args...), oa, n2, n3, ns...)
oa += as
end
end
end
end
end

function generate_reducedim_funcs(fname, comb, sker, ker0!, ker1!)
# specialized method to generate functions with single input arguments
generate_reducedim_funcs(fname, [:(a::Array)], [:a], :(size(a)), comb, sker, ker0!, ker1!)
end

macro code_reducedim(fname, comb, sker, ker0, ker1)
esc(generate_reducedim_funcs(fname, comb, sker, ker0, ker1))
end
Expand Down

0 comments on commit da2d02c

Please sign in to comment.