Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added cudaCheck wherever missing. #686

Merged
merged 1 commit into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
added cudaCheck where missing for proper error checking.
  • Loading branch information
indianspeedster committed Jul 13, 2024
commit 6bb562bd88a6d55c4f4db33b33c434fcfc3d007f
4 changes: 2 additions & 2 deletions llmc/layernorm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ void layernorm_forward(floatX* out, float* mean, float* rstd,
// this may fail, in which case we fall back to the smem free implementation.
cudaCheck(cudaGetLastError());
auto status = cudaFuncSetAttribute(layernorm_forward_kernel6, cudaFuncAttributeMaxDynamicSharedMemorySize, smem);
cudaGetLastError();
cudaCheck(cudaGetLastError());
if (status == cudaSuccess) {
layernorm_forward_kernel6<<<grid_size, dim3(WARP_SIZE, block_y), smem, stream>>>(out, mean, rstd, inp, weight, bias, N, C);
} else {
Expand Down Expand Up @@ -477,7 +477,7 @@ void fused_residual_forward5(floatX* residual, floatX* normed, float* mean, floa
// this may fail, in which case we fall back to the smem free implementation.
cudaCheck(cudaGetLastError());
auto status = cudaFuncSetAttribute(fused_residual_forward_kernel5, cudaFuncAttributeMaxDynamicSharedMemorySize, smem);
cudaGetLastError();
cudaCheck(cudaGetLastError());
if(status == cudaSuccess) {
fused_residual_forward_kernel5<<<grid_size, dim3(WARP_SIZE, block_y), smem, stream>>>(residual, normed,
mean, rstd, inp1, inp2,
Expand Down
6 changes: 3 additions & 3 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1153,14 +1153,14 @@ void gpt2_free(GPT2 *model) {
void common_start(bool override_enable_tf32 = true, bool print_device_info = true) {

// get CUDA device infos
cudaGetDeviceProperties(&deviceProp, multi_gpu_config.local_device_idx);
cudaCheck(cudaGetDeviceProperties(&deviceProp, multi_gpu_config.local_device_idx));
if (print_device_info) {
printf("[System]\n");
printf("Device %d: %s\n", multi_gpu_config.local_device_idx, deviceProp.name);
}

// set up the cuda streams. atm everything is on the single main stream
cudaStreamCreate(&main_stream);
cudaCheck(cudaStreamCreate(&main_stream));
nvtxNameCudaStreamA(main_stream, "main stream");

// set up cuBLAS and cuBLASLt
Expand Down Expand Up @@ -1788,7 +1788,7 @@ int main(int argc, char *argv[]) {
dataloader_reset(&train_loader);
}
// do one training step, doing forward/backward/update on total_batch_size tokens
cudaEventRecord(start);
cudaCheck(cudaEventRecord(start));
// gradient and loss accumulation loop over micro-batches
for (int micro_step = 0; micro_step < grad_accum_steps; micro_step++) {
// fetch the next data batch
Expand Down