Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactored compute forward to not pass in the src tensors each time #729

Merged
merged 8 commits into from
Feb 21, 2024
Prev Previous commit
Next Next commit
fix merge issues with flags
  • Loading branch information
siddharthvader committed Feb 17, 2024
commit f44f91c7931244775a37916ddb4e41e256a26ff2
29 changes: 20 additions & 9 deletions src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -2607,7 +2607,7 @@ static struct ggml_tensor * ggml_new_tensor_impl(
/*.nb =*/ { 0, 0, 0, 0 },
/*.op =*/ GGML_OP_NONE,
/*.op_params =*/ { 0 },
/*.is_param =*/ false,
/*.is_param =*/ 0,
/*.grad =*/ NULL,
/*.src =*/ { NULL },
/*.perf_runs =*/ 0,
Expand Down Expand Up @@ -6509,7 +6509,7 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
void ggml_set_param(
struct ggml_context * ctx,
struct ggml_tensor * tensor) {
tensor->is_param = true;
tensor->flags |= GGML_TENSOR_FLAG_PARAM;

GGML_ASSERT(tensor->grad == NULL);
tensor->grad = ggml_dup_tensor(ctx, tensor);
Expand Down Expand Up @@ -15580,7 +15580,7 @@ static struct ggml_tensor * ggml_recompute_graph_node(
return NULL;
}

if (node->is_param) {
if (node->flags & GGML_TENSOR_FLAG_PARAM) {
return node;
}

Expand Down Expand Up @@ -15614,7 +15614,7 @@ static struct ggml_tensor * ggml_recompute_graph_node(

clone->op = node->op;
clone->grad = node->grad;
clone->is_param = node->is_param;
clone->flags = node->flags;
clone->extra = node->extra;
for (int k = 0; k < GGML_MAX_DIMS; ++k) {
clone->nb[k] = node->nb[k];
Expand Down Expand Up @@ -16646,7 +16646,7 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
for (int i = 0; i < gf->n_nodes; i++) {
struct ggml_tensor * node = gf->nodes[i];

if (node->is_param) {
if (node->flags & GGML_TENSOR_FLAG_PARAM) {
GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
ggml_build_forward_expand(gb, node->grad);
}
Expand Down Expand Up @@ -18131,7 +18131,7 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
i,
node->ne[0], node->ne[1], node->ne[2],
ggml_op_name(node->op), node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs,
ggml_op_name(node->op), node->flags & GGML_TENSOR_FLAG_PARAM ? "x" : node->grad ? "g" : " ", node->perf_runs,
(double) node->perf_cycles / (double) ggml_cycles_per_ms(),
(double) node->perf_cycles / (double) ggml_cycles_per_ms() / (double) node->perf_runs,
(double) node->perf_time_us / 1000.0,
Expand Down Expand Up @@ -18224,7 +18224,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
continue;
}

if (node->is_param) {
if (node->flags & GGML_TENSOR_FLAG_PARAM) {
snprintf(color, sizeof(color), "yellow");
} else if (node->grad) {
if (ggml_graph_find(gf, node)) {
Expand Down Expand Up @@ -18398,7 +18398,7 @@ static enum ggml_opt_result ggml_opt_adam(
int np = 0;
int64_t nx = 0;
for (int i = 0; i < gf->n_nodes; ++i) {
if (gf->nodes[i]->is_param) {
if (gf->nodes[i]->flags & GGML_TENSOR_FLAG_PARAM) {
GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op);

GGML_ASSERT(np < GGML_MAX_PARAMS);
Expand Down Expand Up @@ -18761,7 +18761,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
int np = 0;
int nx = 0;
for (int i = 0; i < gf->n_nodes; ++i) {
if (gf->nodes[i]->is_param) {
if (gf->nodes[i]->flags & GGML_TENSOR_FLAG_PARAM) {
GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op);

GGML_ASSERT(np < GGML_MAX_PARAMS);
Expand Down Expand Up @@ -19236,6 +19236,17 @@ enum ggml_opt_result ggml_opt_resume_g(

////////////////////////////////////////////////////////////////////////////////

void ggml_set_input(struct ggml_tensor * tensor) {
tensor->flags |= GGML_TENSOR_FLAG_INPUT;
}

void ggml_set_output(struct ggml_tensor * tensor) {
tensor->flags |= GGML_TENSOR_FLAG_OUTPUT;
}

////////////////////////////////////////////////////////////////////////////////


void ggml_quantize_init(enum ggml_type type) {
ggml_critical_section_start();

Expand Down