Skip to content

Commit

Permalink
Merge branch 'hongbinl/fix_dtype_in_sample_index' into 'main'
Browse files Browse the repository at this point in the history
change dtype of sample_index from int32 to int64

See merge request ADLR/megatron-lm!1449
  • Loading branch information
ericharper committed Jun 28, 2024
2 parents 6cb81d7 + 1ba2198 commit e30252b
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions megatron/core/datasets/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ py::array build_sample_idx(const py::array_t<int32_t> &sizes_,
{
num_samples = ceil(float(num_epochs * tokens_per_epoch - add_extra_token_to_sequence) / seq_length);
}
int32_t *sample_idx = new int32_t[2 * (num_samples + 1)];
int64_t *sample_idx = new int64_t[2 * (num_samples + 1)];

// Index into sample_idx.
int64_t sample_index = 0;
Expand Down Expand Up @@ -228,11 +228,11 @@ py::array build_sample_idx(const py::array_t<int32_t> &sizes_,
// Method to deallocate memory.
py::capsule free_when_done(sample_idx, [](void *mem_)
{
int32_t *mem = reinterpret_cast<int32_t*>(mem_);
int64_t *mem = reinterpret_cast<int64_t*>(mem_);
delete[] mem; });

// Return the numpy array.
const auto byte_size = sizeof(int32_t);
const auto byte_size = sizeof(int64_t);
return py::array(std::vector<int64_t>{num_samples + 1, 2}, // shape
{2 * byte_size, byte_size}, // C-style contiguous strides
sample_idx, // the data pointer
Expand Down

0 comments on commit e30252b

Please sign in to comment.