Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save cell vars to nwb #79

Draft
wants to merge 7 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
working serial version of nwb output
  • Loading branch information
VBaratham committed May 14, 2019
commit f5b808fdba0398e49c9eededd3d858cc28362601
34 changes: 21 additions & 13 deletions bmtk/simulator/bionet/modules/record_cellvars.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,6 @@
from bmtk.simulator.bionet.io_tools import io

from bmtk.utils.io import cell_vars
try:
# Check to see if h5py is built to run in parallel
if h5py.get_config().mpi:
MembraneRecorder = cell_vars.CellVarRecorderParallel
else:
MembraneRecorder = cell_vars.CellVarRecorder

except Exception as e:
MembraneRecorder = cell_vars.CellVarRecorder

MembraneRecorder._io = io

pc = h.ParallelContext()
MPI_RANK = int(pc.id())
Expand Down Expand Up @@ -86,13 +75,31 @@ def __init__(self, tmp_dir, file_name, variable_name, cells, sections='all', buf
self._local_gids = []
self._sections = sections

self._var_recorder = MembraneRecorder(self._file_name, self._tmp_dir, self._all_variables,
buffer_data=buffer_data, mpi_rank=MPI_RANK, mpi_size=N_HOSTS)
recorder_cls = self._get_var_recorder_cls()
recorder_cls._io = io
self._var_recorder = recorder_cls(self._file_name, self._tmp_dir, self._all_variables,
buffer_data=buffer_data, mpi_rank=MPI_RANK, mpi_size=N_HOSTS)

self._gid_list = [] # list of all gids that will have their variables saved
self._data_block = {} # table of variable data indexed by [gid][variable]
self._block_step = 0 # time step within a given block

def _get_var_recorder_cls(self):
try:
in_mpi = h5py.get_cofig().mpi
except Exception as e:
in_mpi = False

if self._file_name.endswith('.nwb'):
if in_mpi:
raise NotImplementedError("BMTK does not yet support parallel I/O with NWB")
return cell_vars.CellVarRecorderNWB
else:
if in_mpi:
return cell_vars.CellVarRecorderParallel
else:
return cell_vars.CellVarRecorder

def _get_gids(self, sim):
# get list of gids to save. Will only work for biophysical cells saved on the current MPI rank
selected_gids = set(sim.net.get_node_set(self._all_gids).gids())
Expand Down Expand Up @@ -151,6 +158,7 @@ def finalize(self, sim):
self._var_recorder.merge()



class SomaReport(MembraneReport):
"""Special case for when only needing to save the soma variable"""
def __init__(self, tmp_dir, file_name, variable_name, cells, sections='soma', buffer_data=True, transform={}):
Expand Down
165 changes: 119 additions & 46 deletions bmtk/utils/io/cell_vars.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import os
from datetime import datetime
from collections import defaultdict
import h5py
import numpy as np

from pynwb import NWBFile, NWBHDF5IO
from nwbext_simulation_output import Compartments, CompartmentSeries

from bmtk.utils import io
from bmtk.utils.sonata.utils import add_hdf5_magic, add_hdf5_version

Expand Down Expand Up @@ -36,7 +41,7 @@ def __init__(self, var_name):

def __init__(self, file_name, tmp_dir, variables, buffer_data=True, mpi_rank=0, mpi_size=1):
self._file_name = file_name
self._h5_handle = None
self._file_handle = None
self._tmp_dir = tmp_dir
self._variables = variables if isinstance(variables, list) else [variables]
self._n_vars = len(self._variables) # Used later to keep track if more than one var is saved to the same file.
Expand All @@ -46,7 +51,8 @@ def __init__(self, file_name, tmp_dir, variables, buffer_data=True, mpi_rank=0,
self._tmp_files = []
self._saved_file = file_name

if mpi_size > 1:
if mpi_size > 1 and not isinstance(self, CellVarRecorderParallel):

self._io.log_warning('Was unable to run h5py in parallel (mpi) mode.' +
' Saving of membrane variable(s) may slow down.')
tmp_fname = os.path.basename(file_name) # make sure file names don't clash if there are multiple reports
Expand All @@ -56,8 +62,8 @@ def __init__(self, file_name, tmp_dir, variables, buffer_data=True, mpi_rank=0,

self._mapping_gids = [] # list of gids in the order they appear in the data
self._gid_map = {} # table for looking up the gid offsets
self._map_attrs = {} # Used for additonal attributes in /mapping

self._map_attrs = defaultdict(list) # Used for additonal attributes in /mapping
self._mapping_element_ids = [] # sections
self._mapping_element_pos = [] # segments
self._mapping_index = [0] # index_pointer
Expand Down Expand Up @@ -123,10 +129,10 @@ def _calc_offset(self):
self._gids_beg = 0
self._gids_end = self._n_gids_local

def _create_h5_file(self):
self._h5_handle = h5py.File(self._file_name, 'w')
add_hdf5_version(self._h5_handle)
add_hdf5_magic(self._h5_handle)
def _create_file(self):
self._file_handle = h5py.File(self._file_name, 'w')
add_hdf5_version(self._file_handle)
add_hdf5_magic(self._file_handle)

def add_cell(self, gid, sec_list, seg_list, **map_attrs):
assert(len(sec_list) == len(seg_list))
Expand All @@ -140,16 +146,26 @@ def add_cell(self, gid, sec_list, seg_list, **map_attrs):
self._n_segments_local += n_segs
self._n_gids_local += 1
for k, v in map_attrs.items():
if k not in self._map_attrs:
self._map_attrs[k] = v
else:
self._map_attrs[k].extend(v)
self._map_attrs[k].extend(v)

def initialize(self, n_steps, buffer_size=0):
self._calc_offset()
self._create_h5_file()
self._create_file()
self._init_mapping()
self._total_steps = n_steps
self._buffer_block_size = buffer_size
self._init_buffers()

if not self._buffer_data:
# If data is not being buffered and instead written to the main block, we have to add a rank offset
# to the gid offset
for gid, gid_offset in self._gid_map.items():
self._gid_map[gid] = (gid_offset[0] + self._seg_offset_beg, gid_offset[1] + self._seg_offset_beg)

self._is_initialized = True

var_grp = self._h5_handle.create_group('/mapping')
def _init_mapping(self):
var_grp = self._file_handle.create_group('/mapping')
var_grp.create_dataset('gids', shape=(self._n_gids_all,), dtype=np.uint)
var_grp.create_dataset('element_id', shape=(self._n_segments_all,), dtype=np.uint)
var_grp.create_dataset('element_pos', shape=(self._n_segments_all,), dtype=np.float)
Expand All @@ -164,32 +180,25 @@ def initialize(self, n_steps, buffer_size=0):
var_grp['index_pointer'][self._gids_beg:(self._gids_end+1)] = self._mapping_index
for k, v in self._map_attrs.items():
var_grp[k][self._seg_offset_beg:self._seg_offset_end] = v

self._total_steps = n_steps
self._buffer_block_size = buffer_size
if not self._buffer_data:
# If data is not being buffered and instead written to the main block, we have to add a rank offset
# to the gid offset
for gid, gid_offset in self._gid_map.items():
self._gid_map[gid] = (gid_offset[0] + self._seg_offset_beg, gid_offset[1] + self._seg_offset_beg)


def _init_buffers(self):

for var_name, data_tables in self._data_blocks.items():
# If users are trying to save multiple variables in the same file put data table in its own /{var} group
# (not sonata compliant). Otherwise the data table is located at the root
data_grp = self._h5_handle if self._n_vars == 1 else self._h5_handle.create_group('/{}'.format(var_name))
data_grp = self._file_handle if self._n_vars == 1 else self._file_handle.create_group('/{}'.format(var_name))
if self._buffer_data:
# Set up in-memory block to buffer recorded variables before writing to the dataset
data_tables.buffer_block = np.zeros((buffer_size, self._n_segments_local), dtype=np.float)
data_tables.data_block = data_grp.create_dataset('data', shape=(n_steps, self._n_segments_all),
data_tables.buffer_block = np.zeros((self._buffer_block_size, self._n_segments_local), dtype=np.float)
data_tables.data_block = data_grp.create_dataset('data', shape=(self._total_steps, self._n_segments_all),
dtype=np.float, chunks=True)
data_tables.data_block.attrs['variable_name'] = var_name
else:
# Since we are not buffering data, we just write directly to the on-disk dataset
data_tables.buffer_block = data_grp.create_dataset('data', shape=(n_steps, self._n_segments_all),
data_tables.buffer_block = data_grp.create_dataset('data', shape=(self._total_steps, self._n_segments_all),
dtype=np.float, chunks=True)
data_tables.buffer_block.attrs['variable_name'] = var_name

self._is_initialized = True

def record_cell(self, gid, var_name, seg_vals, tstep):
"""Record cell parameters.
Expand Down Expand Up @@ -226,6 +235,7 @@ def flush(self):
if blk_end > self._total_steps:
# Need to handle the case that simulation doesn't end on a block step
blk_end = blk_beg + self._total_steps - blk_beg
seg_beg, seg_end = self._seg_offset_beg, self._seg_offset_end

block_size = blk_end - blk_beg
self._last_save_indx += block_size
Expand All @@ -234,7 +244,7 @@ def flush(self):
data_table.data_block[blk_beg:blk_end, :] = data_table.buffer_block[:block_size, :]

def close(self):
self._h5_handle.close()
self._file_handle.close()

def merge(self):
if self._mpi_size > 1 and self._mpi_rank == 0:
Expand Down Expand Up @@ -282,6 +292,7 @@ def merge(self):
for k, v in self._map_attrs.items():
mapping_grp[k][beg:end] = v


# shift the index pointer values
index_pointer = np.array(tmp_mapping_grp['index_pointer'])
update_index = beg + index_pointer
Expand All @@ -290,7 +301,6 @@ def merge(self):
gids_ds[beg:end] = tmp_mapping_grp['gids']
index_pointer_ds[beg:(end+1)] = update_index


# combine the /var/data datasets
for var_name in self._variables:
data_name = '/data' if self._n_vars == 1 else '/{}/data'.format(var_name)
Expand All @@ -305,33 +315,96 @@ def merge(self):
os.remove(tmp_file)


class CellVarRecorderNWB(CellVarRecorder):
def __init__(self, file_name, tmp_dir, variables, buffer_data=True, mpi_rank=0, mpi_size=1):
super(CellVarRecorderNWB, self).__init__(
file_name, tmp_dir, variables, buffer_data=buffer_data,
mpi_rank=mpi_rank, mpi_size=mpi_size
)
self._compartments = Compartments('compartments')
self._compartmentseries = {}
if self._mpi_size > 1:
self._nwbio = NWBHDF5IO(self._file_name, 'w', comm=comm)
else:
self._nwbio = NWBHDF5IO(self._file_name, 'w')

def _create_file(self):
self._file_handle = NWBFile('description', 'id', datetime.now().astimezone()) # TODO: pass in descr, id

def initialize(self, n_steps, buffer_size=0):
super(CellVarRecorderNWB, self).initialize(n_steps, buffer_size=buffer_size)


def add_cell(self, gid, sec_list, seg_list, **map_attrs):
self._compartments.add_row(number=sec_list, position=seg_list, id=gid)
super(CellVarRecorderNWB, self).add_cell(gid, sec_list, seg_list, **map_attrs)

def _init_mapping(self):
pass # TODO: add timing info, cell ids?

def _init_buffers(self):
if not self._buffer_data:
raise NotImplementedError('Must buffer data with CellVarRecorderNWB')

self._file_handle.add_acquisition(self._compartments)
for var_name, data_tables in self._data_blocks.items():
cs = CompartmentSeries(
var_name, data=np.zeros((self._total_steps, self._n_segments_all)),
compartments=self._compartments, unit='mV', rate=1000.0/self.dt
)
self._compartmentseries[var_name] = cs
self._file_handle.add_acquisition(cs)
data_tables.buffer_block = np.zeros((self._buffer_block_size, self._n_segments_local), dtype=np.float)
data_tables.data_block = self._compartmentseries[var_name].data

self._nwbio.write(self._file_handle)

# Re-read data sets to make them NWB objects, not numpy arrays
# (this way, they are immediately written to disk when modified)
self._nwbio.close()
if self._mpi_size > 1:
self._nwbio = NWBHDF5IO(self._file_name, 'a', comm=comm)
else:
self._nwbio = NWBHDF5IO(self._file_name, 'a')
self._file_handle = self._nwbio.read()
for var_name, data_tables in self._data_blocks.items():
self._data_blocks[var_name].data_block = self._file_handle.acquisition[var_name].data

def close(self):
self._nwbio.close()



class CellVarRecorderParallel(CellVarRecorder):
"""
Unlike the parent, this take advantage of parallel h5py to writting to the results file across different ranks.

"""
def __init__(self, file_name, tmp_dir, variables, buffer_data=True):
super(CellVarRecorder, self).__init__(file_name, tmp_dir, variables, buffer_data=buffer_data, mpi_rank=0,
mpi_size=1)
def __init__(self, file_name, tmp_dir, variables, buffer_data=True, mpi_rank=0, mpi_size=1):
super(CellVarRecorderParallel, self).__init__(
file_name, tmp_dir, variables, buffer_data=buffer_data,
mpi_rank=mpi_rank, mpi_size=mpi_size
)

def _calc_offset(self):
# iterate through the ranks let rank r determine the offset from rank r-1
for r in range(comm.Get_size()):
if rank == r:
if rank < (nhosts - 1):
# pass the num of segments and num of gids to the next rank
offsets = np.array([self._n_segments_local, self._n_gids_local], dtype=np.uint)
comm.Send([offsets, MPI.UNSIGNED_INT], dest=(rank+1))

if rank > 0:
# get num of segments and gids from prev. rank and calculate offsets
offset = np.empty(2, dtype=np.uint)
offsets = np.empty(2, dtype=np.uint)
comm.Recv([offsets, MPI.UNSIGNED_INT], source=(r-1))
self._seg_offset_beg = offsets[0]
self._seg_offset_end = self._seg_offset_beg + self._n_segments_local
self._gids_beg = offsets[1]

self._seg_offset_end = int(self._seg_offset_beg) \
+ int(self._n_segments_local)
self._gids_end = int(self._gids_beg) + int(self._n_gids_local)

self._gids_beg = offset[1]
self._gids_end = self._gids_beg + self._n_gids_local
if rank < (nhosts - 1):
# pass the next rank its offset
offsets = np.array([self._seg_offset_end, self._gids_end], dtype=np.uint)
comm.Send([offsets, MPI.UNSIGNED_INT], dest=(rank+1))

comm.Barrier()

Expand All @@ -345,10 +418,10 @@ def _calc_offset(self):
self._n_segments_all = total_counts[0]
self._n_gids_all = total_counts[1]

def _create_h5_file(self):
self._h5_handle = h5py.File(self._file_name, 'w', driver='mpio', comm=MPI.COMM_WORLD)
add_hdf5_version(self._h5_handle)
add_hdf5_magic(self._h5_handle)
def _create_file(self):
self._file_handle = h5py.File(self._file_name, 'w', driver='mpio', comm=MPI.COMM_WORLD)
add_hdf5_version(self._file_handle)
add_hdf5_magic(self._file_handle)

def merge(self):
pass