Skip to content
This repository was archived by the owner on Jul 1, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/glow/Backends/BackendUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class RuntimeBundle {
/// given function \p F and and copies weights to their address as specified
/// by offsets contained in symbolTable_.
void collectConstants(const IRFunction *F);
void collectConstants(const Module *M);

RuntimeBundle() = default;
RuntimeBundle(std::unordered_map<std::string, RuntimeSymbolInfo> &symbolTable,
size_t constWeight, size_t mutableWeight, size_t activations)
Expand Down
7 changes: 4 additions & 3 deletions include/glow/Backends/CompiledFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,10 @@ class CompiledFunction {
virtual void tearDownRuns() { runsSetup_ = false; }

/// Getter for the runtimeBundle.
const runtime::RuntimeBundle &getRuntimeBundle() const {
return runtimeBundle_;
}
runtime::RuntimeBundle &getRuntimeBundle() { return runtimeBundle_; }

/// Collects constants for runtime.
virtual void collectConstants(Module *){};

protected:
/// Flag to ensure setupRuns is only called once.
Expand Down
4 changes: 2 additions & 2 deletions include/glow/Graph/Graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class Module final {

/// \returns a pointer to the first variable with the name \p name or nullptr
/// if no node has this name.
Constant *getConstantByName(llvm::StringRef name);
Constant *getConstantByName(llvm::StringRef name) const;

/// \returns the list of constants that the Module owns.
ConstList &getConstants() { return constants_; }
Expand All @@ -126,7 +126,7 @@ class Module final {

/// \returns a pointer to the placeholder with the name \p name or
/// nullptr if no placeholder has this name.
Placeholder *getPlaceholderByName(llvm::StringRef name);
Placeholder *getPlaceholderByName(llvm::StringRef name) const;

/// @name High-level Variable builders.
///@{
Expand Down
25 changes: 18 additions & 7 deletions lib/Backends/BackendUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,35 @@ using llvm::cast;
using llvm::isa;

void glow::runtime::RuntimeBundle::collectConstants(const IRFunction *F) {
collectConstants(F->getGraph()->getParent());
}

void glow::runtime::RuntimeBundle::collectConstants(const Module *M) {
// At compile time condense constants to a single block of memory.
// This allows the graph to go away after compile time.
// If there are no constants return nullptr.
if (constantWeightVarsMemSize_ == 0) {
constants_ = nullptr;
return;
}

constants_ =
(uint8_t *)alignedAlloc(constantWeightVarsMemSize_, TensorAlignment);
for (auto &v : F->getGraph()->getParent()->getConstants()) {
assert(isa<WeightVar>(F->getWeightForNode(v)));
auto *w = cast<WeightVar>(F->getWeightForNode(v));
auto payload = v->getPayload().getUnsafePtr();
auto numBytes = w->getSizeInBytes();
auto addr = getValueOffset(v);

for (const auto &symbol : symbolTable_) {
llvm::StringRef name = symbol.first;
const RuntimeSymbolInfo &info = symbol.second;

Constant *c = M->getConstantByName(name);
if (!c) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the situation when module does not have that constant?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a consequence of the RuntimeBundle overlap issue you commented on in #2274, the SymbolTable contains both Constants and Placeholders. We only want to copy the contents of the Constants though, obviously.

continue;
}
auto *payload = c->getPayload().getUnsafePtr();
assert(info.size == c->getPayload().getSizeInBytes() &&
"Mismatched constant size");

// Copy weight to offset.
memcpy(constants_ + addr, payload, numBytes);
memcpy(constants_ + info.offset, payload, info.size);
}
}

Expand Down
4 changes: 3 additions & 1 deletion lib/Backends/CPU/CPUBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ CPUBackend::createIRGen(IRFunction *IR,
std::unique_ptr<CompiledFunction>
CPUBackend::compileIR(std::unique_ptr<IRFunction> IR) const {
auto function = compileIRWithoutConstants(IR.get());
static_cast<CPUFunction *>(function.get())->collectConstants(IR.get());
static_cast<CPUFunction *>(function.get())
->getRuntimeBundle()
.collectConstants(IR.get());
return function;
}

Expand Down
2 changes: 1 addition & 1 deletion lib/Backends/CPU/CPUDeviceManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ void CPUDeviceManager::addNetworkImpl(const Module *module,

// Add to the function name lookup map.
for (const auto &func : functions) {
// TODO: collect constants here when available.
func.second->getRuntimeBundle().collectConstants(module);
functions_.emplace(func.first, func.second);
}

Expand Down
5 changes: 2 additions & 3 deletions lib/Backends/CPU/CPUFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ CPUFunction::~CPUFunction() {
tearDownRuns();
}

void CPUFunction::collectConstants(IRFunction *F) {
runtimeBundle_.collectConstants(F);
void CPUFunction::collectConstants(Module *module) {
runtimeBundle_.collectConstants(module);
}

void CPUFunction::loadPlaceholders(Context *ctx,
Expand Down Expand Up @@ -61,7 +61,6 @@ void CPUFunction::updatePlaceholders(Context *ctx,
}

void CPUFunction::execute(Context *ctx) {
/// Base address for Activations memory block.
uint8_t *baseActivationsAddress{nullptr};

/// Base address for Mutable weights memory block, Inputs and Outputs.
Expand Down
5 changes: 2 additions & 3 deletions lib/Backends/CPU/CPUFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,12 @@ class CPUFunction final : public CompiledFunction {
CPUFunction(std::unique_ptr<llvm::orc::GlowJIT> JIT,
const runtime::RuntimeBundle &runtimeBundle);

/// Collects constants for runtime.
void collectConstants(IRFunction *F);

/// \name CompiledFunction interface
///@{
~CPUFunction() override;
void execute(Context *ctx) override;

void collectConstants(Module *module) override;
///@}
private:
/// Load constant tensors from \p ctx into \p weightsAddress, as defined by
Expand Down
3 changes: 2 additions & 1 deletion lib/Backends/Interpreter/Interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ Interpreter::compileWithoutConstants(Function *F) const {

std::unique_ptr<CompiledFunction>
Interpreter::compileIR(std::unique_ptr<IRFunction> IR) const {
auto *mod = IR->getGraph()->getParent();
auto function = compileIRWithoutConstants(std::move(IR));
auto IFunction = static_cast<InterpreterFunction *>(function.get());
IFunction->collectConstants(IFunction->getIR());
IFunction->collectConstants(mod);
return function;
}

Expand Down
10 changes: 3 additions & 7 deletions lib/Backends/Interpreter/InterpreterFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

#include "llvm/Support/Casting.h"

#include "llvm/Support/raw_ostream.h"
using namespace glow;

InterpreterFunction::InterpreterFunction(std::unique_ptr<IRFunction> F,
Expand All @@ -39,11 +38,11 @@ InterpreterFunction::~InterpreterFunction() {
tearDownRuns();
}

void InterpreterFunction::collectConstants(IRFunction *F) {
runtimeBundle_.collectConstants(F);
void InterpreterFunction::collectConstants(Module *module) {
runtimeBundle_.collectConstants(module);
if (constants_.empty()) {
if (runtimeBundle_.getConstantWeightSize()) {
for (const auto &v : F_->getGraph()->getParent()->getConstants()) {
for (const auto &v : module->getConstants()) {
auto symbolInfo = runtimeBundle_.getSymbolInfo(v);
auto addr = runtimeBundle_.getConstants() + symbolInfo.offset;
auto tensor = new Tensor(addr, &symbolInfo.type);
Expand All @@ -54,9 +53,6 @@ void InterpreterFunction::collectConstants(IRFunction *F) {
}

void InterpreterFunction::execute(Context *ctx) {
if (constants_.empty()) {
collectConstants(F_.get());
}
BoundInterpreterFunction boundFunc(constants_);
boundFunc.execute(F_.get(), ctx);
}
Expand Down
6 changes: 3 additions & 3 deletions lib/Backends/Interpreter/InterpreterFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ class InterpreterFunction final : public CompiledFunction {
///@{
~InterpreterFunction() override;

/// Collects constants for runtime.
void collectConstants(IRFunction *F);

void execute(Context *ctx) override;

/// Collects constants for runtime.
void collectConstants(Module *module) override;

/// Get reference to IR function.
IRFunction *getIR() { return F_.get(); }
///@}
Expand Down
7 changes: 4 additions & 3 deletions lib/Backends/OpenCL/OpenCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1527,14 +1527,15 @@ cl_mem OpenCLFunction::allocDeviceBuffer(uint64_t size) {

void OpenCLFunction::freeDeviceBuffer(cl_mem buf) { clReleaseMemObject(buf); }

void OpenCLFunction::collectConstants(IRFunction *F) {
runtimeBundle_.collectConstants(F);
void OpenCLFunction::collectConstants(Module *module) {
runtimeBundle_.collectConstants(module);
}
std::unique_ptr<CompiledFunction>
OCLBackend::compileIR(std::unique_ptr<IRFunction> IR) const {
auto *module = IR->getGraph()->getParent();
auto function = compileIRWithoutConstants(std::move(IR));
auto OCLFunction = static_cast<OpenCLFunction *>(function.get());
OCLFunction->collectConstants(OCLFunction->getIR());
OCLFunction->collectConstants(module);
return function;
}

Expand Down
8 changes: 4 additions & 4 deletions lib/Backends/OpenCL/OpenCL.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ class OpenCLFunction final : public CompiledFunction {
~OpenCLFunction() override;

void execute(Context *ctx) override;
///@}
/// Allocates on device buffer and copies Constant weights to device.
void setupRuns() override;
/// Per run setup, copies Inputs from \p ctx to on device memory.
Expand All @@ -107,12 +106,13 @@ class OpenCLFunction final : public CompiledFunction {
/// Final cleanup, currently an empty function in OpenCL.
void tearDownRuns() override;

/// Collects constants for runtime.
void collectConstants(Module *module) override;
///@}

/// Returns IR function pointer.
IRFunction *getIR() { return F_.get(); }

/// Collects constants for runtime.
void collectConstants(IRFunction *F);

private:
/// Copy the value from a device to a provided buffer.
/// \returns number of copied bytes.
Expand Down
4 changes: 2 additions & 2 deletions lib/Graph/Graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2368,15 +2368,15 @@ void Module::eraseConstant(ConstList::iterator I) {

void Function::eraseNode(NodesList::iterator I) { nodes_.erase(I); }

Constant *Module::getConstantByName(llvm::StringRef name) {
Constant *Module::getConstantByName(llvm::StringRef name) const {
for (auto *V : getConstants()) {
if (V->getName() == name)
return V;
}
return nullptr;
}

Placeholder *Module::getPlaceholderByName(llvm::StringRef name) {
Placeholder *Module::getPlaceholderByName(llvm::StringRef name) const {
for (auto *P : getPlaceholders()) {
if (P->getName() == name) {
return P;
Expand Down
3 changes: 3 additions & 0 deletions tests/unittests/BackendTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ TEST_P(BackendTest, decoupleCodegenFromGraph) {
auto *saveTensor = ctx.allocate(save->getPlaceholder());
EE_.compile(CompilationMode::Infer, F);

// Collect constants to fill out the RuntimeBundle.
EE_.getCompiledFunction().collectConstants(&mod);

// Erase all of the functions to ensure that the compiled code does not
// depend on the graph.
mod.eraseFunctions();
Expand Down