Skip to content

Commit

Permalink
Merge pull request #6 from anshuman23/dev
Browse files Browse the repository at this point in the history
Returning list of op names in get_graph_ops and extended error atoms to all TF error codes
  • Loading branch information
anshuman23 authored May 19, 2018
2 parents a22de5f + 9adf180 commit 2398996
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 23 deletions.
83 changes: 75 additions & 8 deletions c_src/Tensorflex.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
#include <stdio.h>
#include <stdlib.h>

#define BASE_STRING_LENGTH 255

void free_buffer(void* data, size_t length) {
free(data);
}



ErlNifResourceType *graph_resource, *op_desc_resource, *tensor_resource, *session_resource, *op_resource, *buffer_resource, *status_resource, *graph_opts_resource;

void graph_destr(ErlNifEnv *env, void *res) {
Expand Down Expand Up @@ -75,7 +75,6 @@ static ERL_NIF_TERM string_constant(ErlNifEnv *env, int argc, const ERL_NIF_TERM
return enif_make_string(env, buf, ERL_NIF_LATIN1);
}


static ERL_NIF_TERM new_import_graph_def_opts(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[])
{
TF_ImportGraphDefOptions **graph_opts_resource_alloc = enif_alloc_resource(graph_opts_resource, sizeof(TF_ImportGraphDefOptions *));
Expand Down Expand Up @@ -119,6 +118,46 @@ static ERL_NIF_TERM new_op(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[])
return op_desc;
}

static ERL_NIF_TERM error_to_atom(ErlNifEnv *env, TF_Status* status)
{
switch(TF_GetCode(status))
{
case TF_CANCELLED: return enif_make_atom(env,"cancelled");
break;
case TF_UNKNOWN: return enif_make_atom(env,"unknown");
break;
case TF_INVALID_ARGUMENT: return enif_make_atom(env,"invalid_argument");
break;
case TF_DEADLINE_EXCEEDED: return enif_make_atom(env,"deadline_exceeded");
break;
case TF_NOT_FOUND: return enif_make_atom(env,"not_found");
break;
case TF_ALREADY_EXISTS: return enif_make_atom(env, "already_exists");
break;
case TF_PERMISSION_DENIED: return enif_make_atom(env,"permission_denied");
break;
case TF_UNAUTHENTICATED: return enif_make_atom(env,"unauthenticated");
break;
case TF_RESOURCE_EXHAUSTED: return enif_make_atom(env,"resource_exhausted");
break;
case TF_FAILED_PRECONDITION: return enif_make_atom(env,"failed_precondition");
break;
case TF_ABORTED: return enif_make_atom(env,"aborted");
break;
case TF_OUT_OF_RANGE: return enif_make_atom(env,"out_of_range");
break;
case TF_UNIMPLEMENTED:return enif_make_atom(env,"unimplemented");
break;
case TF_INTERNAL: return enif_make_atom(env,"internal");
break;
case TF_UNAVAILABLE: return enif_make_atom(env,"unavailable");
break;
case TF_DATA_LOSS: return enif_make_atom(env,"data_loss");
break;
default: return enif_make_atom(env,"unlisted_code");
}
}

static ERL_NIF_TERM read_graph(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[])
{
ErlNifBinary filepath;
Expand Down Expand Up @@ -148,7 +187,7 @@ static ERL_NIF_TERM read_graph(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv

TF_GraphImportGraphDef(graph, buf, graph_opts, status);
if (TF_GetCode(status) != TF_OK) {
return enif_make_tuple2(env,enif_make_atom(env,"error"),enif_make_string(env, "Unable to import graph", ERL_NIF_LATIN1));
return enif_make_tuple2(env, enif_make_atom(env,"error"), error_to_atom(env,status));
}
else {
fprintf(stderr, "Successfully imported graph\n");
Expand All @@ -162,6 +201,37 @@ static ERL_NIF_TERM read_graph(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv

}

static ERL_NIF_TERM get_graph_ops(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[])
{
TF_Graph **graph;
enif_get_resource(env, argv[0], graph_resource, (void *) &graph);

int n_ops = 0;
size_t pos = 0;
TF_Operation *op_count;
while ((op_count = TF_GraphNextOperation(*graph, &pos)) != NULL) {
n_ops++;
}

ERL_NIF_TERM *op_list;
ERL_NIF_TERM op_list_eterm;
TF_Operation *op_temp;
ErlNifBinary erl_str;
op_list = malloc(sizeof(ERL_NIF_TERM)*n_ops);
pos = 0;

for(int i=0; i<n_ops; i++) {
op_temp = TF_GraphNextOperation(*graph, &pos);
enif_alloc_binary(strlen((char*) TF_OperationName(op_temp)), &erl_str);
memcpy(erl_str.data, (char*) TF_OperationName(op_temp), strlen((char*) TF_OperationName(op_temp)));
op_list[i] = enif_make_binary(env, &erl_str);
}

op_list_eterm = enif_make_list_from_array(env, op_list, n_ops);
free(op_list);
return op_list_eterm;
}

static ERL_NIF_TERM create_and_run_sess(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[])
{
TF_Graph **graph;
Expand Down Expand Up @@ -207,11 +277,8 @@ static ERL_NIF_TERM create_and_run_sess(ErlNifEnv *env, int argc, const ERL_NIF_
static ErlNifFunc nif_funcs[] =
{
{ "version", 0, version },
{ "new_graph", 0, new_graph },
{ "new_op", 3, new_op },
{ "read_graph", 1, read_graph },
{ "string_constant", 1, string_constant },
{ "create_and_run_sess", 3, create_and_run_sess }
{ "get_graph_ops", 1, get_graph_ops },
};

ERL_NIF_INIT(Elixir.Tensorflex, nif_funcs, res_loader, NULL, NULL, NULL)
Expand Down
18 changes: 3 additions & 15 deletions lib/tensorflex.ex
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,12 @@ defmodule Tensorflex do
raise "NIF tf_version/0 not implemented"
end

def new_graph do
raise "NIF tf_new_graph/0 not implemented"
end

def new_op(_graph, _op, _label) do
raise "NIF tf_new_op/3 not implemented"
end

def read_graph(_filepath) do
raise "NIF read_graph/1 not implemented"
end

def string_constant(_value) do
raise "NIF tf_string_constant/1 not implemented"
def get_graph_ops(_graph) do
raise "NIF get_graph_ops/1 not implemented"
end

def create_and_run_sess(_graph, _opdesc, _tensor) do
raise "NIF tf_create_and_run_sess/3 not implemented"
end


end

0 comments on commit 2398996

Please sign in to comment.