Skip to content

Commit

Permalink
llama: Don't double count the sampling time (ggerganov#2107)
Browse files Browse the repository at this point in the history
  • Loading branch information
howard0su committed Jul 5, 2023
1 parent 9e4475f commit 051c70d
Showing 1 changed file with 9 additions and 11 deletions.
20 changes: 9 additions & 11 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1905,10 +1905,10 @@ void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * can
return;
}

const int64_t t_start_sample_us = ggml_time_us();

llama_sample_softmax(ctx, candidates);

const int64_t t_start_sample_us = ggml_time_us();

// Compute the cumulative probabilities
float cum_sum = 0.0f;
size_t last_idx = candidates->size;
Expand Down Expand Up @@ -1937,9 +1937,8 @@ void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array *
return;
}

const int64_t t_start_sample_us = ggml_time_us();

llama_sample_softmax(nullptr, candidates);
const int64_t t_start_sample_us = ggml_time_us();

// Compute the first and second derivatives
std::vector<float> first_derivatives(candidates->size - 1);
Expand Down Expand Up @@ -1991,11 +1990,11 @@ void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * c
return;
}

const int64_t t_start_sample_us = ggml_time_us();

// Compute the softmax of logits and calculate entropy
llama_sample_softmax(nullptr, candidates);

const int64_t t_start_sample_us = ggml_time_us();

float entropy = 0.0f;
for (size_t i = 0; i < candidates->size; ++i) {
entropy += -candidates->data[i].p * logf(candidates->data[i].p);
Expand Down Expand Up @@ -2164,13 +2163,11 @@ llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_

if (ctx) {
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
ctx->n_sample++;
}
return X;
}

llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) {
assert(ctx);
int64_t t_start_sample_us;
t_start_sample_us = ggml_time_us();

Expand All @@ -2185,13 +2182,14 @@ llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_tok
candidates->size = 1;
}

if (ctx) {
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
}

// Normalize the probabilities of the remaining words
llama_sample_softmax(ctx, candidates);

// Sample the next word X from the remaining words
if (ctx) {
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
}
llama_token X = llama_sample_token(ctx, candidates);
t_start_sample_us = ggml_time_us();

Expand Down

0 comments on commit 051c70d

Please sign in to comment.