Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama 3.1 #38

Merged
merged 15 commits into from
Sep 3, 2024
Merged

Llama 3.1 #38

merged 15 commits into from
Sep 3, 2024

Conversation

laggui
Copy link
Member

@laggui laggui commented Aug 22, 2024

Ported the new Llama weights to our implementation along with their custom RoPE frequency scaling.

Made a couple additional changes to fix minor things along the way:

  • Changed the default llama3 model for the chat example to use 3.1
    • Added command line argument to select between the versions
  • Added multiple stop tokens condition to properly support end of turn + end of text for newer llama models
  • Fixed the max_seq_len to be properly propagated to (command argument value was not used before)
  • Added import feature flag to keep pytorch weights loading code optional

TODO:

  • Use Burn version 0.14

@laggui laggui marked this pull request as ready for review August 29, 2024 16:43
@laggui
Copy link
Member Author

laggui commented Aug 30, 2024

Some notes:

wgpu backend fails due to lack of memory (even with TinyLlama):

// WSL
Loading record...
thread 'main' panicked at /home/laggui/.cargo/registry/src/index.crates.io-6f17d22bba15001f/cubecl-runtime-0.2.0/src/memory_management/dynamic.rs:156:9:
No memory pool big enough to reserve 262144000 bytes.
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
// Windows
Loading record...
thread 'main' panicked at C:\Users\guila\.cargo\registry\src\index.crates.io-6f17d22bba15001f\wgpu-22.1.0\src\backend\wgpu_core.rs:3411:5:
wgpu error: Validation Error

Caused by:
  In Device::create_buffer
    Not enough memory left.


note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
error: process didn't exit successfully: `target\release\examples\chat.exe` (exit code: 101)

cuda backend currently uses f32 because f16 has some compilation errors:

Loading record...
Loaded in 1s
Processing prompt: How many helicopters can a human eat in one sitting?
thread 'main' panicked at /home/laggui/.cargo/registry/src/index.crates.io-6f17d22bba15001f/cubecl-cuda-0.2.0/src/compute/server.rs:237:17:
[Compilation Error]
    default_program(56): error: class "__half2" has no member "i_0"
      l_0_8.i_0 = __half(0.0);
            ^
    default_program(57): error: class "__half2" has no member "i_1"
      l_0_8.i_1 = __half(0.0);
            ^
    default_program(66): error: class "__half2" has no member "i_0"
      l_0_9.i_0 = __half(0.0);
            ^
    default_program(67): error: class "__half2" has no member "i_1"
      l_0_9.i_1 = __half(0.0);
            ^
    4 errors detected in the compilation of "default_program".
[Source]
#include <cuda_fp16.h>
typedef unsigned int uint;


extern "C" __global__ void kernel(
__half2 input_0[],__half2 input_1[],uint info[]
) {

    int3 absoluteIdx = make_int3(
        blockIdx.x * blockDim.x + threadIdx.x,
        blockIdx.y * blockDim.y + threadIdx.y,
        blockIdx.z * blockDim.z + threadIdx.z
    );

    uint idxGlobal = (absoluteIdx.z * gridDim.x * blockDim.x * gridDim.y * blockDim.y) + (absoluteIdx.y * gridDim.x * blockDim.x) + absoluteIdx.x;
uint rank = info[0];
uint rank_2 = rank * 2;
uint l_0_0;
uint l_0_1;
uint l_0_2;
uint l_0_3;
bool l_0_4;
uint l_0_5;
uint l_0_6;
uint l_0_7;
__half2 l_0_8;
__half2 l_0_9;
l_0_0 = idxGlobal;
l_0_1 = idxGlobal;
l_0_2 = idxGlobal;
l_0_3 = info[(2 * 2 * info[0]) + 1] / 2;
l_0_4 = l_0_0 >= l_0_3;
if (l_0_4) {
return;}
l_0_3 = l_0_0 * uint(2);
l_0_5 = uint(0);

for (uint l_1_0 = uint(0); l_1_0 < rank; ++l_1_0) {
l_0_6 = info[(0 * rank_2) + l_1_0 + 1];
l_0_6 = l_0_3 / l_0_6;
l_0_7 = info[(1 * rank_2) + rank + l_1_0 + 1];
l_0_6 = l_0_6 % l_0_7;
l_0_7 = info[(1 * rank_2) + l_1_0 + 1];
l_0_6 = l_0_6 * l_0_7;
l_0_5 = l_0_5 + l_0_6;
}
l_0_5 = l_0_5 / uint(2);
l_0_2 = l_0_5;
uint l_0_10;
bool l_0_11;
l_0_10 = info[(2 * 2 * info[0]) + 1] / 2;
l_0_11 = l_0_1 < l_0_10;
if (l_0_11) {
l_0_8 = input_0[l_0_1];
} else {
l_0_8.i_0 = __half(0.0);
l_0_8.i_1 = __half(0.0);
}
uint l_0_12;
bool l_0_13;
l_0_12 = info[(2 * 2 * info[0]) + 2] / 2;
l_0_13 = l_0_2 < l_0_12;
if (l_0_13) {
l_0_9 = input_1[l_0_2];
} else {
l_0_9.i_0 = __half(0.0);
l_0_9.i_1 = __half(0.0);
}
l_0_8 = l_0_8 * l_0_9;
uint l_0_14;
bool l_0_15;
l_0_14 = info[(2 * 2 * info[0]) + 1] / 2;
l_0_15 = l_0_0 < l_0_14;
if (l_0_15) {
input_0[l_0_0] = l_0_8;
}

}
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace

@laggui laggui merged commit 877996b into main Sep 3, 2024
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants