Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

run_distributed method does not consider batch size #22

Open
maxdreyer opened this issue Apr 27, 2023 · 0 comments
Open

run_distributed method does not consider batch size #22

maxdreyer opened this issue Apr 27, 2023 · 0 comments

Comments

@maxdreyer
Copy link
Contributor

Hi @rachtibat,

the run_distributed method of the FeatureVisualization class does not take into account the actual batch_size for the multi-target case.

Maybe include something like:

if n_samples > batch_size:
    batches_ = math.ceil(len(conditions) / batch_size)
else:
    batches_ = 1

for b_ in range(batches_):
    data_broadcast_ = data_broadcast[b_ * batch_size: (b_ + 1) * batch_size]
    # print(len(conditions), len(data_broadcast_))
    conditions_ = conditions[b_ * batch_size: (b_ + 1) * batch_size]
    # dict_inputs is linked to FeatHooks
    dict_inputs["sample_indices"] = sample_indices[b_ * batch_size: (b_ + 1) * batch_size]
    dict_inputs["targets"] = targets[b_ * batch_size: (b_ + 1) * batch_size]

# composites are already registered before
    self.attribution(data_broadcast_, conditions_, None, exclude_parallel=False)

This would fix some GPU memory issue of mine.

Best,
Max

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant