From d06ea78e2a151d422a6fb3cec05264f7f8582e71 Mon Sep 17 00:00:00 2001 From: Nadav Rotem Date: Mon, 2 Oct 2017 21:09:04 -0700 Subject: [PATCH 1/2] IR: optimize memory buffers by reusing tensor buffers of the same type. This brigs for a major memory use reduction. It brings the resnet50 model from 600Mb to 400Mb. This commit includes a few changes that had to go in together: 1. Change the optimizer interface to include the kind of optimization level: kNone, kTrain, kInfer. 2. Add a method to all of the instruction that describes if they can share the memory buffers. 3. A new optimization for sharing the memory buffers. 4. A new dead-buffer optimization. 5. Fix the C2 loader that did not construct the module in the right order. 6. Add a new method to UseDef::getNumUsers. --- examples/cifar10.cpp | 4 +- examples/mnist.cpp | 4 +- include/glow/IR/IR.h | 7 ++ include/glow/IR/Instrs.h | 17 ++- include/glow/IR/UseDef.h | 5 + include/glow/Interpreter/Interpreter.h | 3 +- include/glow/Optimizer/Optimizer.h | 8 +- src/glow/IR/IR.cpp | 13 ++- src/glow/Importer/Caffe2.cpp | 3 +- src/glow/Interpreter/Interpreter.cpp | 4 +- src/glow/Optimizer/Optimizer.cpp | 154 ++++++++++++++++++++++++- tests/unittests/IRGradCheck.cpp | 13 +-- tests/unittests/IRTest.cpp | 2 - tests/unittests/InterpreterTest.cpp | 15 ++- tools/loader/loader.cpp | 4 +- 15 files changed, 225 insertions(+), 31 deletions(-) diff --git a/examples/cifar10.cpp b/examples/cifar10.cpp index cba99c60b0..3018ac26fd 100644 --- a/examples/cifar10.cpp +++ b/examples/cifar10.cpp @@ -99,10 +99,8 @@ void testCIFAR10() { result = bb.createReturnOp(*SM); } - IP.optimize(); + IP.optimize(OptimizationMode::kTrain); IP.initVars(); - IP.getModule().dump(); - IP.getModule().dumpDAG(); // Report progress every this number of training iterations. int reportRate = 256; diff --git a/examples/mnist.cpp b/examples/mnist.cpp index f36c8f4d69..02df1a5672 100644 --- a/examples/mnist.cpp +++ b/examples/mnist.cpp @@ -92,9 +92,8 @@ void testMNIST() { result = bb.createReturnOp(*SM); } - IP.optimize(); + IP.optimize(OptimizationMode::kTrain); IP.initVars(); - IP.getModule().dumpDAG(); // Report progress every this number of training iterations. constexpr int reportRate = 30; @@ -110,6 +109,7 @@ void testMNIST() { IP.train(reportRate, {A, selected}, {&imageInputs, &labelInputs}); } std::cout << "Validating.\n"; + IP.optimize(OptimizationMode::kInfer); auto LIH = labelInputs.getHandle(); diff --git a/include/glow/IR/IR.h b/include/glow/IR/IR.h index 34f1c170fb..5867e00212 100644 --- a/include/glow/IR/IR.h +++ b/include/glow/IR/IR.h @@ -129,6 +129,11 @@ class Instruction : public Value { } } + /// \returns true if the @In arguments can share a buffer with the @Out + /// arguments. This happens when the read and write access for the buffer are + /// the same. + bool mayShareBuffers() const { return true; } + /// When printing the instruction this method prints the extra metadata. std::string getExtraDesc() const { return ""; } @@ -148,6 +153,8 @@ class Instruction : public Value { void verify() const; operator Value *() const { return getOperand(0).first; } + + static bool mayShareBuffers(const Instruction *I); }; class WeightVar; diff --git a/include/glow/IR/Instrs.h b/include/glow/IR/Instrs.h index 55bd5d48fa..01eefb8ee2 100644 --- a/include/glow/IR/Instrs.h +++ b/include/glow/IR/Instrs.h @@ -3,6 +3,7 @@ #include "glow/IR/IR.h" #include "glow/IR/Type.h" +#include "glow/Support/Casting.h" namespace glow { @@ -23,13 +24,17 @@ class DeallocActivationInst : public Instruction { public: DeallocActivationInst(Value *src) : Instruction(Kinded::Kind::DeallocActivationInstKind, src->getType(), - {{src, OperandKind::kIn}}) {} + {{src, OperandKind::kOut}}) {} static bool classof(const Kinded *k) { return k->getKind() == Kinded::Kind::DeallocActivationInstKind; } void verify() const; + + AllocActivationInst *getAlloc() const { + return cast(getOperand(0).first); + } }; class CopyInst : public Instruction { @@ -66,6 +71,9 @@ class ConvolutionInst : public Instruction { static bool classof(const Kinded *k) { return k->getKind() == Kinded::Kind::ConvolutionInstKind; } + + bool mayShareBuffers() const { return false; } + std::string getExtraDesc() const; Value *getDest() const { return getOperand(0).first; } Value *getSrc() const { return getOperand(1).first; } @@ -107,6 +115,8 @@ class PoolInst : public Instruction { static bool classof(const Kinded *k) { return k->getKind() == Kinded::Kind::PoolInstKind; } + + bool mayShareBuffers() const { return false; } std::string getExtraDesc() const; Value *getDest() const { return getOperand(0).first; } Value *getSrc() const { return getOperand(1).first; } @@ -135,6 +145,8 @@ class FullyConnectedInst : public Instruction { static bool classof(const Kinded *k) { return k->getKind() == Kinded::Kind::FullyConnectedInstKind; } + + bool mayShareBuffers() const { return false; } std::string getExtraDesc() const; Value *getDest() const { return getOperand(0).first; } Value *getSrc() const { return getOperand(1).first; } @@ -234,6 +246,8 @@ class TransposeInst : public Instruction { static bool classof(const Kinded *k) { return k->getKind() == Kinded::Kind::TransposeInstKind; } + + bool mayShareBuffers() const { return false; } std::string getExtraDesc() const; Value *getDest() const { return getOperand(0).first; } Value *getSrc() const { return getOperand(1).first; } @@ -280,6 +294,7 @@ class ConcatInst : public Instruction { static bool classof(const Kinded *k) { return k->getKind() == Kinded::Kind::ConcatInstKind; } + bool mayShareBuffers() const { return false; } std::string getExtraDesc() const; Value *getDest() const { return getOperand(0).first; } Value *getSrc() const { return getOperand(1).first; } diff --git a/include/glow/IR/UseDef.h b/include/glow/IR/UseDef.h index 15af29e34b..63b2bad4dc 100644 --- a/include/glow/IR/UseDef.h +++ b/include/glow/IR/UseDef.h @@ -32,6 +32,11 @@ template class UseDef { /// \returns True if the value has some users. bool hasUsers() const { return users_.size(); } + /// \returns the number of users that the value has. + unsigned getNumUsers() const { + return std::distance(users_.begin(), users_.end()); + } + /// Returns true if the user \p I is in the list. bool hasUser(const UserTy *I) const { for (auto &U : users_) { diff --git a/include/glow/Interpreter/Interpreter.h b/include/glow/Interpreter/Interpreter.h index 2fdb657894..f3697aeeeb 100644 --- a/include/glow/Interpreter/Interpreter.h +++ b/include/glow/Interpreter/Interpreter.h @@ -5,6 +5,7 @@ #include "glow/IR/IRBuilder.h" #include "glow/Network/Tensor.h" #include "glow/Network/Train.h" +#include "glow/Optimizer/Optimizer.h" #include @@ -32,7 +33,7 @@ class Interpreter final { Module &getModule() { return M_; } /// Run the target-independent optimizations on the module. - void optimize(); + void optimize(OptimizationMode mode); /// Ctor. Interpreter(); diff --git a/include/glow/Optimizer/Optimizer.h b/include/glow/Optimizer/Optimizer.h index c6983152b4..ef5401d0f4 100644 --- a/include/glow/Optimizer/Optimizer.h +++ b/include/glow/Optimizer/Optimizer.h @@ -5,7 +5,13 @@ namespace glow { class Module; -void optimize(Module &M); +enum class OptimizationMode { + kNone, // Don't optimize the module. + kTrain, // Optimize the module but allow training. + kInfer, // Optimize the module and break training. +}; + +void optimize(Module &M, OptimizationMode mode); } // namespace glow diff --git a/src/glow/IR/IR.cpp b/src/glow/IR/IR.cpp index 4b54406d3c..803fc0d3a7 100644 --- a/src/glow/IR/IR.cpp +++ b/src/glow/IR/IR.cpp @@ -98,12 +98,23 @@ static std::string getExtraDesc(const Kinded *K) { if (auto *X = dyn_cast(K)) \ return X->getExtraDesc(); #include "glow/IR/Instrs.def" -#undef DEF_INSTRE +#undef DEF_INSTR #undef DEF_VALUE glow_unreachable(); } +bool Instruction::mayShareBuffers(const Instruction *I) { +#define DEF_INSTR(CLASS, NAME) \ + if (auto *X = dyn_cast(I)) \ + return X->mayShareBuffers(); +#define DEF_VALUE(CLASS, NAME) +#include "glow/IR/Instrs.def" +#undef DEF_INSTR +#undef DEF_VALUE + glow_unreachable(); +} + static std::string getDesc(const Value *v) { std::string sb; std::string name = v->getName(); diff --git a/src/glow/Importer/Caffe2.cpp b/src/glow/Importer/Caffe2.cpp index 62b23b6e5e..506fe06db7 100644 --- a/src/glow/Importer/Caffe2.cpp +++ b/src/glow/Importer/Caffe2.cpp @@ -405,8 +405,9 @@ caffe2ModelLoader::caffe2ModelLoader(const std::string &netDescFilename, loadProtoFile(weightsDef, netWeightFilename); loadWeights(weightsDef); loadNetwork(networkDef); - builder_.deallocateActiveInstrs(); // Save the result of the last operator into a weight. root_ = builder_.createReturnOp(root_); + + builder_.deallocateActiveInstrs(); } diff --git a/src/glow/Interpreter/Interpreter.cpp b/src/glow/Interpreter/Interpreter.cpp index b9357c235c..1a08bff1ab 100644 --- a/src/glow/Interpreter/Interpreter.cpp +++ b/src/glow/Interpreter/Interpreter.cpp @@ -19,7 +19,9 @@ Interpreter::~Interpreter() { } } -void Interpreter::optimize() { ::glow::optimize(M_); } +void Interpreter::optimize(OptimizationMode mode) { + ::glow::optimize(M_, mode); +} void Interpreter::registerTensor(Value *v, Tensor *t) { assert(t->getType().isEqual(v->getType()) && diff --git a/src/glow/Optimizer/Optimizer.cpp b/src/glow/Optimizer/Optimizer.cpp index 7f1566a2cc..6ce3f12615 100644 --- a/src/glow/Optimizer/Optimizer.cpp +++ b/src/glow/Optimizer/Optimizer.cpp @@ -4,6 +4,7 @@ #include "glow/Support/Casting.h" #include +#include using namespace glow; @@ -46,10 +47,159 @@ static void hoistDealloc(Module &M) { } } -void glow::optimize(Module &M) { +/// Delete alloc instructions that have no readers or writers. +static void deleteDeadAllocs(Module &M) { + auto &instrs = M.getInstrs(); + + // C++ does not have a clean way to erase reverse iterators, so we'll need to + // collect a set of instructions to delete. + std::unordered_set toDelete; + + // Collect pairs of alloc-dealloc to remove. + for (auto it = instrs.begin(), e = instrs.end(); it != e; ++it) { + if (auto *da = dyn_cast(*it)) { + auto *alloc = da->getAlloc(); + // Delete the dealloc, if this is the only user of the alloc. + if (alloc->getNumUsers() < 2) { + toDelete.insert(da); + continue; + } + } + // Erase allocs with no users. + if (auto *alloc = dyn_cast(*it)) { + if (alloc->getNumUsers() < 2) { + toDelete.insert(alloc); + continue; + } + } + } + + // Delete the instructions. + for (auto it = instrs.begin(), e = instrs.end(); it != e; + /* nop */) { + if (toDelete.count(*it)) { + it = instrs.erase(it); + continue; + } + it++; + } +} + +// Replace all users of some value with another value, but don't touch the +// dealloc instruction, because we need to preserve the well formdness of the +// IR. +static void replaceAllNonDeallocUsersWith(Value *val, Value *with) { + assert(val != with && "Replacing value with self"); + auto &lst = val->getUsers(); + std::vector users(lst.begin(), lst.end()); + for (auto &U : users) { + // Ignore dealloc instrs. + if (isa(U.second)) { + continue; + } + + U.second->setOperand(U.first, with); + } +} + +/// Optimize the input/output buffer for the instruction \p I, based on the +/// liveness information in \p liveBuffers. +static void shareBuffersForInstr(const std::unordered_set &liveBuffers, + Instruction *I) { + // At this point variables are marked as dead, and variables have + // not been marked alive yet. + for (unsigned first = 0, e = I->getNumOperands(); first < e; first++) { + for (unsigned second = first + 1; second < e; second++) { + auto destOp = I->getOperand(first); + auto srcOp = I->getOperand(second); + + // Operands must be different, but of the same type. + if (destOp.first->getType() != srcOp.first->getType() || + destOp.first == srcOp.first) { + continue; + } + + // If both the src and the dest operands are dead, this means that we can + // reuse the buffer storage. + if (!liveBuffers.count(destOp.first) && !liveBuffers.count(srcOp.first)) { + replaceAllNonDeallocUsersWith(destOp.first, srcOp.first); + return; + } + } + } +} + +static void shareBuffers(Module &M) { + auto &instrs = M.getInstrs(); + std::unordered_set liveBuffers; + + // All the weights are alive, because they are persistent. + for (auto *W : M.getWeights()) { + liveBuffers.insert(W); + } + + // For each instruction, in reverse order. + for (auto it = instrs.rbegin(), e = instrs.rend(); it != e; ++it) { + Instruction *I = *it; + + // Remove dependencies from the live set, because this instruction + // writes into them. + for (unsigned op = 0, ope = I->getNumOperands(); op < ope; op++) { + auto O = I->getOperand(op); + auto ai = dyn_cast(O.first); + + // dependency means that the buffer is being killed. Remove from the + // live list. + if (ai && O.second == OperandKind::kOut) { + auto it = liveBuffers.find(ai); + if (it != liveBuffers.end()) { + liveBuffers.erase(it); + } + continue; + } + // The means that the value of the buffer is being consumed, + // which means that it is alive. Add to the live set. + if (ai && O.second == OperandKind::kInOut) { + liveBuffers.insert(ai); + } + } + + if (Instruction::mayShareBuffers(I)) + shareBuffersForInstr(liveBuffers, I); + + // Insert the input buffers into the live set. + for (unsigned op = 0, ope = I->getNumOperands(); op < ope; op++) { + auto O = I->getOperand(op); + auto ai = dyn_cast(O.first); + + // The means that the value of the buffer is being consumed, + // which means that it is alive. Add to the live set. + if (ai && O.second != OperandKind::kOut) { + liveBuffers.insert(ai); + } + } + } +} + +void glow::optimize(Module &M, OptimizationMode mode) { M.verify(); - hoistDealloc(M); + if (mode == OptimizationMode::kNone) { + return; + } + + // Sharing buffers is only legal in training mode because it kills the + // backprop. + if (mode == OptimizationMode::kInfer) { + shareBuffers(M); + M.verify(); + } + // Remove unused allocations. + deleteDeadAllocs(M); + M.verify(); + + // Shorten the lifetime of buffers. + hoistDealloc(M); M.verify(); } diff --git a/tests/unittests/IRGradCheck.cpp b/tests/unittests/IRGradCheck.cpp index 0707ffeed8..13ad1196a3 100644 --- a/tests/unittests/IRGradCheck.cpp +++ b/tests/unittests/IRGradCheck.cpp @@ -103,7 +103,7 @@ TEST(Network, gradientCheck_FC_Concat_RELU) { result = bb.createReturnOp(*O); } - IP.optimize(); + IP.optimize(OptimizationMode::kTrain); IP.initVars(); Tensor inputs(ElemKind::FloatTy, {{1, numInputElem}}); @@ -142,7 +142,7 @@ TEST(Network, gradientCheck_Conv) { result = bb.createReturnOp(*O); } - IP.optimize(); + IP.optimize(OptimizationMode::kTrain); IP.initVars(); Tensor inputs(ElemKind::FloatTy, {1, numDim, numDim, 1}); @@ -180,7 +180,6 @@ TEST(Network, gradientCheck_AvgPool) { } IP.getModule().verify(); - IP.getModule().dump(); IP.initVars(); Tensor inputs(ElemKind::FloatTy, {1, numDim, numDim, 1}); @@ -217,7 +216,7 @@ TEST(Network, gradientCheck_batchNorm) { result = bb.createReturnOp(*O); } - IP.optimize(); + IP.optimize(OptimizationMode::kTrain); IP.initVars(); Tensor inputs(ElemKind::FloatTy, {1, numDim, numDim, 3}); @@ -263,7 +262,7 @@ TEST(Network, gradientCheck_Arithmetic) { result = bb.createReturnOp(*O); } - IP.optimize(); + IP.optimize(OptimizationMode::kTrain); IP.initVars(); Tensor iA(ElemKind::FloatTy, {1, numDim}); @@ -351,7 +350,7 @@ TEST(Network, gradientCheck_FC_Concat_Tanh) { result = bb.createReturnOp(*FA); } - IP.optimize(); + IP.optimize(OptimizationMode::kTrain); IP.initVars(); Tensor inputs(ElemKind::FloatTy, {{1, numInputElem}}); @@ -386,7 +385,7 @@ TEST(Network, gradientCheck_Transpose) { result = bb.createReturnOp(*TA); } - IP.optimize(); + IP.optimize(OptimizationMode::kTrain); IP.initVars(); Tensor inputs(ElemKind::FloatTy, {1, 5, 10, 15}); diff --git a/tests/unittests/IRTest.cpp b/tests/unittests/IRTest.cpp index 207a7a94bb..19532575d5 100644 --- a/tests/unittests/IRTest.cpp +++ b/tests/unittests/IRTest.cpp @@ -109,7 +109,6 @@ TEST(IR, allInstrs) { builder.createArithmeticInst(I1, I0, I0, ArithmeticInst::OpKind::kMul); } M.verify(); - M.dump(); } TEST(IR, highLevelBuilder) { @@ -136,7 +135,6 @@ TEST(IR, highLevelBuilder) { (void)rshp; (void)aa; } - M.dump(); M.verify(); } diff --git a/tests/unittests/InterpreterTest.cpp b/tests/unittests/InterpreterTest.cpp index 948e50c834..50f3692882 100644 --- a/tests/unittests/InterpreterTest.cpp +++ b/tests/unittests/InterpreterTest.cpp @@ -40,7 +40,7 @@ TEST(Interpreter, interpret) { builder.createReturnOp(*SM); } - IP.optimize(); + IP.optimize(OptimizationMode::kInfer); IP.initVars(); IP.infer({input}, {&inputs}); } @@ -75,7 +75,7 @@ TEST(Interpreter, trainASimpleNetwork) { inputs.getHandle() = {0.15, 0.15, 0.15, 0.15}; expected.getHandle() = {0.9, 0.9, 0.9, 0.9}; - IP.optimize(); + IP.optimize(OptimizationMode::kTrain); IP.initVars(); // Train the network. Learn 1000 batches. @@ -83,6 +83,7 @@ TEST(Interpreter, trainASimpleNetwork) { // Testing the output vector. + IP.optimize(OptimizationMode::kInfer); IP.infer({A}, {&inputs}); auto RNWH = IP.getTensorForValue(result)->getHandle(); (void)RNWH; @@ -122,7 +123,7 @@ TEST(Interpreter, simpleRegression) { auto I = inputs.getHandle(); auto E = expected.getHandle(); - IP.optimize(); + IP.optimize(OptimizationMode::kTrain); IP.initVars(); // Train the network: @@ -194,7 +195,7 @@ TEST(Interpreter, learnXor) { TL.at({i, 0}) = a ^ b; } - IP.optimize(); + IP.optimize(OptimizationMode::kTrain); IP.initVars(); // Train the network: @@ -277,7 +278,7 @@ TEST(Network, circle) { result = bb.createReturnOp(*SM); } - IP.optimize(); + IP.optimize(OptimizationMode::kTrain); IP.initVars(); Tensor coordinates(ElemKind::FloatTy, {numSamples, 2}); @@ -377,12 +378,14 @@ TEST(Network, learnSingleValueConcat) { inputs.getHandle().clear(0.15); expected.getHandle().clear(0.9); - IP.optimize(); + IP.optimize(OptimizationMode::kTrain); IP.initVars(); // Train the network: IP.train(1000, {A, B, Ex}, {&inputs, &inputs, &expected}); + IP.optimize(OptimizationMode::kInfer); + // Testing the output vector. IP.infer({A}, {&inputs}); auto RNWH = IP.getTensorForValue(result)->getHandle(); diff --git a/tools/loader/loader.cpp b/tools/loader/loader.cpp index 9b69e2fcc8..44592d0331 100644 --- a/tools/loader/loader.cpp +++ b/tools/loader/loader.cpp @@ -78,10 +78,8 @@ int main(int argc, char **argv) { {"data", "gpu_0/data", "softmax_expected"}, {&data, &data, &expected_softmax}, IP); - IP.optimize(); + IP.optimize(OptimizationMode::kInfer); IP.initVars(); - IP.getModule().dump(); - IP.getModule().dumpDAG(); auto *SM = LD.getRoot(); Value *i0 = LD.getOrCreateNodeByName("gpu_0/data"); From c1a43b48eab8dfd877a7f032fb0923fa36f27bda Mon Sep 17 00:00:00 2001 From: Nadav Rotem Date: Tue, 3 Oct 2017 10:28:40 -0700 Subject: [PATCH 2/2] Address CR changes. --- include/glow/IR/UseDef.h | 8 ++-- src/glow/IR/IR.cpp | 6 +-- src/glow/Optimizer/Optimizer.cpp | 73 ++++++++++++++++---------------- 3 files changed, 42 insertions(+), 45 deletions(-) diff --git a/include/glow/IR/UseDef.h b/include/glow/IR/UseDef.h index 63b2bad4dc..e21aeca081 100644 --- a/include/glow/IR/UseDef.h +++ b/include/glow/IR/UseDef.h @@ -30,16 +30,14 @@ template class UseDef { void addUse(Use U) { users_.push_back(U); } /// \returns True if the value has some users. - bool hasUsers() const { return users_.size(); } + bool hasUsers() const { return !users_.empty(); } /// \returns the number of users that the value has. - unsigned getNumUsers() const { - return std::distance(users_.begin(), users_.end()); - } + unsigned getNumUsers() const { return users_.size(); } /// Returns true if the user \p I is in the list. bool hasUser(const UserTy *I) const { - for (auto &U : users_) { + for (const auto &U : users_) { if (U.second == I) return true; } diff --git a/src/glow/IR/IR.cpp b/src/glow/IR/IR.cpp index 803fc0d3a7..5b97de7fbf 100644 --- a/src/glow/IR/IR.cpp +++ b/src/glow/IR/IR.cpp @@ -92,10 +92,10 @@ void Module::verify() const { static std::string getExtraDesc(const Kinded *K) { #define DEF_INSTR(CLASS, NAME) \ - if (auto *X = dyn_cast(K)) \ + if (const auto *X = dyn_cast(K)) \ return X->getExtraDesc(); #define DEF_VALUE(CLASS, NAME) \ - if (auto *X = dyn_cast(K)) \ + if (const auto *X = dyn_cast(K)) \ return X->getExtraDesc(); #include "glow/IR/Instrs.def" #undef DEF_INSTR @@ -106,7 +106,7 @@ static std::string getExtraDesc(const Kinded *K) { bool Instruction::mayShareBuffers(const Instruction *I) { #define DEF_INSTR(CLASS, NAME) \ - if (auto *X = dyn_cast(I)) \ + if (const auto *X = dyn_cast(I)) \ return X->mayShareBuffers(); #define DEF_VALUE(CLASS, NAME) #include "glow/IR/Instrs.def" diff --git a/src/glow/Optimizer/Optimizer.cpp b/src/glow/Optimizer/Optimizer.cpp index 6ce3f12615..0a2fdf5b73 100644 --- a/src/glow/Optimizer/Optimizer.cpp +++ b/src/glow/Optimizer/Optimizer.cpp @@ -51,38 +51,27 @@ static void hoistDealloc(Module &M) { static void deleteDeadAllocs(Module &M) { auto &instrs = M.getInstrs(); - // C++ does not have a clean way to erase reverse iterators, so we'll need to - // collect a set of instructions to delete. - std::unordered_set toDelete; - - // Collect pairs of alloc-dealloc to remove. - for (auto it = instrs.begin(), e = instrs.end(); it != e; ++it) { - if (auto *da = dyn_cast(*it)) { - auto *alloc = da->getAlloc(); - // Delete the dealloc, if this is the only user of the alloc. - if (alloc->getNumUsers() < 2) { - toDelete.insert(da); - continue; - } - } - // Erase allocs with no users. - if (auto *alloc = dyn_cast(*it)) { - if (alloc->getNumUsers() < 2) { - toDelete.insert(alloc); - continue; - } - } - } - - // Delete the instructions. - for (auto it = instrs.begin(), e = instrs.end(); it != e; - /* nop */) { - if (toDelete.count(*it)) { - it = instrs.erase(it); - continue; - } - it++; - } + // Remove all of the DeallocActivationInst that close unused allocs. + instrs.erase( + std::remove_if(instrs.begin(), instrs.end(), + [](const Instruction *I) -> bool { + if (const auto *DA = + dyn_cast(I)) { + return DA->getAlloc()->getNumUsers() < 2; + } + return false; + }), + instrs.end()); + + // Remove the unused allocs. + instrs.erase(std::remove_if(instrs.begin(), instrs.end(), + [](const Instruction *I) -> bool { + if (isa(I)) { + return I->getNumUsers() < 2; + } + return false; + }), + std::end(instrs)); } // Replace all users of some value with another value, but don't touch the @@ -90,9 +79,11 @@ static void deleteDeadAllocs(Module &M) { // IR. static void replaceAllNonDeallocUsersWith(Value *val, Value *with) { assert(val != with && "Replacing value with self"); - auto &lst = val->getUsers(); - std::vector users(lst.begin(), lst.end()); - for (auto &U : users) { + auto &users = val->getUsers(); + // We use a vector here because changing the operands of the user changes the + // uselist, and this invalidates the iterator. + std::vector usersVec(users.begin(), users.end()); + for (auto &U : usersVec) { // Ignore dealloc instrs. if (isa(U.second)) { continue; @@ -148,9 +139,13 @@ static void shareBuffers(Module &M) { auto O = I->getOperand(op); auto ai = dyn_cast(O.first); + if (!ai) { + continue; + } + // dependency means that the buffer is being killed. Remove from the // live list. - if (ai && O.second == OperandKind::kOut) { + if (O.second == OperandKind::kOut) { auto it = liveBuffers.find(ai); if (it != liveBuffers.end()) { liveBuffers.erase(it); @@ -172,9 +167,13 @@ static void shareBuffers(Module &M) { auto O = I->getOperand(op); auto ai = dyn_cast(O.first); + if (!ai) { + continue; + } + // The means that the value of the buffer is being consumed, // which means that it is alive. Add to the live set. - if (ai && O.second != OperandKind::kOut) { + if (O.second != OperandKind::kOut) { liveBuffers.insert(ai); } }