From b4febff76d00c798c9a9689c89fe9393e349a74b Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 13 Sep 2015 18:05:48 -0700 Subject: [PATCH] add more detailed comments on threaded engine, move dispatcher logic outside the threadedvar complete --- src/engine/engine_impl.h | 1 + src/engine/stream_manager.h | 9 +- src/engine/threaded_engine.cc | 101 ++++++++++++---------- src/engine/threaded_engine.h | 155 ++++++++++++++++++++++++++-------- 4 files changed, 182 insertions(+), 84 deletions(-) diff --git a/src/engine/engine_impl.h b/src/engine/engine_impl.h index cc5ab2e47d6a..e4c350656097 100644 --- a/src/engine/engine_impl.h +++ b/src/engine/engine_impl.h @@ -13,6 +13,7 @@ namespace mxnet { namespace engine { + /*! \brief base class of engine variables, used for type checking */ struct Var { #if ENGINE_DEBUG diff --git a/src/engine/stream_manager.h b/src/engine/stream_manager.h index 76f18d29ec05..7b2382d60df7 100644 --- a/src/engine/stream_manager.h +++ b/src/engine/stream_manager.h @@ -56,11 +56,7 @@ RunContext StreamManager::GetRunContext( auto&& counter = gpu_cnt_.at(ctx.dev_id); if (counter == -1) { for (auto&& i : gpu_streams_.at(ctx.dev_id)) { - #if MXNET_USE_CUDNN == 1 - i = mshadow::NewStream(true, true); - #else - i = mshadow::NewStream(true, false); - #endif // MXNET_USE_CUDNN + i = mshadow::NewStream(true, MXNET_USE_CUDNN != 0); } counter = 0; } @@ -88,7 +84,7 @@ RunContext StreamManager::GetIORunContext( { std::lock_guard lock{m_}; if (gpu_io_streams_.at(ctx.dev_id) == nullptr) { - gpu_io_streams_.at(ctx.dev_id) = mshadow::NewStream(true, false); + gpu_io_streams_.at(ctx.dev_id) = mshadow::NewStream(false, false); } } return {gpu_io_streams_.at(ctx.dev_id)}; @@ -126,7 +122,6 @@ StreamManager::~StreamManager() { } } // namespace engine - } // namespace mxnet #endif // MXNET_ENGINE_STREAM_MANAGER_H_ diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc index b4da2f4b06a9..3b8d8cd0d32e 100644 --- a/src/engine/threaded_engine.cc +++ b/src/engine/threaded_engine.cc @@ -50,9 +50,7 @@ void ThreadedVar::AppendWriteDependency(OprBlock* opr_block) { head_->trigger = opr_block; head_->write = true; if (ready_to_read_) { - /*! - * Raise `num_pending_reads_` temporarily to avoid premature triggering. - */ + // Raise `num_pending_reads_` temporarily to avoid premature triggering. ++num_pending_reads_; pending_write_ = head_; if (--num_pending_reads_ == 0) { @@ -65,51 +63,74 @@ void ThreadedVar::AppendWriteDependency(OprBlock* opr_block) { template void ThreadedVar::CompleteReadDependency(Dispatcher dispatcher) { - std::lock_guard lock{m_}; - if (--num_pending_reads_ == 0) { - if (pending_write_ != nullptr && --pending_write_->trigger->wait == 0) { - dispatcher(pending_write_->trigger); + bool trigger = false; + { + // this is lock scope + std::lock_guard lock{m_}; + if (--num_pending_reads_ == 0) { + if (pending_write_ != nullptr && --pending_write_->trigger->wait == 0) { + trigger = true; + } } } + if (trigger) { + dispatcher(pending_write_->trigger); + } } template bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) { - std::lock_guard lock{m_}; - assert(ready_to_read_ == false); - auto cur_head = pending_write_->next; - VersionedVarBlock::Delete(pending_write_); - pending_write_ = nullptr; + VersionedVarBlock *old_pending_write, *end_of_dispatch_chain; + int num_reads; + { + // this is lock scope + std::lock_guard lock{m_}; + assert(ready_to_read_ == false); + // detach pending write + old_pending_write = pending_write_; + pending_write_ = nullptr; + // search for chains to trigger + VersionedVarBlock *p = old_pending_write->next; + assert(num_pending_reads_ == 0); + num_reads = 0; + while (p->next != nullptr && p->write == false) { + p = p->next; + ++num_reads; + } + num_pending_reads_ = num_reads; + end_of_dispatch_chain = p; + if (p->next == nullptr) { + ready_to_read_ = true; + } else { + assert(p->write == true); + pending_write_ = p; + } + } + // this is outside of lock scope + // the linked list is detached from variable + VersionedVarBlock *cur_head = old_pending_write->next; + VersionedVarBlock::Delete(old_pending_write); if (to_delete_) { assert(cur_head->next == nullptr); VersionedVarBlock::Delete(cur_head); return true; - } else { - while (true) { - if (cur_head->write == true) { - ++num_pending_reads_; - pending_write_ = cur_head; - if (--num_pending_reads_ == 0) { - if (--cur_head->trigger->wait == 0) { - dispatcher(cur_head->trigger); - } - } - break; - } else if (cur_head->next == nullptr) { - ready_to_read_ = true; - break; - } else { - ++num_pending_reads_; - if (--cur_head->trigger->wait == 0) { - dispatcher(cur_head->trigger); - } - auto prev = cur_head; - cur_head = cur_head->next; - VersionedVarBlock::Delete(prev); - } + } + // dispatch all the events + while (cur_head != end_of_dispatch_chain) { + if (--cur_head->trigger->wait == 0) { + dispatcher(cur_head->trigger); } - return false; + auto prev = cur_head; + cur_head = cur_head->next; + VersionedVarBlock::Delete(prev); } + // trigger pending write, if any + if (pending_write_ != nullptr && num_reads == 0) { + if (--pending_write_->trigger->wait == 0) { + dispatcher(pending_write_->trigger); + } + } + return false; } void ThreadedVar::SetToDelete() { @@ -122,14 +143,6 @@ bool ThreadedVar::ready_to_read() { return ready_to_read_; } -ThreadedVar* ThreadedVar::CastFromBase(Var* v) { - return v->Cast(); -} - -ThreadedOpr* ThreadedOpr::CastFromBase(Opr* o) { - return o->Cast(); -} - ThreadedEngine::ThreadedEngine() : pending_{0}, thread_pool_{[this]() { ThreadWorker(&task_queue_); }}, diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h index 9f3ae3f1c9ba..12557c6739eb 100644 --- a/src/engine/threaded_engine.h +++ b/src/engine/threaded_engine.h @@ -1,5 +1,8 @@ /*! * Copyright (c) 2015 by Contributors + * \file threaded_engine.h + * \brief Implementation of threaded engine that tracks the dependency + * and pushes actions to execute. */ #ifndef MXNET_ENGINE_THREADED_ENGINE_H_ #define MXNET_ENGINE_THREADED_ENGINE_H_ @@ -20,101 +23,186 @@ namespace mxnet { namespace engine { -/*! - * \brief Forward declarations. - */ +// Define helper macros for debug information. +#if ENGINE_DEBUG +#define DEFINE_ENGINE_DEBUG_INFO(Type) \ + static std::atomic counter; \ + Type() { LOG(INFO) << __func__ << " " << ++counter; } \ + ~Type() { LOG(INFO) << __func__ << " " << --counter; } +#else +#define DEFINE_ENGINE_DEBUG_INFO(Type) +#endif + +// Forward declarations struct ThreadedOpr; /*! - * \brief Operation in the queue. + * \brief Operation block in the scheduler. + * Each OprBlock corresponds to an operation pushed to the engine. */ struct OprBlock : public common::ObjectPoolAllocatable { -#if ENGINE_DEBUG - static std::atomic counter; - OprBlock() { LOG(INFO) << __func__ << " " << ++counter; } - ~OprBlock() { LOG(INFO) << __func__ << " " << --counter; } -#endif // ENGINE_DEBUG + /*! + * \brief wait number of pending tasks this OprBlock is waiting for. + */ std::atomic wait{0}; + /*! \brief Pointer to information on performing real operation */ ThreadedOpr* opr{nullptr}; + /*! \brief The context this operator */ Context ctx; + // define possible debug information + DEFINE_ENGINE_DEBUG_INFO(OprBlock); }; // struct OprBlock /*! - * \brief Variable with version information. + * \brief VersionedVarBlock that corresponding to a variable version. + * This is a basic unit of LinkedList in the ThreadedVar. */ struct VersionedVarBlock : public common::ObjectPoolAllocatable { -#if ENGINE_DEBUG - static std::atomic counter; - VersionedVarBlock() { LOG(INFO) << __func__ << " " << ++counter; } - ~VersionedVarBlock() { LOG(INFO) << __func__ << " " << --counter; } -#endif // ENGINE_DEBUG + /*! \brief next block in the LinkedList */ VersionedVarBlock* next{nullptr}; + /*! \brief the operation this block triggers */ OprBlock* trigger{nullptr}; + /*! \brief whether this operation is a write(mutate) operation. */ bool write{false}; + /*! \brief define possible debug information */ + DEFINE_ENGINE_DEBUG_INFO(VersionedVarBlock); }; // struct VersionedVarBlock /*! * \brief Variable implementation. + * Each ThreadedVar is a linked list(queue) of operations to be performed. */ class ThreadedVar final : public Var, public common::ObjectPoolAllocatable { public: -#if ENGINE_DEBUG - static std::atomic counter; - ~ThreadedVar() { LOG(INFO) << __func__ << " " << --counter; } -#endif // ENGINE_DEBUG + /*! + * \brief constructor + * \param head head block of the LinkedList, + * need to be initialized with next==nullptr and trigger=nullptr. + */ explicit ThreadedVar(VersionedVarBlock* head); + /*! + * \brief Schedule a read operation on this variable. + * If the opr_block can be runed right away, + * the wait counter of opr_block will be decreased. + * Otherwise, the opr_block will be added to waiting queue. + * \param opr_block The operation to be scheduled. + */ void AppendReadDependency(OprBlock* opr_block); + /*! + * \brief Schedule a write operation on this variable. + * If the opr_block can be runed right away, + * the wait counter of opr_block will be decreased. + * Otherwise, the opr_block will be added to waiting queue. + * \param opr_block The operation to be scheduled. + */ void AppendWriteDependency(OprBlock* opr_block); + /*! + * \brief A read operation is completed on this variable. + * This function may trigger subsequent waiting operations on this variable. + * + * \param dispatcher the function called to trigger the operation, + * when all of its dependencies are satiesfied. + * \tparam Dispatcher the function called to trigger an operation. + */ template void CompleteReadDependency(Dispatcher dispatcher); + /*! + * \brief A write operation is completed on this variable. + * This function may trigger subsequent waiting operations on this variable. + * + * \param dispatcher the function called to trigger the operation, + * when all of its dependencies are satiesfied. + * \tparam Dispatcher the function called to trigger an operation. + * \return to_delete, whether this Variable can be deleted after this functin. + */ template bool CompleteWriteDependency(Dispatcher dispatcher); + /*! \brief Mark this variable to be deleted. */ void SetToDelete(); + /*! \return whether this variable is ready to read. */ bool ready_to_read(); - - static ThreadedVar* CastFromBase(Var* ptr); + /*! + * \brief Cast a Var pointer to ThreadedVar pointer + * \param ptr pointer from base. + * \return a casted pointer. + */ + inline static ThreadedVar* CastFromBase(Var* ptr) { + return ptr->Cast(); + } + // code for debug. +#if ENGINE_DEBUG + static std::atomic counter; + ~ThreadedVar() { LOG(INFO) << __func__ << " " << --counter; } +#endif // ENGINE_DEBUG private: // TODO(hotpxl) change this to spinlock for faster runtime + // TODO(hotpxl) consider rename head + /*! \brief inetrnal mutex of the ThreadedVar */ std::mutex m_; + /*! \brief number of pending reads operation in the variable. */ std::size_t num_pending_reads_{0}; + /*! + * \brief Points to the last VersionedVarBlock in the queue. + * head_ always points to a empty VersionedVarBlock. + * So when we want to append an operation to the queue: + * 1) update head_->trigger to be new op + * 2) update head_->next to be a new VersionedVarBlock + * 3) move head to head->next. + */ VersionedVarBlock* head_{nullptr}; + /*! + * \brief The pointer to next write to perform. + * This pointer will only be updated when the write completes. + * This is actually the head(oldest operation) in the queue. + */ VersionedVarBlock* pending_write_{nullptr}; /*! - * If true, then there are no current or future processing of the chain. + * \brief If true, then there are no running or pending write on this variable. */ bool ready_to_read_{true}; /*! - * If true, delete after operation completes. + * \brief If true, delete after operation completes. */ bool to_delete_{false}; }; // struct ThreadedVar /*! - * \brief Operator implementation. + * \brief Operator used in ThreadedEngine. */ struct ThreadedOpr final : public Opr, public common::ObjectPoolAllocatable { -#if ENGINE_DEBUG - static std::atomic counter; - ThreadedOpr() { LOG(INFO) << __func__ << " " << ++counter; } - ~ThreadedOpr() { LOG(INFO) << __func__ << " " << --counter; } -#endif // ENGINE_DEBUG + /*! \brief The function to be invoked each time. */ Engine::AsyncFn fn; + /*! \brief The variable this operation will read from. */ std::vector const_vars; + /*! \brief The variable this operation will mutate. */ std::vector mutable_vars; + /*! \brief the property of the operator */ FnProperty prop; + /*! + * \brief Whether this is an temporary operator + * that can be deleted right after the operation completed. + */ bool temporary{false}; - - static ThreadedOpr* CastFromBase(Opr* ptr); + /*! + * \brief Cast a Opr pointer to ThreadedOpr pointer + * \param ptr pointer from base. + * \return a casted pointer. + */ + inline static ThreadedOpr* CastFromBase(Opr* ptr) { + return ptr->Cast(); + } + // define possible debug information + DEFINE_ENGINE_DEBUG_INFO(ThreadedOpr); }; // struct ThreadedOpr /*! * \brief Engine implementation. */ -class ThreadedEngine final : public Engine { +class ThreadedEngine : public Engine { public: /*! * \brief Constructor and destructor. @@ -125,7 +213,8 @@ class ThreadedEngine final : public Engine { * \brief Overriding methods. */ ThreadedVar* NewVariable() override; - ThreadedOpr* NewOperator(AsyncFn fn, std::vector const& const_vars, + ThreadedOpr* NewOperator(AsyncFn fn, + std::vector const& const_vars, std::vector const& mutable_vars, FnProperty prop) override; void DeleteOperator(OprHandle op) override;