From fefe3e7fc8b219beb2119184e5cc684c25f3d643 Mon Sep 17 00:00:00 2001 From: Philip Dhingra <195001+philipkd@users.noreply.github.com> Date: Wed, 3 Jan 2024 11:09:00 -0800 Subject: [PATCH] fix mini_batch_size being referenced before assignment (openai#25) --- weak_to_strong/train.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/weak_to_strong/train.py b/weak_to_strong/train.py index 2a389a9..5e151f0 100644 --- a/weak_to_strong/train.py +++ b/weak_to_strong/train.py @@ -231,9 +231,11 @@ def maybe_load_model(model): ).to("cuda") already_trained = maybe_load_model(model) # data parallel: currently not supported with model parallel + + minibatch_size = min(minibatch_size_per_device * torch.cuda.device_count(), batch_size) + if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model, output_device=0) - minibatch_size = min(minibatch_size_per_device * torch.cuda.device_count(), batch_size) print( "Using", torch.cuda.device_count(),