From 6875e7781e1dc79eed8f118a8621f9d4f4a76235 Mon Sep 17 00:00:00 2001 From: Nick Gibson Date: Tue, 22 Jan 2019 10:54:55 -0800 Subject: [PATCH] Remove IRFunction requirement from CollectConstants --- include/glow/Backends/BackendUtils.h | 2 ++ include/glow/Backends/CompiledFunction.h | 7 +++--- include/glow/Graph/Graph.h | 4 +-- lib/Backends/BackendUtils.cpp | 25 +++++++++++++------ lib/Backends/CPU/CPUBackend.cpp | 4 ++- lib/Backends/CPU/CPUDeviceManager.cpp | 2 +- lib/Backends/CPU/CPUFunction.cpp | 5 ++-- lib/Backends/CPU/CPUFunction.h | 5 ++-- lib/Backends/Interpreter/Interpreter.cpp | 3 ++- .../Interpreter/InterpreterFunction.cpp | 10 +++----- .../Interpreter/InterpreterFunction.h | 6 ++--- lib/Backends/OpenCL/OpenCL.cpp | 7 +++--- lib/Backends/OpenCL/OpenCL.h | 8 +++--- lib/Graph/Graph.cpp | 4 +-- tests/unittests/BackendTest.cpp | 3 +++ 15 files changed, 55 insertions(+), 40 deletions(-) diff --git a/include/glow/Backends/BackendUtils.h b/include/glow/Backends/BackendUtils.h index d7ac1919a5..4198bbd46b 100644 --- a/include/glow/Backends/BackendUtils.h +++ b/include/glow/Backends/BackendUtils.h @@ -66,6 +66,8 @@ class RuntimeBundle { /// given function \p F and and copies weights to their address as specified /// by offsets contained in symbolTable_. void collectConstants(const IRFunction *F); + void collectConstants(const Module *M); + RuntimeBundle() = default; RuntimeBundle(std::unordered_map &symbolTable, size_t constWeight, size_t mutableWeight, size_t activations) diff --git a/include/glow/Backends/CompiledFunction.h b/include/glow/Backends/CompiledFunction.h index 27483a03ba..e32aa42422 100644 --- a/include/glow/Backends/CompiledFunction.h +++ b/include/glow/Backends/CompiledFunction.h @@ -55,9 +55,10 @@ class CompiledFunction { virtual void tearDownRuns() { runsSetup_ = false; } /// Getter for the runtimeBundle. - const runtime::RuntimeBundle &getRuntimeBundle() const { - return runtimeBundle_; - } + runtime::RuntimeBundle &getRuntimeBundle() { return runtimeBundle_; } + + /// Collects constants for runtime. + virtual void collectConstants(Module *){}; protected: /// Flag to ensure setupRuns is only called once. diff --git a/include/glow/Graph/Graph.h b/include/glow/Graph/Graph.h index ad7c9a1718..6f5ce8f414 100644 --- a/include/glow/Graph/Graph.h +++ b/include/glow/Graph/Graph.h @@ -112,7 +112,7 @@ class Module final { /// \returns a pointer to the first variable with the name \p name or nullptr /// if no node has this name. - Constant *getConstantByName(llvm::StringRef name); + Constant *getConstantByName(llvm::StringRef name) const; /// \returns the list of constants that the Module owns. ConstList &getConstants() { return constants_; } @@ -126,7 +126,7 @@ class Module final { /// \returns a pointer to the placeholder with the name \p name or /// nullptr if no placeholder has this name. - Placeholder *getPlaceholderByName(llvm::StringRef name); + Placeholder *getPlaceholderByName(llvm::StringRef name) const; /// @name High-level Variable builders. ///@{ diff --git a/lib/Backends/BackendUtils.cpp b/lib/Backends/BackendUtils.cpp index 077c21df69..159936f48f 100644 --- a/lib/Backends/BackendUtils.cpp +++ b/lib/Backends/BackendUtils.cpp @@ -21,7 +21,10 @@ using llvm::cast; using llvm::isa; void glow::runtime::RuntimeBundle::collectConstants(const IRFunction *F) { + collectConstants(F->getGraph()->getParent()); +} +void glow::runtime::RuntimeBundle::collectConstants(const Module *M) { // At compile time condense constants to a single block of memory. // This allows the graph to go away after compile time. // If there are no constants return nullptr. @@ -29,16 +32,24 @@ void glow::runtime::RuntimeBundle::collectConstants(const IRFunction *F) { constants_ = nullptr; return; } + constants_ = (uint8_t *)alignedAlloc(constantWeightVarsMemSize_, TensorAlignment); - for (auto &v : F->getGraph()->getParent()->getConstants()) { - assert(isa(F->getWeightForNode(v))); - auto *w = cast(F->getWeightForNode(v)); - auto payload = v->getPayload().getUnsafePtr(); - auto numBytes = w->getSizeInBytes(); - auto addr = getValueOffset(v); + + for (const auto &symbol : symbolTable_) { + llvm::StringRef name = symbol.first; + const RuntimeSymbolInfo &info = symbol.second; + + Constant *c = M->getConstantByName(name); + if (!c) { + continue; + } + auto *payload = c->getPayload().getUnsafePtr(); + assert(info.size == c->getPayload().getSizeInBytes() && + "Mismatched constant size"); + // Copy weight to offset. - memcpy(constants_ + addr, payload, numBytes); + memcpy(constants_ + info.offset, payload, info.size); } } diff --git a/lib/Backends/CPU/CPUBackend.cpp b/lib/Backends/CPU/CPUBackend.cpp index db8bee6b80..41833465f7 100644 --- a/lib/Backends/CPU/CPUBackend.cpp +++ b/lib/Backends/CPU/CPUBackend.cpp @@ -100,7 +100,9 @@ CPUBackend::createIRGen(IRFunction *IR, std::unique_ptr CPUBackend::compileIR(std::unique_ptr IR) const { auto function = compileIRWithoutConstants(IR.get()); - static_cast(function.get())->collectConstants(IR.get()); + static_cast(function.get()) + ->getRuntimeBundle() + .collectConstants(IR.get()); return function; } diff --git a/lib/Backends/CPU/CPUDeviceManager.cpp b/lib/Backends/CPU/CPUDeviceManager.cpp index 37e33514c1..26b1dc5ff4 100644 --- a/lib/Backends/CPU/CPUDeviceManager.cpp +++ b/lib/Backends/CPU/CPUDeviceManager.cpp @@ -52,7 +52,7 @@ void CPUDeviceManager::addNetworkImpl(const Module *module, // Add to the function name lookup map. for (const auto &func : functions) { - // TODO: collect constants here when available. + func.second->getRuntimeBundle().collectConstants(module); functions_.emplace(func.first, func.second); } diff --git a/lib/Backends/CPU/CPUFunction.cpp b/lib/Backends/CPU/CPUFunction.cpp index ffc36f10f8..b6d37e4349 100644 --- a/lib/Backends/CPU/CPUFunction.cpp +++ b/lib/Backends/CPU/CPUFunction.cpp @@ -30,8 +30,8 @@ CPUFunction::~CPUFunction() { tearDownRuns(); } -void CPUFunction::collectConstants(IRFunction *F) { - runtimeBundle_.collectConstants(F); +void CPUFunction::collectConstants(Module *module) { + runtimeBundle_.collectConstants(module); } void CPUFunction::loadPlaceholders(Context *ctx, @@ -61,7 +61,6 @@ void CPUFunction::updatePlaceholders(Context *ctx, } void CPUFunction::execute(Context *ctx) { - /// Base address for Activations memory block. uint8_t *baseActivationsAddress{nullptr}; /// Base address for Mutable weights memory block, Inputs and Outputs. diff --git a/lib/Backends/CPU/CPUFunction.h b/lib/Backends/CPU/CPUFunction.h index d35277f39b..96e220fa23 100644 --- a/lib/Backends/CPU/CPUFunction.h +++ b/lib/Backends/CPU/CPUFunction.h @@ -33,13 +33,12 @@ class CPUFunction final : public CompiledFunction { CPUFunction(std::unique_ptr JIT, const runtime::RuntimeBundle &runtimeBundle); - /// Collects constants for runtime. - void collectConstants(IRFunction *F); - /// \name CompiledFunction interface ///@{ ~CPUFunction() override; void execute(Context *ctx) override; + + void collectConstants(Module *module) override; ///@} private: /// Load constant tensors from \p ctx into \p weightsAddress, as defined by diff --git a/lib/Backends/Interpreter/Interpreter.cpp b/lib/Backends/Interpreter/Interpreter.cpp index 0c08e41015..62af91c72c 100644 --- a/lib/Backends/Interpreter/Interpreter.cpp +++ b/lib/Backends/Interpreter/Interpreter.cpp @@ -37,9 +37,10 @@ Interpreter::compileWithoutConstants(Function *F) const { std::unique_ptr Interpreter::compileIR(std::unique_ptr IR) const { + auto *mod = IR->getGraph()->getParent(); auto function = compileIRWithoutConstants(std::move(IR)); auto IFunction = static_cast(function.get()); - IFunction->collectConstants(IFunction->getIR()); + IFunction->collectConstants(mod); return function; } diff --git a/lib/Backends/Interpreter/InterpreterFunction.cpp b/lib/Backends/Interpreter/InterpreterFunction.cpp index 6ab0f503bc..64b2f465db 100644 --- a/lib/Backends/Interpreter/InterpreterFunction.cpp +++ b/lib/Backends/Interpreter/InterpreterFunction.cpp @@ -22,7 +22,6 @@ #include "llvm/Support/Casting.h" -#include "llvm/Support/raw_ostream.h" using namespace glow; InterpreterFunction::InterpreterFunction(std::unique_ptr F, @@ -39,11 +38,11 @@ InterpreterFunction::~InterpreterFunction() { tearDownRuns(); } -void InterpreterFunction::collectConstants(IRFunction *F) { - runtimeBundle_.collectConstants(F); +void InterpreterFunction::collectConstants(Module *module) { + runtimeBundle_.collectConstants(module); if (constants_.empty()) { if (runtimeBundle_.getConstantWeightSize()) { - for (const auto &v : F_->getGraph()->getParent()->getConstants()) { + for (const auto &v : module->getConstants()) { auto symbolInfo = runtimeBundle_.getSymbolInfo(v); auto addr = runtimeBundle_.getConstants() + symbolInfo.offset; auto tensor = new Tensor(addr, &symbolInfo.type); @@ -54,9 +53,6 @@ void InterpreterFunction::collectConstants(IRFunction *F) { } void InterpreterFunction::execute(Context *ctx) { - if (constants_.empty()) { - collectConstants(F_.get()); - } BoundInterpreterFunction boundFunc(constants_); boundFunc.execute(F_.get(), ctx); } diff --git a/lib/Backends/Interpreter/InterpreterFunction.h b/lib/Backends/Interpreter/InterpreterFunction.h index 8ac6832e1a..b956bc3a25 100644 --- a/lib/Backends/Interpreter/InterpreterFunction.h +++ b/lib/Backends/Interpreter/InterpreterFunction.h @@ -57,11 +57,11 @@ class InterpreterFunction final : public CompiledFunction { ///@{ ~InterpreterFunction() override; - /// Collects constants for runtime. - void collectConstants(IRFunction *F); - void execute(Context *ctx) override; + /// Collects constants for runtime. + void collectConstants(Module *module) override; + /// Get reference to IR function. IRFunction *getIR() { return F_.get(); } ///@} diff --git a/lib/Backends/OpenCL/OpenCL.cpp b/lib/Backends/OpenCL/OpenCL.cpp index 6bfef00d11..d392261463 100644 --- a/lib/Backends/OpenCL/OpenCL.cpp +++ b/lib/Backends/OpenCL/OpenCL.cpp @@ -1527,14 +1527,15 @@ cl_mem OpenCLFunction::allocDeviceBuffer(uint64_t size) { void OpenCLFunction::freeDeviceBuffer(cl_mem buf) { clReleaseMemObject(buf); } -void OpenCLFunction::collectConstants(IRFunction *F) { - runtimeBundle_.collectConstants(F); +void OpenCLFunction::collectConstants(Module *module) { + runtimeBundle_.collectConstants(module); } std::unique_ptr OCLBackend::compileIR(std::unique_ptr IR) const { + auto *module = IR->getGraph()->getParent(); auto function = compileIRWithoutConstants(std::move(IR)); auto OCLFunction = static_cast(function.get()); - OCLFunction->collectConstants(OCLFunction->getIR()); + OCLFunction->collectConstants(module); return function; } diff --git a/lib/Backends/OpenCL/OpenCL.h b/lib/Backends/OpenCL/OpenCL.h index 83d5e31fd7..9e8758a51a 100644 --- a/lib/Backends/OpenCL/OpenCL.h +++ b/lib/Backends/OpenCL/OpenCL.h @@ -97,7 +97,6 @@ class OpenCLFunction final : public CompiledFunction { ~OpenCLFunction() override; void execute(Context *ctx) override; - ///@} /// Allocates on device buffer and copies Constant weights to device. void setupRuns() override; /// Per run setup, copies Inputs from \p ctx to on device memory. @@ -107,12 +106,13 @@ class OpenCLFunction final : public CompiledFunction { /// Final cleanup, currently an empty function in OpenCL. void tearDownRuns() override; + /// Collects constants for runtime. + void collectConstants(Module *module) override; + ///@} + /// Returns IR function pointer. IRFunction *getIR() { return F_.get(); } - /// Collects constants for runtime. - void collectConstants(IRFunction *F); - private: /// Copy the value from a device to a provided buffer. /// \returns number of copied bytes. diff --git a/lib/Graph/Graph.cpp b/lib/Graph/Graph.cpp index d9f0c05c35..76ee8086ec 100644 --- a/lib/Graph/Graph.cpp +++ b/lib/Graph/Graph.cpp @@ -2368,7 +2368,7 @@ void Module::eraseConstant(ConstList::iterator I) { void Function::eraseNode(NodesList::iterator I) { nodes_.erase(I); } -Constant *Module::getConstantByName(llvm::StringRef name) { +Constant *Module::getConstantByName(llvm::StringRef name) const { for (auto *V : getConstants()) { if (V->getName() == name) return V; @@ -2376,7 +2376,7 @@ Constant *Module::getConstantByName(llvm::StringRef name) { return nullptr; } -Placeholder *Module::getPlaceholderByName(llvm::StringRef name) { +Placeholder *Module::getPlaceholderByName(llvm::StringRef name) const { for (auto *P : getPlaceholders()) { if (P->getName() == name) { return P; diff --git a/tests/unittests/BackendTest.cpp b/tests/unittests/BackendTest.cpp index cc4d0682e8..6d47983771 100644 --- a/tests/unittests/BackendTest.cpp +++ b/tests/unittests/BackendTest.cpp @@ -200,6 +200,9 @@ TEST_P(BackendTest, decoupleCodegenFromGraph) { auto *saveTensor = ctx.allocate(save->getPlaceholder()); EE_.compile(CompilationMode::Infer, F); + // Collect constants to fill out the RuntimeBundle. + EE_.getCompiledFunction().collectConstants(&mod); + // Erase all of the functions to ensure that the compiled code does not // depend on the graph. mod.eraseFunctions();