From 504f735529be0e22a5677f79a61090b5dcbe3c6f Mon Sep 17 00:00:00 2001 From: winsty Date: Sat, 1 Aug 2015 21:06:27 +0800 Subject: [PATCH 1/3] static graph --- include/mxnet/static_graph.h | 67 ++++++++++++++++++++++++++++++++++++ include/mxnet/symbol.h | 29 +++------------- src/symbol/symbol.cc | 35 ++++++++++++------- 3 files changed, 94 insertions(+), 37 deletions(-) create mode 100644 include/mxnet/static_graph.h diff --git a/include/mxnet/static_graph.h b/include/mxnet/static_graph.h new file mode 100644 index 000000000000..a9566fb8cdff --- /dev/null +++ b/include/mxnet/static_graph.h @@ -0,0 +1,67 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file static_graph.h + * \brief the static graph of symbols + */ +#ifndef MXNET_STATIC_GRAPH_H_ +#define MXNET_STATIC_GRAPH_H_ + +#include +#include +#include +#include +#include "./atomic_symbol.h" +namespace mxnet { + /*! + * \brief Node is the container of AtomicSymbol, it also stores the connection of the AtomicSymbol + * with input symbols. + */ + struct Node { + /*! \brief wrapped atomic symbol */ + AtomicSymbol* sym_; + /*! \brief name of the node */ + std::string name_; + /*! \brief inputs to this node */ + std::vector > in_symbol_; + /*! \brief index of the inputs if the inputs are tuple */ + std::vector in_index_; + /*! \brief the output shape of the wrapped symbol */ + std::vector out_shape_; + /*! + * \brief constructor + */ + explicit Node(AtomicSymbol* sym = nullptr, const std::string& name = "") : + sym_(sym), name_(name) { + } + /*! + * \brief destructor + */ + ~Node() { + if (sym_) { + delete sym_; + } + } + }; + + struct StaticGraph { + std::unordered_map name_id_map; + std::vector > nodes; + std::vector > output_index; + std::vector > connected_graph; + + int FindNodeByName(const std::string& name, const std::shared_ptr node) { + int id = 0; + if (name_id_map.find(name) == name_id_map.end()) { + name_id_map[name] = name_id_map.size(); + nodes.push_back(node); + output_index.push_back(std::vector()); + connected_graph.push_back(std::vector()); + id = name_id_map.size(); + } else { + id = name_id_map[name]; + } + return id; + } + }; +} +#endif diff --git a/include/mxnet/symbol.h b/include/mxnet/symbol.h index 0b69005f7a16..71c73f54a4c4 100644 --- a/include/mxnet/symbol.h +++ b/include/mxnet/symbol.h @@ -16,6 +16,7 @@ #include "./base.h" #include "./tensor_blob.h" #include "./operator.h" +#include "./static_graph.h" namespace mxnet { /*! @@ -28,30 +29,6 @@ namespace mxnet { */ class Symbol { protected: - /*! - * \brief Node is the container of AtomicSymbol, it also stores the connection of the AtomicSymbol - * with input symbols. - */ - struct Node { - /*! \brief wrapped atomic symbol */ - AtomicSymbol* sym_; - /*! \brief name of the node */ - std::string name_; - /*! \brief inputs to this node */ - std::vector > in_symbol_; - /*! \brief index of the inputs if the inputs are tuple */ - std::vector in_index_; - /*! \brief the output shape of the wrapped symbol */ - std::vector out_shape_; - /*! - * \brief constructor - */ - explicit Node(AtomicSymbol* sym = nullptr, const std::string& name = ""); - /*! - * \brief destructor - */ - ~Node(); - }; /*! \brief the head node of the Symbol, it could be shared in many graphs */ std::shared_ptr head_; /*! \brief if the head has multiple return values, index is used to specify */ @@ -60,7 +37,7 @@ class Symbol { std::shared_ptr > > arg_users_; /*! \brief find arg users */ void FindArgUsers(); - + void dfs_(const std::shared_ptr node, StaticGraph& graph); public: /*! * \brief declare virtual destructor in case it is subclassed. @@ -98,6 +75,8 @@ class Symbol { * \return the arguments list of this symbol, they can be either named or unnamed (empty string). */ virtual std::vector ListArgs(); + + virtual StaticGraph ToStaticGraph(); /*! * \brief create Symbol by wrapping AtomicSymbol */ diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc index a4d966ba422f..c89a71995cab 100644 --- a/src/symbol/symbol.cc +++ b/src/symbol/symbol.cc @@ -6,20 +6,11 @@ #include #include #include +#include #include namespace mxnet { -Symbol::Node::Node(AtomicSymbol* sym, const std::string& name) - : sym_(sym), name_(name) { -} - -Symbol::Node::~Node() { - if (sym_) { - delete sym_; - } -} - void Symbol::FindArgUsers() { arg_users_.reset(new std::vector >); // depth first traversing @@ -144,14 +135,14 @@ Symbol Symbol::Create(AtomicSymbol *atomic_symbol) { std::vector args = atomic_symbol->DescribeArguments(); std::vector rets = atomic_symbol->DescribeReturns(); // set head_ - s.head_ = std::make_shared(atomic_symbol, ""); + s.head_ = std::make_shared(atomic_symbol, ""); // set index_ s.index_ = rets.size() > 1 ? -1 : 0; // set head_->in_index_ s.head_->in_index_ = std::vector(args.size(), 0); // set head_->in_symbol_ for (auto name : args) { - s.head_->in_symbol_.push_back(std::make_shared(nullptr, name)); + s.head_->in_symbol_.push_back(std::make_shared(nullptr, name)); } // set head_->out_shape_ s.head_->out_shape_ = std::vector(rets.size()); @@ -169,4 +160,24 @@ Symbol Symbol::Create(const std::string& type_name, return Create(atomic_symbol); } +StaticGraph Symbol::ToStaticGraph() { + StaticGraph graph; + dfs_(this->head_, graph); + return graph; +} + + +void Symbol::dfs_(const std::shared_ptr node, StaticGraph& graph) { + int id = graph.FindNodeByName(node->name_, node); + for (size_t i = 0; i < node->in_symbol_.size(); ++i) { + std::shared_ptr parent = node->in_symbol_[i]; + int parent_id = graph.FindNodeByName(parent->name_, node); + graph.connected_graph[parent_id].push_back(id); + graph.output_index[parent_id].push_back(node->in_index_[i]); + if (parent->sym_) { + dfs_(parent, graph); + } + } +} + } // namespace mxnet From 8f320c4f8472efa7a2f8089279f994ca12077e88 Mon Sep 17 00:00:00 2001 From: winsty Date: Wed, 5 Aug 2015 22:39:17 +0800 Subject: [PATCH 2/3] simplify static graph --- include/mxnet/static_graph.h | 26 ++++++++++++++++---------- src/symbol/symbol.cc | 18 +++++++++--------- 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/include/mxnet/static_graph.h b/include/mxnet/static_graph.h index a9566fb8cdff..86db91546d32 100644 --- a/include/mxnet/static_graph.h +++ b/include/mxnet/static_graph.h @@ -12,15 +12,20 @@ #include #include "./atomic_symbol.h" namespace mxnet { + struct NodeMetaInfo{ + /*! \brief wrapped atomic symbol */ + AtomicSymbol* sym_; + /*! \brief name of the node */ + std::string name_; + }; + /*! * \brief Node is the container of AtomicSymbol, it also stores the connection of the AtomicSymbol * with input symbols. */ struct Node { - /*! \brief wrapped atomic symbol */ - AtomicSymbol* sym_; - /*! \brief name of the node */ - std::string name_; + + NodeMetaInfo info_; /*! \brief inputs to this node */ std::vector > in_symbol_; /*! \brief index of the inputs if the inputs are tuple */ @@ -30,22 +35,23 @@ namespace mxnet { /*! * \brief constructor */ - explicit Node(AtomicSymbol* sym = nullptr, const std::string& name = "") : - sym_(sym), name_(name) { + explicit Node(AtomicSymbol* sym = nullptr, const std::string& name = "") { + info_.sym_ = sym; + info_.name_ = name; } /*! * \brief destructor */ ~Node() { - if (sym_) { - delete sym_; + if (info_.sym_) { + delete info_.sym_; } } }; struct StaticGraph { std::unordered_map name_id_map; - std::vector > nodes; + std::vector nodes; std::vector > output_index; std::vector > connected_graph; @@ -53,7 +59,7 @@ namespace mxnet { int id = 0; if (name_id_map.find(name) == name_id_map.end()) { name_id_map[name] = name_id_map.size(); - nodes.push_back(node); + nodes.push_back(node->info_); output_index.push_back(std::vector()); connected_graph.push_back(std::vector()); id = name_id_map.size(); diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc index c89a71995cab..7c35804e7f38 100644 --- a/src/symbol/symbol.cc +++ b/src/symbol/symbol.cc @@ -22,7 +22,7 @@ void Symbol::FindArgUsers() { stk.pop_back(); } else { Node* next_level = back.first->in_symbol_[back.second].get(); - if (next_level->sym_) { + if (next_level->info_.sym_) { stk.push_back({next_level, 0}); } else { // back uses next_level which is a placeholder arg_users_->push_back({back.first, back.second}); @@ -42,10 +42,10 @@ Symbol Symbol::Copy() const { Node* back = stk.back(); stk.pop_back(); if (old_new.count(back) == 0) { - if (back->sym_) { - old_new[back] = std::make_shared(back->sym_->Copy(), back->name_); + if (back->info_.sym_) { + old_new[back] = std::make_shared(back->info_.sym_->Copy(), back->info_.name_); } else { - old_new[back] = std::make_shared(nullptr, back->name_); + old_new[back] = std::make_shared(nullptr, back->info_.name_); } } for (const std::shared_ptr& n : back->in_symbol_) { @@ -98,7 +98,7 @@ Symbol Symbol::operator () (const std::unordered_map& kwarg << s.arg_users_->size() << " provided " << kwargs.size(); for (size_t i = 0; i < s.arg_users_->size(); ++i) { const std::pair& arg_user = (*s.arg_users_)[i]; - const std::string& name = arg_user.first->name_; + const std::string& name = arg_user.first->info_.name_; if (!(name == "") && kwargs.count(name) != 0) { const Symbol& bind = kwargs.at(name); arg_user.first->in_symbol_[arg_user.second] = bind.head_; @@ -125,7 +125,7 @@ std::vector Symbol::ListArgs() { } std::transform(arg_users_->begin(), arg_users_->end(), std::back_inserter(ret), [&](const std::pair& n) -> std::string { - return n.first->in_symbol_[n.second]->name_; + return n.first->in_symbol_[n.second]->info_.name_; }); return ret; } @@ -168,13 +168,13 @@ StaticGraph Symbol::ToStaticGraph() { void Symbol::dfs_(const std::shared_ptr node, StaticGraph& graph) { - int id = graph.FindNodeByName(node->name_, node); + int id = graph.FindNodeByName(node->info_.name_, node); for (size_t i = 0; i < node->in_symbol_.size(); ++i) { std::shared_ptr parent = node->in_symbol_[i]; - int parent_id = graph.FindNodeByName(parent->name_, node); + int parent_id = graph.FindNodeByName(parent->info_.name_, node); graph.connected_graph[parent_id].push_back(id); graph.output_index[parent_id].push_back(node->in_index_[i]); - if (parent->sym_) { + if (parent->info_.sym_) { dfs_(parent, graph); } } From bccfbeb10ba786c3a4354dd6c7c5b6c22c5d1ed3 Mon Sep 17 00:00:00 2001 From: winsty Date: Fri, 7 Aug 2015 00:09:19 +0800 Subject: [PATCH 3/3] fix static graph --- include/mxnet/static_graph.h | 54 +++++++++--------------------------- include/mxnet/symbol.h | 54 ++++++++++++++++++++++++++++++++++-- src/symbol/symbol.cc | 29 +++++++++++-------- 3 files changed, 81 insertions(+), 56 deletions(-) diff --git a/include/mxnet/static_graph.h b/include/mxnet/static_graph.h index 86db91546d32..747090695b99 100644 --- a/include/mxnet/static_graph.h +++ b/include/mxnet/static_graph.h @@ -12,54 +12,26 @@ #include #include "./atomic_symbol.h" namespace mxnet { - struct NodeMetaInfo{ - /*! \brief wrapped atomic symbol */ - AtomicSymbol* sym_; - /*! \brief name of the node */ - std::string name_; - }; - - /*! - * \brief Node is the container of AtomicSymbol, it also stores the connection of the AtomicSymbol - * with input symbols. - */ - struct Node { - - NodeMetaInfo info_; - /*! \brief inputs to this node */ - std::vector > in_symbol_; - /*! \brief index of the inputs if the inputs are tuple */ - std::vector in_index_; - /*! \brief the output shape of the wrapped symbol */ - std::vector out_shape_; - /*! - * \brief constructor - */ - explicit Node(AtomicSymbol* sym = nullptr, const std::string& name = "") { - info_.sym_ = sym; - info_.name_ = name; - } - /*! - * \brief destructor - */ - ~Node() { - if (info_.sym_) { - delete info_.sym_; - } - } - }; - + struct StaticGraph { + struct StaticNode { + /*! \brief wrapped atomic symbol */ + AtomicSymbol* sym_; + /*! \brief name of the node */ + std::string name_; + }; std::unordered_map name_id_map; - std::vector nodes; + std::vector nodes; std::vector > output_index; std::vector > connected_graph; - - int FindNodeByName(const std::string& name, const std::shared_ptr node) { + int FindNodeByName(const std::string& name, const AtomicSymbol* sym) { int id = 0; if (name_id_map.find(name) == name_id_map.end()) { name_id_map[name] = name_id_map.size(); - nodes.push_back(node->info_); + StaticNode static_node; + static_node.sym_ = sym->Copy(); + static_node.name_ = name; + nodes.push_back(static_node); output_index.push_back(std::vector()); connected_graph.push_back(std::vector()); id = name_id_map.size(); diff --git a/include/mxnet/symbol.h b/include/mxnet/symbol.h index 71c73f54a4c4..d72410809731 100644 --- a/include/mxnet/symbol.h +++ b/include/mxnet/symbol.h @@ -19,6 +19,7 @@ #include "./static_graph.h" namespace mxnet { +class CompositeOperator; /*! * \brief Symbol is the wrapper of AtomicSymbol, the reason for this indirection is that Symbol * should support expressions and often passed by value. While AtomicSymbol have many subclasses, @@ -28,6 +29,37 @@ namespace mxnet { * A atomic symbol can be seen as a special case of the composite symbol with only the head node. */ class Symbol { + public: + /*! + * \brief Node is the container of AtomicSymbol, it also stores the connection of the AtomicSymbol + * with input symbols. + */ + struct Node { + /*! \brief wrapped atomic symbol */ + AtomicSymbol* sym_; + /*! \brief name of the node */ + std::string name_; + /*! \brief inputs to this node */ + std::vector > in_symbol_; + /*! \brief index of the inputs if the inputs are tuple */ + std::vector in_index_; + /*! \brief the output shape of the wrapped symbol */ + std::vector out_shape_; + /*! + * \brief constructor + */ + explicit Node(AtomicSymbol* sym = nullptr, const std::string& name = "") : + sym_(sym), name_(name) { + } + /*! + * \brief destructor + */ + ~Node() { + if (sym_) { + delete sym_; + } + } + }; protected: /*! \brief the head node of the Symbol, it could be shared in many graphs */ std::shared_ptr head_; @@ -37,7 +69,13 @@ class Symbol { std::shared_ptr > > arg_users_; /*! \brief find arg users */ void FindArgUsers(); - void dfs_(const std::shared_ptr node, StaticGraph& graph); + /** + * @brief Recursively parse the symbol to equivalent static graph. + * + * @param node The current node in dfs + * @param graph The static graph + */ + void Dfs(const std::shared_ptr node, StaticGraph& graph); public: /*! * \brief declare virtual destructor in case it is subclassed. @@ -48,7 +86,14 @@ class Symbol { * \param ctx context of the operator * \return returns the pointer to a created operator. It is on the user to delete. */ - virtual Operator* Bind(Context ctx) const { return nullptr; } + virtual CompositeOperator* Bind(Context ctx) const { return nullptr; } + /** + * @brief Bind the symbol to a composite operator + * + * @param in A map denotes name and corresponding NArray for binding + * @return The composite operator + */ + virtual CompositeOperator* Bind(Context ctx, const std::unordered_map& in); /*! * \brief copy the symbol * \return a deep copy of the graph @@ -75,7 +120,10 @@ class Symbol { * \return the arguments list of this symbol, they can be either named or unnamed (empty string). */ virtual std::vector ListArgs(); - + /** + * @brief Convert current symbol to its equivalent static graph representation. + * @return the static graph + */ virtual StaticGraph ToStaticGraph(); /*! * \brief create Symbol by wrapping AtomicSymbol diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc index 7c35804e7f38..2506b49af65f 100644 --- a/src/symbol/symbol.cc +++ b/src/symbol/symbol.cc @@ -22,7 +22,7 @@ void Symbol::FindArgUsers() { stk.pop_back(); } else { Node* next_level = back.first->in_symbol_[back.second].get(); - if (next_level->info_.sym_) { + if (next_level->sym_) { stk.push_back({next_level, 0}); } else { // back uses next_level which is a placeholder arg_users_->push_back({back.first, back.second}); @@ -42,10 +42,10 @@ Symbol Symbol::Copy() const { Node* back = stk.back(); stk.pop_back(); if (old_new.count(back) == 0) { - if (back->info_.sym_) { - old_new[back] = std::make_shared(back->info_.sym_->Copy(), back->info_.name_); + if (back->sym_) { + old_new[back] = std::make_shared(back->sym_->Copy(), back->name_); } else { - old_new[back] = std::make_shared(nullptr, back->info_.name_); + old_new[back] = std::make_shared(nullptr, back->name_); } } for (const std::shared_ptr& n : back->in_symbol_) { @@ -98,7 +98,7 @@ Symbol Symbol::operator () (const std::unordered_map& kwarg << s.arg_users_->size() << " provided " << kwargs.size(); for (size_t i = 0; i < s.arg_users_->size(); ++i) { const std::pair& arg_user = (*s.arg_users_)[i]; - const std::string& name = arg_user.first->info_.name_; + const std::string& name = arg_user.first->name_; if (!(name == "") && kwargs.count(name) != 0) { const Symbol& bind = kwargs.at(name); arg_user.first->in_symbol_[arg_user.second] = bind.head_; @@ -125,7 +125,7 @@ std::vector Symbol::ListArgs() { } std::transform(arg_users_->begin(), arg_users_->end(), std::back_inserter(ret), [&](const std::pair& n) -> std::string { - return n.first->in_symbol_[n.second]->info_.name_; + return n.first->in_symbol_[n.second]->name_; }); return ret; } @@ -162,20 +162,25 @@ Symbol Symbol::Create(const std::string& type_name, StaticGraph Symbol::ToStaticGraph() { StaticGraph graph; - dfs_(this->head_, graph); + Dfs(this->head_, graph); return graph; } +CompositeOperator* Symbol::Bind(Context ctx, const std::unordered_map& in) { + StaticGraph graph = this->ToStaticGraph(); + return NULL; + //TODO: pass the graph and in to initlialize a composite op. +} -void Symbol::dfs_(const std::shared_ptr node, StaticGraph& graph) { - int id = graph.FindNodeByName(node->info_.name_, node); +void Symbol::Dfs(const std::shared_ptr node, StaticGraph& graph) { + int id = graph.FindNodeByName(node->name_, node->sym_); for (size_t i = 0; i < node->in_symbol_.size(); ++i) { std::shared_ptr parent = node->in_symbol_[i]; - int parent_id = graph.FindNodeByName(parent->info_.name_, node); + int parent_id = graph.FindNodeByName(parent->name_, parent->sym_); graph.connected_graph[parent_id].push_back(id); graph.output_index[parent_id].push_back(node->in_index_[i]); - if (parent->info_.sym_) { - dfs_(parent, graph); + if (parent->sym_) { + Dfs(parent, graph); } } }