Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support next brian release #249

Merged
merged 37 commits into from
Oct 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
cae6205
Update brian2 sumodule to PR brian-team/brian2#1297
denisalevi Aug 19, 2021
fa9a61d
WIP: Add const variables instead of literal replacement
denisalevi Aug 19, 2021
454e939
Move `HOST_CONSTANTS` before `define_N` block
denisalevi Aug 19, 2021
3377404
Add `N` declaration in spikegenerator reset kernel
denisalevi Aug 19, 2021
b821bee
Add `N` declaration to thresholder kernel
denisalevi Aug 20, 2021
f12573a
Remove an unnecessary define for `inf`
denisalevi Aug 20, 2021
b8118a9
Add test for uneven group size RNG bug
denisalevi Aug 26, 2021
28e722e
Fix number of generated random number in single run codeobjects
denisalevi Aug 26, 2021
b9fe96a
Comment out developmental profiling in `spatialstateupdate`
denisalevi Aug 26, 2021
f0b8d54
Update brian2 submodule to PR brian-team/brian2#1280
denisalevi Aug 19, 2021
aa32b33
Fix printing brian2 preferences in debug mode
denisalevi Aug 26, 2021
ba4f5e3
Update brian2 submodule to PR brian-team/brian2#1324
denisalevi Aug 26, 2021
e48e0ff
Update brian2 submodule to PR brian-team/brian2#1294
denisalevi Aug 26, 2021
2ef726c
Support fixed size synapses
denisalevi Aug 19, 2021
015ccd0
Merge pull request #245 from brian-team/remove-constants-replacement
denisalevi Aug 26, 2021
28cef8c
Reset thresholder at end of run
denisalevi Aug 26, 2021
3ac9296
Support synapses generator expressions `i="..."`
denisalevi Aug 26, 2021
8cb8789
Merge pull request #246 from brian-team/fixed_size_synapses
denisalevi Aug 26, 2021
cf12c08
Use updated `exc_isintance` from brian2
denisalevi Aug 27, 2021
806be14
Merge pull request #247 from brian-team/connect_j_generator
denisalevi Aug 26, 2021
e040aa0
Merge pull request #248 from brian-team/fix-deactivating-spiking-objects
denisalevi Aug 27, 2021
0779378
Merge remote-tracking branch 'origin/master' into support-next-brian-…
mstimberg Oct 19, 2022
b0f5509
Update Brian2 submodule to PR brian-team/brian2#1338
mstimberg Oct 17, 2022
4ef97a1
Fix SpikeMonitor support for subgroups
mstimberg Oct 21, 2022
7cd2380
Update Brian2 submodule to PR brian-team/brian2#1352
mstimberg Oct 21, 2022
564a359
Update 'brian2.diff' after updating Brian2 to PR brian-team/brian2#1352
mstimberg Oct 21, 2022
384ba24
Update Brian2 submodule to PR brian-team/brian2#1343
mstimberg Oct 21, 2022
6c1b7cc
Update 'brian2.diff' after updating Brian2 to PR brian-team/brian2#1343
mstimberg Oct 21, 2022
0522322
Update Brian2 submodule to PR brian-team/brian2#1364
mstimberg Oct 21, 2022
6a44f19
Update 'brian2.diff' after updating Brian2 to PR brian-team/brian2#1364
mstimberg Oct 21, 2022
52231d0
Update Brian2 submodule to 2.5.0.1
mstimberg Oct 21, 2022
223c855
Update Brian2 submodule to 2.5.0.3
mstimberg Oct 21, 2022
d552da4
Update Brian2 submodule to 2.5.1
mstimberg Oct 21, 2022
d972fe7
Brian2 change: Avoid casting timesteps to 32bit in C++ standalone
mstimberg Oct 24, 2022
21688d9
Make Clock::i_end public again
mstimberg Oct 24, 2022
92706a1
Use minimal compute capability in test to pass on older GPUs
denisalevi Oct 26, 2022
13cb9d9
Update brian2 version dependency
denisalevi Oct 26, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions brian2cuda/brianlib/clocks.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
#include<math.h>

namespace {
inline int fround(double x)
inline int64_t fround(double x)
{
return (int)(x+0.5);
return (int64_t)(x+0.5);
};
};

Expand All @@ -18,8 +18,8 @@ class Clock
double epsilon;
double *dt;
int64_t *timestep;
double *t;
int64_t i_end;
double *t;
Clock(double _epsilon=1e-14) : epsilon(_epsilon) { i_end = 0;};
inline void tick()
{
Expand All @@ -29,20 +29,20 @@ class Clock
inline bool running() { return timestep[0]<i_end; };
void set_interval(double start, double end)
{
int i_start = fround(start/dt[0]);
int64_t i_start = fround(start/dt[0]);
double t_start = i_start*dt[0];
if(t_start==start || fabs(t_start-start)<=epsilon*fabs(t_start))
{
timestep[0] = i_start;
} else
{
timestep[0] = (int)ceil(start/dt[0]);
timestep[0] = (int64_t)ceil(start/dt[0]);
}
i_end = fround(end/dt[0]);
double t_end = i_end*dt[0];
if(!(t_end==end || fabs(t_end-end)<=epsilon*fabs(t_end)))
{
i_end = (int)ceil(end/dt[0]);
i_end = (int64_t)ceil(end/dt[0]);
}
}
};
Expand Down
1 change: 0 additions & 1 deletion brian2cuda/brianlib/common_math.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include<limits>
#include<stdlib.h>

#define inf (std::numeric_limits<double>::infinity())
#ifdef _MSC_VER
#define INFINITY (std::numeric_limits<double>::infinity())
#define NAN (std::numeric_limits<double>::quiet_NaN())
Expand Down
168 changes: 103 additions & 65 deletions brian2cuda/device.py

Large diffs are not rendered by default.

19 changes: 11 additions & 8 deletions brian2cuda/templates/common_group.cu
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,7 @@ _run_kernel_{{codeobj_name}}(

int tid = threadIdx.x;
int bid = blockIdx.x;

int _idx = bid * THREADS_PER_BLOCK + tid;
int _idx = bid * blockDim.x + tid;
int _vectorisation_idx = _idx;

///// KERNEL_CONSTANTS /////
Expand All @@ -129,8 +128,6 @@ _run_kernel_{{codeobj_name}}(
///// kernel_lines /////
{{kernel_lines|autoindent}}

assert(THREADS_PER_BLOCK == blockDim.x);

{% block additional_variables %}
{% endblock %}

Expand Down Expand Up @@ -158,6 +155,9 @@ _run_kernel_{{codeobj_name}}(
}
{% endblock kernel %}

{% block extra_kernel_definitions %}
{% endblock %}


void _run_{{codeobj_name}}()
{
Expand All @@ -169,15 +169,18 @@ void _run_{{codeobj_name}}()
{% endif %}
{% endblock %}

///// HOST_CONSTANTS ///////////
%HOST_CONSTANTS%

{% block define_N %}
{# N is a constant in most cases (NeuronGroup, etc.), but a scalar array for
synapses, we therefore have to take care to get its value in the right
way. #}
synapses, we therefore have to take care to get its value in the right
way. #}
const int _N = {{constant_or_scalar('N', variables['N'])}};
{% endblock %}

///// HOST_CONSTANTS ///////////
%HOST_CONSTANTS%
///// ADDITIONAL_HOST_CODE /////
%ADDITIONAL_HOST_CODE%

{% block host_maincode %}
{% endblock %}
Expand Down
3 changes: 0 additions & 3 deletions brian2cuda/templates/group_variable_set.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,6 @@

{# _num_group_idx is defined in HOST_CONSTANTS, so we can't set _N before #}
{% block define_N %}
{% endblock %}

{% block host_maincode %}
const int _N = _num_group_idx;
{% endblock %}

Expand Down
38 changes: 14 additions & 24 deletions brian2cuda/templates/spatialstateupdate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@
///// HOST_CONSTANTS ///////////
%HOST_CONSTANTS%

// Number of group units in stateupdate is always N (no subgroups)
const int _N = N;

///// ADDITIONAL_HOST_CODE /////
%ADDITIONAL_HOST_CODE%

{# needed to translate _array... to _ptr_array... #}
///// pointers_lines /////
{{pointers_lines|autoindent}}
Expand All @@ -31,7 +37,7 @@

// Inverse axial resistance
{# {{ openmp_pragma('parallel-static') }} #}
for (int _i=1; _i<N; _i++)
for (int _i=1; _i<_N; _i++)
{{_invr}}[_i] = 1.0/(_Ri*(1/{{r_length_2}}[_i-1] + 1/{{r_length_1}}[_i]));
// Cut sections
{# {{ openmp_pragma('parallel-static') }} #}
Expand All @@ -42,10 +48,10 @@
// The particular solution
// a[i,j]=ab[u+i-j,j] -- u is the number of upper diagonals = 1
{# {{ openmp_pragma('parallel-static') }} #}
for (int _i=0; _i<N; _i++)
for (int _i=0; _i<_N; _i++)
{{_ab_star1}}[_i] = (-({{Cm}}[_i] / {{dt}}) - {{_invr}}[_i] / {{area}}[_i]);
{# {{ openmp_pragma('parallel-static') }} #}
for (int _i=1; _i<N; _i++)
for (int _i=1; _i<_N; _i++)
{
{{_ab_star0}}[_i] = {{_invr}}[_i] / {{area}}[_i-1];
{{_ab_star2}}[_i-1] = {{_invr}}[_i] / {{area}}[_i];
Expand Down Expand Up @@ -121,7 +127,6 @@

__global__ void _tridiagsolve_kernel_{{codeobj_name}}(
int _N,
int THREADS_PER_BLOCK,
///// KERNEL_PARAMETERS /////
%KERNEL_PARAMETERS%
)
Expand All @@ -130,16 +135,14 @@ __global__ void _tridiagsolve_kernel_{{codeobj_name}}(

int tid = threadIdx.x;
int bid = blockIdx.x;
int _idx = bid * THREADS_PER_BLOCK + tid;
int _idx = bid * blockDim.x + tid;

///// KERNEL_CONSTANTS /////
%KERNEL_CONSTANTS%

///// kernel_lines /////
{{kernel_lines|autoindent}}

assert(THREADS_PER_BLOCK == blockDim.x);

// we need to run the kernel with 1 thread per block (to be changed by optimization)
assert(tid == 0 && bid == _idx);

Expand All @@ -161,7 +164,7 @@ __global__ void _tridiagsolve_kernel_{{codeobj_name}}(
{{_u_plus}}[_j]={{_b_plus}}[_j]; // RHS -> _u_plus (solution)
{{_u_minus}}[_j]={{_b_minus}}[_j]; // RHS -> _u_minus (solution)
_bi={{_ab_star1}}[_j]-{{_gtot_all}}[_j]; // main diagonal
if (_j<N-1)
if (_j<_N-1)
{{_c}}[_j]={{_ab_star0}}[_j+1]; // superdiagonal
if (_j>0)
{
Expand Down Expand Up @@ -198,7 +201,6 @@ __global__ void _tridiagsolve_kernel_{{codeobj_name}}(

__global__ void _coupling_kernel_{{codeobj_name}}(
int _N,
int THREADS_PER_BLOCK,
///// KERNEL_PARAMETERS /////
%KERNEL_PARAMETERS%
)
Expand All @@ -207,16 +209,14 @@ __global__ void _coupling_kernel_{{codeobj_name}}(

int tid = threadIdx.x;
int bid = blockIdx.x;
int _idx = bid * THREADS_PER_BLOCK + tid;
int _idx = bid * blockDim.x + tid;

///// KERNEL_CONSTANTS /////
%KERNEL_CONSTANTS%

///// kernel_lines /////
{{kernel_lines|autoindent}}

assert(THREADS_PER_BLOCK == blockDim.x);

// we need to run the kernel with 1 thread, 1 block
assert(_idx == 0);

Expand Down Expand Up @@ -315,7 +315,6 @@ __global__ void _coupling_kernel_{{codeobj_name}}(

__global__ void _combine_kernel_{{codeobj_name}}(
int _N,
int THREADS_PER_BLOCK,
///// KERNEL_PARAMETERS /////
%KERNEL_PARAMETERS%
)
Expand All @@ -324,16 +323,14 @@ __global__ void _combine_kernel_{{codeobj_name}}(

int tid = threadIdx.x;
int bid = blockIdx.x;
int _idx = bid * THREADS_PER_BLOCK + tid;
int _idx = bid * blockDim.x + tid;

///// KERNEL_CONSTANTS /////
%KERNEL_CONSTANTS%

///// kernel_lines /////
{{kernel_lines|autoindent}}

assert(THREADS_PER_BLOCK == blockDim.x);

// we need to run the kernel with 1 thread per block (to be changed by optimization)
assert(tid == 0 && bid == _idx);

Expand Down Expand Up @@ -361,7 +358,6 @@ __global__ void _combine_kernel_{{codeobj_name}}(

__global__ void _currents_kernel_{{codeobj_name}}(
int _N,
int THREADS_PER_BLOCK,
///// KERNEL_PARAMETERS /////
%KERNEL_PARAMETERS%
)
Expand All @@ -370,16 +366,14 @@ __global__ void _currents_kernel_{{codeobj_name}}(

int tid = threadIdx.x;
int bid = blockIdx.x;
int _idx = bid * THREADS_PER_BLOCK + tid;
int _idx = bid * blockDim.x + tid;

///// KERNEL_CONSTANTS /////
%KERNEL_CONSTANTS%

///// kernel_lines /////
{{kernel_lines|autoindent}}

assert(THREADS_PER_BLOCK == blockDim.x);

if(_idx >= _N)
{
return;
Expand Down Expand Up @@ -423,7 +417,6 @@ __global__ void _currents_kernel_{{codeobj_name}}(
int num_threads_tridiagsolve = 1;
_tridiagsolve_kernel_{{codeobj_name}}<<<num_blocks_tridiagsolve, num_threads_tridiagsolve>>>(
_N,
num_threads_tridiagsolve,
///// HOST_PARAMETERS /////
%HOST_PARAMETERS%
);
Expand All @@ -445,7 +438,6 @@ __global__ void _currents_kernel_{{codeobj_name}}(
int num_threads_coupling = 1;
_coupling_kernel_{{codeobj_name}}<<<num_blocks_coupling, num_threads_coupling>>>(
_N,
num_threads_coupling,
///// HOST_PARAMETERS /////
%HOST_PARAMETERS%
);
Expand All @@ -467,7 +459,6 @@ __global__ void _currents_kernel_{{codeobj_name}}(
int num_threads_combine = 1;
_combine_kernel_{{codeobj_name}}<<<num_blocks_combine, num_threads_combine>>>(
_N,
num_threads_combine,
///// HOST_PARAMETERS /////
%HOST_PARAMETERS%
);
Expand Down Expand Up @@ -544,7 +535,6 @@ __global__ void _currents_kernel_{{codeobj_name}}(
// run kernel 5
_currents_kernel_{{codeobj_name}}<<<num_blocks_currents, num_threads_currents>>>(
_N,
num_threads_currents,
///// HOST_PARAMETERS /////
%HOST_PARAMETERS%
);
Expand Down
2 changes: 2 additions & 0 deletions brian2cuda/templates/spikegenerator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@

int _idx = blockIdx.x * blockDim.x + threadIdx.x;

const int N = {{owner.N}};

// We need kernel_lines for time variables
///// kernel_lines /////
{{kernel_lines|autoindent}}
Expand Down
2 changes: 1 addition & 1 deletion brian2cuda/templates/spikemonitor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ struct is_in_subgroup
__device__
bool operator()(const int32_t &neuron)
{
return (_source_start <= neuron && neuron < _source_stop);
return ({{_source_start}} <= neuron && neuron < {{_source_stop}});
}
};
{% endif %}{# Subgroup #}
Expand Down
14 changes: 5 additions & 9 deletions brian2cuda/templates/synapses.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@ __global__ void
__launch_bounds__(1024, {{sm_multiplier}})
{% endif %}
_run_kernel_{{codeobj_name}}(
{# TODO: we only need _N if we have random numbers per synapse, add a if test here #}
int _N,
int bid_offset,
int timestep,
int THREADS_PER_BLOCK,
{% if bundle_mode %}
int threads_per_bundle,
{% endif %}
Expand All @@ -36,14 +34,13 @@ _run_kernel_{{codeobj_name}}(
{
using namespace brian;

assert(THREADS_PER_BLOCK == blockDim.x);

int tid = threadIdx.x;
int bid = blockIdx.x + bid_offset;
//TODO: do we need _idx here? if no, get also rid of scoping after scalar code
// scalar_code can depend on _idx (e.g. if the state update depends on a
// subexpression that is the same for all synapses, ?)
int _idx = bid * THREADS_PER_BLOCK + tid;
int _threads_per_block = blockDim.x;
int _idx = bid * _threads_per_block + tid;
int _vectorisation_idx = _idx;

///// KERNEL_CONSTANTS /////
Expand Down Expand Up @@ -102,7 +99,7 @@ _run_kernel_{{codeobj_name}}(
int pre_post_block_id = (spiking_neuron - spikes_start) * num_parallel_blocks + post_block_idx;
int num_synapses = {{pathway.name}}_num_synapses_by_pre[pre_post_block_id];
int32_t* propagating_synapses = {{pathway.name}}_synapse_ids_by_pre[pre_post_block_id];
for(int j = tid; j < num_synapses; j+=THREADS_PER_BLOCK)
for(int j = tid; j < num_synapses; j+=_threads_per_block)
{
// _idx is the synapse id
int32_t _idx = propagating_synapses[j];
Expand All @@ -127,7 +124,7 @@ _run_kernel_{{codeobj_name}}(
{% if bundle_mode %}
// use a fixed number of threads per bundle, i runs through all those threads of all bundles
// for threads_per_bundle == 1, we have one thread per bundle (parallel)
for (int i = tid; i < queue_size*threads_per_bundle; i+=THREADS_PER_BLOCK)
for (int i = tid; i < queue_size*threads_per_bundle; i+=_threads_per_block)
{
// bundle_idx runs through all bundles
int bundle_idx = i / threads_per_bundle;
Expand All @@ -149,7 +146,7 @@ _run_kernel_{{codeobj_name}}(
{% else %}{# no bundle_mode #}

// use one thread per synapse
for(int j = tid; j < queue_size; j+=THREADS_PER_BLOCK)
for(int j = tid; j < queue_size; j+=_threads_per_block)
{
int32_t _idx = synapses_queue[bid].at(j);
{
Expand Down Expand Up @@ -303,7 +300,6 @@ if ({{pathway.name}}_max_size > 0)
_N,
bid_offset,
{{owner.clock.name}}.timestep[0],
num_threads,
{% if bundle_mode %}
num_threads_per_bundle,
{% endif %}
Expand Down
Loading