Skip to content

Commit

Permalink
api: execution model: stream and primitive::execution
Browse files Browse the repository at this point in the history
  • Loading branch information
Fomenko, Evarist M committed Mar 7, 2019
1 parent daf2892 commit 783718e
Show file tree
Hide file tree
Showing 64 changed files with 415 additions and 705 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ include_directories(include)

add_subdirectory(src)
add_subdirectory(examples)
add_subdirectory(tests)
# add_subdirectory(tests)

# Cannot use CMAKE_INSTALL_DOCDIR since it uses PROJECT_NAME and not LIB_NAME
install(FILES LICENSE DESTINATION ${CMAKE_INSTALL_DATAROOTDIR}/doc/${LIB_NAME})
12 changes: 6 additions & 6 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ set_if(UNIX LIBM m)

register_exe(simple-net-c simple_net.c "test")
register_exe(simple-net-cpp simple_net.cpp "test")
register_exe(simple-training-net-c simple_training_net.c "test" ${LIBM})
register_exe(simple-training-net-cpp simple_training_net.cpp "test" ${LIBM})
register_exe(simple-net-int8-cpp simple_net_int8.cpp "test")
register_exe(simple-rnn-cpp simple_rnn.cpp "test")
register_exe(simple-rnn-int8-cpp simple_rnn_int8.cpp "test")
register_exe(simple-rnn-training-cpp simple_rnn_training.cpp "test")
# register_exe(simple-training-net-c simple_training_net.c "test" ${LIBM})
# register_exe(simple-training-net-cpp simple_training_net.cpp "test" ${LIBM})
# register_exe(simple-net-int8-cpp simple_net_int8.cpp "test")
# register_exe(simple-rnn-cpp simple_rnn.cpp "test")
# register_exe(simple-rnn-int8-cpp simple_rnn_int8.cpp "test")
# register_exe(simple-rnn-training-cpp simple_rnn_training.cpp "test")
6 changes: 3 additions & 3 deletions examples/simple_net.c
Original file line number Diff line number Diff line change
Expand Up @@ -428,9 +428,9 @@ mkldnn_status_t simple_net() {
if (pool_reorder_dst) net[n++] = pool_reorder_dst;

mkldnn_stream_t stream;
CHECK(mkldnn_stream_create(&stream, mkldnn_eager));
CHECK(mkldnn_stream_submit(stream, n, net, NULL));
CHECK(mkldnn_stream_wait(stream, n, NULL));
CHECK(mkldnn_stream_create(&stream, engine, mkldnn_stream_kind_default));
for (uint32_t i = 0; i < n; ++i)
mkldnn_primitive_execute(net[i], stream);

/* clean-up */
CHECK(mkldnn_primitive_desc_destroy(conv_pd));
Expand Down
7 changes: 4 additions & 3 deletions examples/simple_net.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -795,9 +795,10 @@ void simple_net(int times = 100) {
net.push_back(reorder(fc8_dst_memory, user_dst_memory));
}

stream(stream::kind::eager).submit(net_weights).wait();
stream s(cpu_engine);
for (int j = 0; j < times; ++j) {
stream(stream::kind::eager).submit(net).wait();
for (const auto &p: net)
p.execute(s);
}
}

Expand All @@ -806,7 +807,7 @@ int main(int argc, char **argv) {
auto begin = chrono::duration_cast<chrono::milliseconds>(
chrono::steady_clock::now().time_since_epoch())
.count();
int times = 1000;
int times = 100;
simple_net(times);
auto end = chrono::duration_cast<chrono::milliseconds>(
chrono::steady_clock::now().time_since_epoch())
Expand Down
26 changes: 6 additions & 20 deletions include/mkldnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,10 @@ mkldnn_status_t MKLDNN_API mkldnn_primitive_create(
const mkldnn_primitive_at_t *inputs,
const_mkldnn_primitive_t *outputs);

/** Executes a @p primitive using a @p stream. */
mkldnn_status_t MKLDNN_API mkldnn_primitive_execute(
const_mkldnn_primitive_t primitive, mkldnn_stream_t stream);

/** Retrieves a reference to the @p primitive_desc descriptor of given @p
* primitive.
*
Expand Down Expand Up @@ -1711,27 +1715,9 @@ mkldnn_status_t MKLDNN_API mkldnn_engine_destroy(mkldnn_engine_t engine);
/** @addtogroup c_api_stream Execution stream operations
* @{ */

/** Creates an execution @p stream of @p stream_kind. */
/** Creates an execution @p stream of @p stream_kind with @p engine. */
mkldnn_status_t MKLDNN_API mkldnn_stream_create(mkldnn_stream_t *stream,
mkldnn_stream_kind_t stream_kind);

/** Submits @p primitives to an execution @p stream. The number of primitives
* is @p n. All or none of the primitives can be lazy. In case of an error,
* returns the offending @p error_primitive if it is not @c NULL. */
mkldnn_status_t MKLDNN_API mkldnn_stream_submit(mkldnn_stream_t stream,
size_t n, mkldnn_primitive_t primitives[],
mkldnn_primitive_t *error_primitive);

/** Waits for all primitives in the execution @p stream to finish. Returns
* immediately if @p block is zero. In case of an error, returns
* the offending @p error_primitive if it is not @c NULL. */
mkldnn_status_t MKLDNN_API mkldnn_stream_wait(mkldnn_stream_t stream,
int block, mkldnn_primitive_t *error_primitive);

/** Reruns all the primitives within the @p stream. In case of an error,
* returns the offending @p error_primitive if it is not @c NULL. */
mkldnn_status_t MKLDNN_API mkldnn_stream_rerun(mkldnn_stream_t stream,
mkldnn_primitive_t *error_primitive);
mkldnn_engine_t engine, mkldnn_stream_kind_t stream_kind);

/** Destroys an execution @p stream. */
mkldnn_status_t MKLDNN_API mkldnn_stream_destroy(mkldnn_stream_t stream);
Expand Down
69 changes: 12 additions & 57 deletions include/mkldnn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ class primitive: public handle<mkldnn_primitive_t> {
/// Returns the descriptor of the underlying C API primitive.
inline const_mkldnn_primitive_desc_t get_primitive_desc() const;
// TODO: use the C++ API wrapper structure.

void execute(struct stream &astream) const;
};

inline mkldnn_primitive_kind_t convert_to_c(primitive::kind akind) {
Expand Down Expand Up @@ -3295,77 +3297,30 @@ template <> struct handle_traits<mkldnn_stream_t> {
struct stream: public handle<mkldnn_stream_t> {
using handle::handle;

enum kind { any = mkldnn_stream_kind_t::mkldnn_any_stream,
eager = mkldnn_stream_kind_t::mkldnn_eager,
lazy = mkldnn_stream_kind_t::mkldnn_lazy };
enum kind { stream_kind_default = mkldnn_stream_kind_default, };

static mkldnn_stream_kind_t convert_to_c(kind akind) {
return static_cast<mkldnn_stream_kind_t>(akind);
}

/// Constructs a stream.
stream(kind akind) {
stream(const engine &aengine, kind akind = stream_kind_default) {
mkldnn_stream_t astream;
error::wrap_c_api(mkldnn_stream_create(&astream,
convert_to_c(akind)),
"could not create a stream");
error::wrap_c_api(mkldnn_stream_create(&astream, aengine.get(),
convert_to_c(akind)), "could not create a stream");
reset(astream);
}

/// Submits a vector of primitives to a stream for computations.
///
/// @param primitives The vector of primitives to submit.
/// @returns The stream.
stream &submit(std::vector<primitive> primitives) {
// TODO: find a proper way to convert vector<primitive> to
// vector<mkldnn_primitive_t>
if (primitives.size() == 0) return *this;
std::vector<mkldnn_primitive_t> c_api_primitives;
c_api_primitives.reserve(primitives.size());
auto convert_to_c = [](primitive p) { return p.get(); };
std::transform(primitives.begin(), primitives.end(),
std::back_inserter(c_api_primitives), convert_to_c);

mkldnn_primitive_t c_api_error_primitive;
error::wrap_c_api(
mkldnn_stream_submit(get(),
c_api_primitives.size(), &c_api_primitives[0],
&c_api_error_primitive),
"could not submit primitives to a stream",
&c_api_error_primitive);

return *this;
}

/// Waits for all computations submitted to the stream to complete.
///
/// @param block Specifies whether the operation should wait indefinitely or
/// return immediately.
/// @returns @c true if all computations completed.
/// @returns @c false if not all computations completed.
bool wait(bool block = true) {
mkldnn_primitive_t c_api_error_primitive;
mkldnn_status_t status = mkldnn_stream_wait(get(),
block, &c_api_error_primitive);
if (status != mkldnn_success
&& status != mkldnn_try_again)
error::wrap_c_api(status, "could not wait on a stream",
&c_api_error_primitive);
return (status == mkldnn_success);
}

stream &rerun() {
mkldnn_primitive_t c_api_error_primitive;
error::wrap_c_api(
mkldnn_stream_rerun(get(), &c_api_error_primitive),
"could not rerun a stream", &c_api_error_primitive);
return *this;
}
};

#undef REG_QUERY_MPD

/// @}

inline void primitive::execute(stream &astream) const {
error::wrap_c_api(mkldnn_primitive_execute(get(), astream.get()),
"could not execute a primitive");
}

/// @} C++ API

} // namespace mkldnn
Expand Down
8 changes: 2 additions & 6 deletions include/mkldnn_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -1252,12 +1252,8 @@ typedef enum {

/** @brief Kinds of streams. */
typedef enum {
/** An unspecified engine. */
mkldnn_any_stream,
/** Eager stream. */
mkldnn_eager,
/** Lazy stream. */
mkldnn_lazy,
/** A default kind of stream. */
mkldnn_stream_kind_default,
} mkldnn_stream_kind_t;

/** @struct mkldnn_stream
Expand Down
4 changes: 1 addition & 3 deletions src/common/c_types_map.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,9 +400,7 @@ using primitive_at_t = mkldnn_primitive_at_t;

using stream_kind_t = mkldnn_stream_kind_t;
namespace stream_kind {
const stream_kind_t any_stream = mkldnn_any_stream;
const stream_kind_t eager = mkldnn_eager;
const stream_kind_t lazy = mkldnn_lazy;
const stream_kind_t stream_kind_default = mkldnn_stream_kind_default;
}
using stream_t = mkldnn_stream;

Expand Down
34 changes: 0 additions & 34 deletions src/common/engine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,40 +52,6 @@ struct mkldnn_engine: public mkldnn::impl::c_compatible {
/** get kind of the current engine */
virtual mkldnn::impl::engine_kind_t kind() const { return kind_; }

/** submits a primitive @p p for execution
*
* @param p (input)
* primitive to execute
* @param e (output)
* resulting event (to be passed to p->execute(e))
* @param prerequisites (input)
* vector of prerequisite events that must be finished before @p p is run
*
* @return
* status of the operation
*
* @remark @b Rational.
* Prerequisites are separated from input-resources. Though memory is a
* primitive, it becomes a singularity point in the sense of signaling
* that it is ready (either it should not have corresponding event or the
* event should be always returns it is ready). Let engine has pretty
* simple logic wrt primitive run. Also this approach allows to reduce
* the amount of prerequisites checks -- usually the \# of real
* dependencies < the \# of primitive inputs).
*
* @warning
* Engine does not track dependencies and their consistencies. Internal
* library code may easily submit a primitive with the same resulting and
* prerequisite event, obtaining dead-lock. Engine won't even try to
* prevent such a situation.
*
* @note
* if any of @p prerequisites is finished with @c event::error or @c
* event::aborted primitive @p p would not be executed, its event's @p e
* state is automatically set to @c event::aborted */
virtual mkldnn::impl::status_t submit(mkldnn::impl::primitive_t *p,
mkldnn::impl::event_t *e, event_vector &prerequisites) = 0;

/* implementation section */
virtual mkldnn::impl::status_t memory_primitive_desc_create(
mkldnn::impl::memory_pd_t **memory_pd,
Expand Down
46 changes: 45 additions & 1 deletion src/common/primitive.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,35 @@
#include <assert.h>

#include "c_types_map.hpp"
#include "engine.hpp"
#include "memory_pd.hpp"
#include "primitive_desc.hpp"
#include "primitive.hpp"
#include "engine.hpp"
#include "type_helpers.hpp"
#include "stream.hpp"
#include "utils.hpp"

using namespace mkldnn::impl;
using namespace mkldnn::impl::status;
using namespace mkldnn::impl::primitive_kind;

namespace {
// XXX: this is a huge hammer. This disables all and any msan checks on
// primitives outputs.
//
// A proper approach would be an implementation-specific unpoisoning.
void unpoison_outputs(const primitive_t *p) {
if (p->engine()->kind() != engine_kind::cpu) return;
for(auto o: p->outputs()) {
assert(o->kind() == primitive_kind::memory);
void *p;
o->get_data_handle(&p);
size_t s = ((memory_pd_t *)o->pd())->get_size();
msan_unpoison(p, s);
}
}
}

status_t mkldnn_primitive_desc_destroy(primitive_desc_t *primitive_desc) {
if (primitive_desc) delete primitive_desc;
return success;
Expand All @@ -53,6 +72,31 @@ status_t mkldnn_primitive_create(primitive_t **primitive,
return primitive_desc->create_primitive(primitive, inputs, outputs);
}

status_t mkldnn_primitive_execute(const primitive_t *primitive,
stream_t *stream) {
bool ok = true
&& !utils::any_null(primitive, stream)
&& primitive->engine() == stream->engine();
if (!ok) return invalid_arguments;

exec_ctx_t ctx(stream);
status_t status;

if (mkldnn_verbose()->level) {
double ms = get_msec();
status = primitive->execute(ctx);
ms = get_msec() - ms;
printf("mkldnn_verbose,exec,%s,%g\n", primitive->pd()->info(), ms);
fflush(0);
} else {
status = primitive->execute(ctx);
}

if (msan_enabled) unpoison_outputs(primitive);

return status;
}

status_t mkldnn_primitive_get_primitive_desc(const primitive_t *primitive,
const primitive_desc_t **primitive_desc) {
if (utils::any_null(primitive, primitive_desc))
Expand Down
33 changes: 22 additions & 11 deletions src/common/primitive.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,25 @@
#include "nstl.hpp"
#include "primitive_desc.hpp"

namespace mkldnn {
namespace impl {

/** Primitive execution context (helps passing stream, memories, and events. */
struct exec_ctx_t {
exec_ctx_t(const exec_ctx_t &) = default;
exec_ctx_t(exec_ctx_t &&) = default;

exec_ctx_t(stream_t *stream): stream_(stream) {}

stream_t *stream() const { return stream_; }

private:
stream_t *stream_;
};

}
}

/** \brief A pure virtual primitive class
*
* Primitive contains links to its inputs & outputs, though it does not track
Expand Down Expand Up @@ -69,17 +88,9 @@ struct mkldnn_primitive: public mkldnn::impl::c_compatible {
/** returns primitive's kind */
mkldnn::impl::primitive_kind_t kind() const { return pd_->kind(); }

/** executes primitive with resulting event @p e
*
* @p e (output)
* a resulting event. It is primitive responsibility to set @p e state
* after actual computations are done
*
* @remark @b Rational.
* Suppose engine has a task pool and for some reasons submission failed.
* In this case primitive will set @p e's state to event::error
*/
virtual void execute(mkldnn::impl::event_t *e) const = 0;
/** executes primitive with execution context @p ctx */
virtual mkldnn::impl::status_t execute(const mkldnn::impl::exec_ctx_t &ctx)
const = 0;

/** returns data handle. Applicable for memory primitives only. */
virtual mkldnn::impl::status_t get_data_handle(void **handle) const {
Expand Down
Loading

0 comments on commit 783718e

Please sign in to comment.