Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions include/mxnet/static_graph.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*!
* Copyright (c) 2015 by Contributors
* \file static_graph.h
* \brief the static graph of symbols
*/
#ifndef MXNET_STATIC_GRAPH_H_
#define MXNET_STATIC_GRAPH_H_

#include <vector>
#include <unordered_map>
#include <string>
#include <memory>
#include "./atomic_symbol.h"
namespace mxnet {

struct StaticGraph {
struct StaticNode {
/*! \brief wrapped atomic symbol */
AtomicSymbol* sym_;
/*! \brief name of the node */
std::string name_;
};
std::unordered_map<std::string, int> name_id_map;
std::vector<StaticNode> nodes;
std::vector<std::vector<int> > output_index;
std::vector<std::vector<int> > connected_graph;
int FindNodeByName(const std::string& name, const AtomicSymbol* sym) {
int id = 0;
if (name_id_map.find(name) == name_id_map.end()) {
name_id_map[name] = name_id_map.size();
StaticNode static_node;
static_node.sym_ = sym->Copy();
static_node.name_ = name;
nodes.push_back(static_node);
output_index.push_back(std::vector<int>());
connected_graph.push_back(std::vector<int>());
id = name_id_map.size();
} else {
id = name_id_map[name];
}
return id;
}
};
}
#endif
37 changes: 32 additions & 5 deletions include/mxnet/symbol.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
#include "./base.h"
#include "./tensor_blob.h"
#include "./operator.h"
#include "./static_graph.h"

namespace mxnet {
class CompositeOperator;
/*!
* \brief Symbol is the wrapper of AtomicSymbol, the reason for this indirection is that Symbol
* should support expressions and often passed by value. While AtomicSymbol have many subclasses,
Expand All @@ -27,7 +29,7 @@ namespace mxnet {
* A atomic symbol can be seen as a special case of the composite symbol with only the head node.
*/
class Symbol {
protected:
public:
/*!
* \brief Node is the container of AtomicSymbol, it also stores the connection of the AtomicSymbol
* with input symbols.
Expand All @@ -46,12 +48,19 @@ class Symbol {
/*!
* \brief constructor
*/
explicit Node(AtomicSymbol* sym = nullptr, const std::string& name = "");
explicit Node(AtomicSymbol* sym = nullptr, const std::string& name = "") :
sym_(sym), name_(name) {
}
/*!
* \brief destructor
*/
~Node();
~Node() {
if (sym_) {
delete sym_;
}
}
};
protected:
/*! \brief the head node of the Symbol, it could be shared in many graphs */
std::shared_ptr<Node> head_;
/*! \brief if the head has multiple return values, index is used to specify */
Expand All @@ -60,7 +69,13 @@ class Symbol {
std::shared_ptr<std::vector<std::pair<Node*, int> > > arg_users_;
/*! \brief find arg users */
void FindArgUsers();

/**
* @brief Recursively parse the symbol to equivalent static graph.
*
* @param node The current node in dfs
* @param graph The static graph
*/
void Dfs(const std::shared_ptr<Node> node, StaticGraph& graph);
public:
/*!
* \brief declare virtual destructor in case it is subclassed.
Expand All @@ -71,7 +86,14 @@ class Symbol {
* \param ctx context of the operator
* \return returns the pointer to a created operator. It is on the user to delete.
*/
virtual Operator* Bind(Context ctx) const { return nullptr; }
virtual CompositeOperator* Bind(Context ctx) const { return nullptr; }
/**
* @brief Bind the symbol to a composite operator
*
* @param in A map denotes name and corresponding NArray for binding
* @return The composite operator
*/
virtual CompositeOperator* Bind(Context ctx, const std::unordered_map<std::string, NArray>& in);
/*!
* \brief copy the symbol
* \return a deep copy of the graph
Expand All @@ -98,6 +120,11 @@ class Symbol {
* \return the arguments list of this symbol, they can be either named or unnamed (empty string).
*/
virtual std::vector<std::string> ListArgs();
/**
* @brief Convert current symbol to its equivalent static graph representation.
* @return the static graph
*/
virtual StaticGraph ToStaticGraph();
/*!
* \brief create Symbol by wrapping AtomicSymbol
*/
Expand Down
40 changes: 28 additions & 12 deletions src/symbol/symbol.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,11 @@
#include <dmlc/logging.h>
#include <mxnet/symbol.h>
#include <mxnet/registry.h>
#include <mxnet/static_graph.h>
#include <iterator>

namespace mxnet {

Symbol::Node::Node(AtomicSymbol* sym, const std::string& name)
: sym_(sym), name_(name) {
}

Symbol::Node::~Node() {
if (sym_) {
delete sym_;
}
}

void Symbol::FindArgUsers() {
arg_users_.reset(new std::vector<std::pair<Node*, int> >);
// depth first traversing
Expand Down Expand Up @@ -144,14 +135,14 @@ Symbol Symbol::Create(AtomicSymbol *atomic_symbol) {
std::vector<std::string> args = atomic_symbol->DescribeArguments();
std::vector<std::string> rets = atomic_symbol->DescribeReturns();
// set head_
s.head_ = std::make_shared<Symbol::Node>(atomic_symbol, "");
s.head_ = std::make_shared<Node>(atomic_symbol, "");
// set index_
s.index_ = rets.size() > 1 ? -1 : 0;
// set head_->in_index_
s.head_->in_index_ = std::vector<int>(args.size(), 0);
// set head_->in_symbol_
for (auto name : args) {
s.head_->in_symbol_.push_back(std::make_shared<Symbol::Node>(nullptr, name));
s.head_->in_symbol_.push_back(std::make_shared<Node>(nullptr, name));
}
// set head_->out_shape_
s.head_->out_shape_ = std::vector<TShape>(rets.size());
Expand All @@ -169,4 +160,29 @@ Symbol Symbol::Create(const std::string& type_name,
return Create(atomic_symbol);
}

StaticGraph Symbol::ToStaticGraph() {
StaticGraph graph;
Dfs(this->head_, graph);
return graph;
}

CompositeOperator* Symbol::Bind(Context ctx, const std::unordered_map<std::string, NArray>& in) {
StaticGraph graph = this->ToStaticGraph();
return NULL;
//TODO: pass the graph and in to initlialize a composite op.
}

void Symbol::Dfs(const std::shared_ptr<Node> node, StaticGraph& graph) {
int id = graph.FindNodeByName(node->name_, node->sym_);
for (size_t i = 0; i < node->in_symbol_.size(); ++i) {
std::shared_ptr<Node> parent = node->in_symbol_[i];
int parent_id = graph.FindNodeByName(parent->name_, parent->sym_);
graph.connected_graph[parent_id].push_back(id);
graph.output_index[parent_id].push_back(node->in_index_[i]);
if (parent->sym_) {
Dfs(parent, graph);
}
}
}

} // namespace mxnet