Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: Optimize normalize_reductions_dense #2311

Merged
merged 1 commit into from
Feb 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions devito/ir/clusters/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,11 +475,15 @@ def pull_indexeds(expr, subs, mapper, parent=None):
return cluster.rebuild(processed)


@cluster_pass(mode='dense')
def normalize_reductions_dense(cluster, sregistry, options):
"""
Extract the right-hand sides of reduction Eq's in to temporaries.
"""
return _normalize_reductions_dense(cluster, sregistry, options, {})


@cluster_pass(mode='dense')
def _normalize_reductions_dense(cluster, sregistry, options, mapper):
opt_mapify_reduce = options['mapify-reduce']

dims = [d for d in cluster.ispace.itdims
Expand All @@ -499,20 +503,25 @@ def normalize_reductions_dense(cluster, sregistry, options):
# `s += r[x]`
# This makes it much easier to parallelize the map part regardless
# of the target backend
lhs, rhs = e.args

if e.lhs.function.is_Array:
if lhs.function.is_Array:
# Probably a compiler-generated reduction, e.g. via
# recursive compilation; it's an Array already, so nothing to do
processed.append(e)
elif rhs in mapper:
mloubout marked this conversation as resolved.
Show resolved Hide resolved
# Seen this RHS already, so reuse the Array that was created for it
processed.append(e.func(lhs, mapper[rhs].indexify()))
else:
# Here the LHS could be a Symbol or a user-level Function
# In the latter case we copy the data into a temporary Array
# because the Function might be padded, and reduction operations
# require, in general, the data values to be contiguous
name = sregistry.make_name()
a = Array(name=name, dtype=e.dtype, dimensions=dims)
processed.extend([Eq(a.indexify(), e.rhs),
e.func(e.lhs, a.indexify())])
a = mapper[rhs] = Array(name=name, dtype=e.dtype, dimensions=dims)

processed.extend([Eq(a.indexify(), rhs),
e.func(lhs, a.indexify())])
else:
processed.append(e)

Expand Down
Loading