Skip to content
This repository was archived by the owner on Jul 1, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions examples/cifar10.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,8 @@ void testCIFAR10() {
result = bb.createReturnOp(*SM);
}

IP.optimize();
IP.optimize(OptimizationMode::kTrain);
IP.initVars();
IP.getModule().dump();
IP.getModule().dumpDAG();

// Report progress every this number of training iterations.
int reportRate = 256;
Expand Down
4 changes: 2 additions & 2 deletions examples/mnist.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,8 @@ void testMNIST() {
result = bb.createReturnOp(*SM);
}

IP.optimize();
IP.optimize(OptimizationMode::kTrain);
IP.initVars();
IP.getModule().dumpDAG();

// Report progress every this number of training iterations.
constexpr int reportRate = 30;
Expand All @@ -110,6 +109,7 @@ void testMNIST() {
IP.train(reportRate, {A, selected}, {&imageInputs, &labelInputs});
}
std::cout << "Validating.\n";
IP.optimize(OptimizationMode::kInfer);

auto LIH = labelInputs.getHandle<size_t>();

Expand Down
7 changes: 7 additions & 0 deletions include/glow/IR/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ class Instruction : public Value {
}
}

/// \returns true if the @In arguments can share a buffer with the @Out
/// arguments. This happens when the read and write access for the buffer are
/// the same.
bool mayShareBuffers() const { return true; }

/// When printing the instruction this method prints the extra metadata.
std::string getExtraDesc() const { return ""; }

Expand All @@ -148,6 +153,8 @@ class Instruction : public Value {
void verify() const;

operator Value *() const { return getOperand(0).first; }

static bool mayShareBuffers(const Instruction *I);
};

class WeightVar;
Expand Down
17 changes: 16 additions & 1 deletion include/glow/IR/Instrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "glow/IR/IR.h"
#include "glow/IR/Type.h"
#include "glow/Support/Casting.h"

namespace glow {

Expand All @@ -23,13 +24,17 @@ class DeallocActivationInst : public Instruction {
public:
DeallocActivationInst(Value *src)
: Instruction(Kinded::Kind::DeallocActivationInstKind, src->getType(),
{{src, OperandKind::kIn}}) {}
{{src, OperandKind::kOut}}) {}

static bool classof(const Kinded *k) {
return k->getKind() == Kinded::Kind::DeallocActivationInstKind;
}

void verify() const;

AllocActivationInst *getAlloc() const {
return cast<AllocActivationInst>(getOperand(0).first);
}
};

class CopyInst : public Instruction {
Expand Down Expand Up @@ -66,6 +71,9 @@ class ConvolutionInst : public Instruction {
static bool classof(const Kinded *k) {
return k->getKind() == Kinded::Kind::ConvolutionInstKind;
}

bool mayShareBuffers() const { return false; }

std::string getExtraDesc() const;
Value *getDest() const { return getOperand(0).first; }
Value *getSrc() const { return getOperand(1).first; }
Expand Down Expand Up @@ -107,6 +115,8 @@ class PoolInst : public Instruction {
static bool classof(const Kinded *k) {
return k->getKind() == Kinded::Kind::PoolInstKind;
}

bool mayShareBuffers() const { return false; }
std::string getExtraDesc() const;
Value *getDest() const { return getOperand(0).first; }
Value *getSrc() const { return getOperand(1).first; }
Expand Down Expand Up @@ -135,6 +145,8 @@ class FullyConnectedInst : public Instruction {
static bool classof(const Kinded *k) {
return k->getKind() == Kinded::Kind::FullyConnectedInstKind;
}

bool mayShareBuffers() const { return false; }
std::string getExtraDesc() const;
Value *getDest() const { return getOperand(0).first; }
Value *getSrc() const { return getOperand(1).first; }
Expand Down Expand Up @@ -234,6 +246,8 @@ class TransposeInst : public Instruction {
static bool classof(const Kinded *k) {
return k->getKind() == Kinded::Kind::TransposeInstKind;
}

bool mayShareBuffers() const { return false; }
std::string getExtraDesc() const;
Value *getDest() const { return getOperand(0).first; }
Value *getSrc() const { return getOperand(1).first; }
Expand Down Expand Up @@ -280,6 +294,7 @@ class ConcatInst : public Instruction {
static bool classof(const Kinded *k) {
return k->getKind() == Kinded::Kind::ConcatInstKind;
}
bool mayShareBuffers() const { return false; }
std::string getExtraDesc() const;
Value *getDest() const { return getOperand(0).first; }
Value *getSrc() const { return getOperand(1).first; }
Expand Down
7 changes: 5 additions & 2 deletions include/glow/IR/UseDef.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,14 @@ template <typename UserTy, typename UseTy> class UseDef {
void addUse(Use U) { users_.push_back(U); }

/// \returns True if the value has some users.
bool hasUsers() const { return users_.size(); }
bool hasUsers() const { return !users_.empty(); }

/// \returns the number of users that the value has.
unsigned getNumUsers() const { return users_.size(); }

/// Returns true if the user \p I is in the list.
bool hasUser(const UserTy *I) const {
for (auto &U : users_) {
for (const auto &U : users_) {
if (U.second == I)
return true;
}
Expand Down
3 changes: 2 additions & 1 deletion include/glow/Interpreter/Interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "glow/IR/IRBuilder.h"
#include "glow/Network/Tensor.h"
#include "glow/Network/Train.h"
#include "glow/Optimizer/Optimizer.h"

#include <unordered_map>

Expand Down Expand Up @@ -32,7 +33,7 @@ class Interpreter final {
Module &getModule() { return M_; }

/// Run the target-independent optimizations on the module.
void optimize();
void optimize(OptimizationMode mode);

/// Ctor.
Interpreter();
Expand Down
8 changes: 7 additions & 1 deletion include/glow/Optimizer/Optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@ namespace glow {

class Module;

void optimize(Module &M);
enum class OptimizationMode {
kNone, // Don't optimize the module.
kTrain, // Optimize the module but allow training.

This comment was marked as off-topic.

This comment was marked as off-topic.

kInfer, // Optimize the module and break training.
};

void optimize(Module &M, OptimizationMode mode);

} // namespace glow

Expand Down
17 changes: 14 additions & 3 deletions src/glow/IR/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,18 +92,29 @@ void Module::verify() const {

static std::string getExtraDesc(const Kinded *K) {
#define DEF_INSTR(CLASS, NAME) \
if (auto *X = dyn_cast<const CLASS>(K)) \
if (const auto *X = dyn_cast<const CLASS>(K)) \
return X->getExtraDesc();
#define DEF_VALUE(CLASS, NAME) \
if (auto *X = dyn_cast<const CLASS>(K)) \
if (const auto *X = dyn_cast<const CLASS>(K)) \
return X->getExtraDesc();
#include "glow/IR/Instrs.def"
#undef DEF_INSTRE
#undef DEF_INSTR
#undef DEF_VALUE

glow_unreachable();
}

bool Instruction::mayShareBuffers(const Instruction *I) {
#define DEF_INSTR(CLASS, NAME) \
if (const auto *X = dyn_cast<const CLASS>(I)) \
return X->mayShareBuffers();
#define DEF_VALUE(CLASS, NAME)
#include "glow/IR/Instrs.def"
#undef DEF_INSTR
#undef DEF_VALUE
glow_unreachable();
}

static std::string getDesc(const Value *v) {
std::string sb;
std::string name = v->getName();
Expand Down
3 changes: 2 additions & 1 deletion src/glow/Importer/Caffe2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,8 +405,9 @@ caffe2ModelLoader::caffe2ModelLoader(const std::string &netDescFilename,
loadProtoFile(weightsDef, netWeightFilename);
loadWeights(weightsDef);
loadNetwork(networkDef);
builder_.deallocateActiveInstrs();

// Save the result of the last operator into a weight.
root_ = builder_.createReturnOp(root_);

builder_.deallocateActiveInstrs();
}
4 changes: 3 additions & 1 deletion src/glow/Interpreter/Interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ Interpreter::~Interpreter() {
}
}

void Interpreter::optimize() { ::glow::optimize(M_); }
void Interpreter::optimize(OptimizationMode mode) {
::glow::optimize(M_, mode);
}

void Interpreter::registerTensor(Value *v, Tensor *t) {
assert(t->getType().isEqual(v->getType()) &&
Expand Down
153 changes: 151 additions & 2 deletions src/glow/Optimizer/Optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "glow/Support/Casting.h"

#include <unordered_map>
#include <unordered_set>

using namespace glow;

Expand Down Expand Up @@ -46,10 +47,158 @@ static void hoistDealloc(Module &M) {
}
}

void glow::optimize(Module &M) {
/// Delete alloc instructions that have no readers or writers.
static void deleteDeadAllocs(Module &M) {
auto &instrs = M.getInstrs();

// Remove all of the DeallocActivationInst that close unused allocs.
instrs.erase(
std::remove_if(instrs.begin(), instrs.end(),
[](const Instruction *I) -> bool {
if (const auto *DA =
dyn_cast<const DeallocActivationInst>(I)) {
return DA->getAlloc()->getNumUsers() < 2;
}
return false;
}),
instrs.end());

// Remove the unused allocs.
instrs.erase(std::remove_if(instrs.begin(), instrs.end(),
[](const Instruction *I) -> bool {
if (isa<const AllocActivationInst>(I)) {
return I->getNumUsers() < 2;
}
return false;
}),
std::end(instrs));
}

// Replace all users of some value with another value, but don't touch the
// dealloc instruction, because we need to preserve the well formdness of the
// IR.
static void replaceAllNonDeallocUsersWith(Value *val, Value *with) {
assert(val != with && "Replacing value with self");
auto &users = val->getUsers();
// We use a vector here because changing the operands of the user changes the
// uselist, and this invalidates the iterator.
std::vector<Value::Use> usersVec(users.begin(), users.end());
for (auto &U : usersVec) {
// Ignore dealloc instrs.
if (isa<DeallocActivationInst>(U.second)) {
continue;
}

U.second->setOperand(U.first, with);
}
}

/// Optimize the input/output buffer for the instruction \p I, based on the
/// liveness information in \p liveBuffers.
static void shareBuffersForInstr(const std::unordered_set<Value *> &liveBuffers,
Instruction *I) {
// At this point <out> variables are marked as dead, and <in> variables have
// not been marked alive yet.
for (unsigned first = 0, e = I->getNumOperands(); first < e; first++) {
for (unsigned second = first + 1; second < e; second++) {
auto destOp = I->getOperand(first);
auto srcOp = I->getOperand(second);

// Operands must be different, but of the same type.
if (destOp.first->getType() != srcOp.first->getType() ||
destOp.first == srcOp.first) {
continue;
}

// If both the src and the dest operands are dead, this means that we can
// reuse the buffer storage.
if (!liveBuffers.count(destOp.first) && !liveBuffers.count(srcOp.first)) {
replaceAllNonDeallocUsersWith(destOp.first, srcOp.first);
return;
}
}
}
}

static void shareBuffers(Module &M) {
auto &instrs = M.getInstrs();
std::unordered_set<Value *> liveBuffers;

// All the weights are alive, because they are persistent.
for (auto *W : M.getWeights()) {
liveBuffers.insert(W);
}

// For each instruction, in reverse order.
for (auto it = instrs.rbegin(), e = instrs.rend(); it != e; ++it) {
Instruction *I = *it;

// Remove <out> dependencies from the live set, because this instruction
// writes into them.
for (unsigned op = 0, ope = I->getNumOperands(); op < ope; op++) {
auto O = I->getOperand(op);
auto ai = dyn_cast<AllocActivationInst>(O.first);

This comment was marked as off-topic.

This comment was marked as off-topic.

if (!ai) {
continue;
}

// <Out> dependency means that the buffer is being killed. Remove from the
// live list.
if (O.second == OperandKind::kOut) {
auto it = liveBuffers.find(ai);
if (it != liveBuffers.end()) {
liveBuffers.erase(it);
}
continue;
}
// The <InOut> means that the value of the buffer is being consumed,
// which means that it is alive. Add to the live set.
if (ai && O.second == OperandKind::kInOut) {
liveBuffers.insert(ai);
}
}

if (Instruction::mayShareBuffers(I))
shareBuffersForInstr(liveBuffers, I);

// Insert the input buffers into the live set.
for (unsigned op = 0, ope = I->getNumOperands(); op < ope; op++) {
auto O = I->getOperand(op);
auto ai = dyn_cast<AllocActivationInst>(O.first);

if (!ai) {
continue;
}

// The <In> means that the value of the buffer is being consumed,
// which means that it is alive. Add to the live set.
if (O.second != OperandKind::kOut) {
liveBuffers.insert(ai);
}
}
}
}

void glow::optimize(Module &M, OptimizationMode mode) {
M.verify();

hoistDealloc(M);
if (mode == OptimizationMode::kNone) {
return;
}

// Sharing buffers is only legal in training mode because it kills the
// backprop.
if (mode == OptimizationMode::kInfer) {
shareBuffers(M);
M.verify();
}

// Remove unused allocations.
deleteDeadAllocs(M);
M.verify();

// Shorten the lifetime of buffers.
hoistDealloc(M);
M.verify();
}
Loading