-
Notifications
You must be signed in to change notification settings - Fork 699
[RFC] Refactor CPUFunction and InterpreterFunction to remove per-run state #2274
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,63 +30,55 @@ CPUFunction::~CPUFunction() { | |
tearDownRuns(); | ||
} | ||
|
||
void CPUFunction::setupRuns() { | ||
if (!runsSetup_) { | ||
if (runtimeBundle_.getActivationsSize() != 0) { | ||
baseActivationsAddress_ = (uint8_t *)alignedAlloc( | ||
runtimeBundle_.getActivationsSize(), TensorAlignment); | ||
} | ||
|
||
if (runtimeBundle_.getMutableWeightSize() != 0) { | ||
baseMutableWeightVarsAddress_ = (uint8_t *)alignedAlloc( | ||
runtimeBundle_.getMutableWeightSize(), TensorAlignment); | ||
} | ||
runsSetup_ = true; | ||
} | ||
} | ||
|
||
void CPUFunction::collectConstants(IRFunction *F) { | ||
runtimeBundle_.collectConstants(F); | ||
} | ||
|
||
void CPUFunction::beforeRun(const Context &ctx) { | ||
void CPUFunction::loadPlaceholders(Context *ctx, | ||
uint8_t *baseMutableWeightVarsAddress) { | ||
// Copy Placeholders into allocated memory. | ||
for (auto PH : ctx.pairs()) { | ||
for (auto PH : ctx->pairs()) { | ||
auto payload = PH.second->getUnsafePtr(); | ||
auto symbolInfo = runtimeBundle_.getSymbolInfo(PH.first); | ||
auto addr = symbolInfo.offset; | ||
auto numBytes = symbolInfo.size; | ||
// copy PH to allocated memory. | ||
memcpy(baseMutableWeightVarsAddress_ + addr, payload, numBytes); | ||
memcpy(baseMutableWeightVarsAddress + addr, payload, numBytes); | ||
} | ||
} | ||
|
||
void CPUFunction::afterRun(const Context &ctx) { | ||
void CPUFunction::updatePlaceholders(Context *ctx, | ||
uint8_t *baseMutableWeightVarsAddress) { | ||
// Copy placeholders from device back into context. | ||
for (auto PH : ctx.pairs()) { | ||
for (auto PH : ctx->pairs()) { | ||
auto symbolInfo = runtimeBundle_.getSymbolInfo(PH.first); | ||
auto payload = baseMutableWeightVarsAddress_ + symbolInfo.offset; | ||
auto payload = baseMutableWeightVarsAddress + symbolInfo.offset; | ||
auto numBytes = symbolInfo.size; | ||
auto addr = PH.second->getUnsafePtr(); | ||
// copy PH from allocated memory. | ||
memcpy(addr, payload, numBytes); | ||
} | ||
} | ||
|
||
void CPUFunction::tearDownRuns() { | ||
if (baseMutableWeightVarsAddress_) { | ||
alignedFree(baseMutableWeightVarsAddress_); | ||
baseMutableWeightVarsAddress_ = nullptr; | ||
void CPUFunction::execute(Context *ctx) { | ||
/// Base address for Activations memory block. | ||
|
||
uint8_t *baseActivationsAddress{nullptr}; | ||
|
||
/// Base address for Mutable weights memory block, Inputs and Outputs. | ||
uint8_t *baseMutableWeightVarsAddress{nullptr}; | ||
|
||
if (runtimeBundle_.getActivationsSize() != 0) { | ||
baseActivationsAddress = (uint8_t *)alignedAlloc( | ||
runtimeBundle_.getActivationsSize(), TensorAlignment); | ||
} | ||
|
||
if (baseActivationsAddress_) { | ||
alignedFree(baseActivationsAddress_); | ||
baseActivationsAddress_ = nullptr; | ||
if (runtimeBundle_.getMutableWeightSize() != 0) { | ||
baseMutableWeightVarsAddress = (uint8_t *)alignedAlloc( | ||
runtimeBundle_.getMutableWeightSize(), TensorAlignment); | ||
} | ||
runsSetup_ = false; | ||
} | ||
|
||
void CPUFunction::execute() { | ||
loadPlaceholders(ctx, baseMutableWeightVarsAddress); | ||
|
||
auto sym = JIT_->findSymbol("jitmain"); | ||
assert(sym && "Unable to JIT the code!"); | ||
using JitFuncType = | ||
|
@@ -95,9 +87,14 @@ void CPUFunction::execute() { | |
auto address = sym.getAddress(); | ||
if (address) { | ||
JitFuncType funcPtr = reinterpret_cast<JitFuncType>(address.get()); | ||
funcPtr(runtimeBundle_.getConstants(), baseMutableWeightVarsAddress_, | ||
baseActivationsAddress_); | ||
funcPtr(runtimeBundle_.getConstants(), baseMutableWeightVarsAddress, | ||
baseActivationsAddress); | ||
} else { | ||
GLOW_ASSERT(false && "Error getting address."); | ||
} | ||
|
||
updatePlaceholders(ctx, baseMutableWeightVarsAddress); | ||
|
||
alignedFree(baseMutableWeightVarsAddress); | ||
|
||
alignedFree(baseActivationsAddress); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,12 +28,6 @@ class CPUFunction final : public CompiledFunction { | |
/// initializes the LLVM backends. | ||
std::unique_ptr<llvm::orc::GlowJIT> JIT_; | ||
|
||
/// Base address for Activations memory block. | ||
uint8_t *baseActivationsAddress_{}; | ||
|
||
/// Base address for Mutable weights memory block, Inputs and Outputs. | ||
uint8_t *baseMutableWeightVarsAddress_{}; | ||
|
||
public: | ||
/// Ctor. | ||
CPUFunction(std::unique_ptr<llvm::orc::GlowJIT> JIT, | ||
|
@@ -42,24 +36,19 @@ class CPUFunction final : public CompiledFunction { | |
/// Collects constants for runtime. | ||
void collectConstants(IRFunction *F); | ||
|
||
/// Allocate Mutable buffers on device this includes Activations and | ||
/// Placeholders. | ||
void setupRuns() override; | ||
|
||
/// Copy Input Placeholder data to position. | ||
void beforeRun(const Context &ctx) override; | ||
|
||
/// Copy Outputs to Placeholders in \p ctx. | ||
void afterRun(const Context &ctx) override; | ||
|
||
/// Final cleanup, free all allocations. | ||
void tearDownRuns() override; | ||
|
||
/// \name CompiledFunction interface | ||
///@{ | ||
~CPUFunction() override; | ||
void execute() override; | ||
void execute(Context *ctx) override; | ||
///@} | ||
private: | ||
/// Load constant tensors from \p ctx into \p weightsAddress, as defined by | ||
|
||
/// the RuntimeBundle (pre-run). | ||
void loadPlaceholders(Context *ctx, uint8_t *weightsAddress); | ||
|
||
/// Load weights from \p weightsAddress into applicable backing tensors in | ||
/// \p ctx, as defined by the RuntimeBundle (post-run). | ||
void updatePlaceholders(Context *ctx, uint8_t *weightsAddress); | ||
}; | ||
} // end namespace glow | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,29 +22,26 @@ | |
|
||
#include "llvm/Support/Casting.h" | ||
|
||
#include "llvm/Support/raw_ostream.h" | ||
|
||
using namespace glow; | ||
|
||
InterpreterFunction::InterpreterFunction(std::unique_ptr<IRFunction> F, | ||
const runtime::RuntimeBundle &bundle) | ||
: CompiledFunction(bundle), F_(std::move(F)) {} | ||
|
||
InterpreterFunction::~InterpreterFunction() { | ||
// Delete the tensors that are owned by this backend. | ||
for (const auto &p : tensors_) { | ||
for (const auto &p : constants_) { | ||
delete p.second; | ||
} | ||
tensors_.clear(); | ||
externalTensors_.clear(); | ||
constants_.clear(); | ||
|
||
alignedFree(runtimeBundle_.getConstants()); | ||
tearDownRuns(); | ||
} | ||
|
||
void InterpreterFunction::collectConstants(IRFunction *F) { | ||
runtimeBundle_.collectConstants(F); | ||
} | ||
|
||
void InterpreterFunction::setupRuns() { | ||
if (!runsSetup_) { | ||
if (constants_.empty()) { | ||
if (runtimeBundle_.getConstantWeightSize()) { | ||
for (const auto &v : F_->getGraph()->getParent()->getConstants()) { | ||
auto symbolInfo = runtimeBundle_.getSymbolInfo(v); | ||
|
@@ -53,36 +50,27 @@ void InterpreterFunction::setupRuns() { | |
constants_.emplace(std::string(v->getName()), tensor); | ||
} | ||
} | ||
runsSetup_ = true; | ||
} | ||
} | ||
|
||
void InterpreterFunction::beforeRun(const Context &ctx) { | ||
// Register the concrete tensors that back the placeholder tensors. | ||
for (auto &ph : ctx.pairs()) { | ||
auto *w = F_->getWeightForNode(ph.first); | ||
assert(!externalTensors_.count(w) && "The tensor is already registered"); | ||
externalTensors_[w] = ph.second; | ||
} | ||
} | ||
|
||
void InterpreterFunction::afterRun(const Context &ctx) { | ||
// Remove the concrete tensors that back the placeholder tensors. | ||
for (auto &ph : ctx.pairs()) { | ||
auto *w = F_->getWeightForNode(ph.first); | ||
externalTensors_.erase(w); | ||
void InterpreterFunction::execute(Context *ctx) { | ||
if (constants_.empty()) { | ||
collectConstants(F_.get()); | ||
} | ||
BoundInterpreterFunction boundFunc(constants_); | ||
boundFunc.execute(F_.get(), ctx); | ||
} | ||
|
||
void InterpreterFunction::tearDownRuns() { | ||
for (const auto &p : constants_) { | ||
BoundInterpreterFunction::~BoundInterpreterFunction() { | ||
// Delete the tensors that are owned by this backend. | ||
for (const auto &p : tensors_) { | ||
delete p.second; | ||
} | ||
constants_.clear(); | ||
runsSetup_ = false; | ||
tensors_.clear(); | ||
externalTensors_.clear(); | ||
} | ||
|
||
Tensor *InterpreterFunction::getTensor(const Value *v) const { | ||
Tensor *BoundInterpreterFunction::getTensor(const Value *v) const { | ||
auto it = tensors_.find(v); | ||
if (it != tensors_.end()) { | ||
return it->second; | ||
|
@@ -97,7 +85,7 @@ Tensor *InterpreterFunction::getTensor(const Value *v) const { | |
return ie->second; | ||
} | ||
|
||
Tensor *InterpreterFunction::getOrCreateTensor(const Value *v) { | ||
Tensor *BoundInterpreterFunction::getOrCreateTensor(const Value *v) { | ||
auto ie = externalTensors_.find(v); | ||
if (ie != externalTensors_.end()) { | ||
return ie->second; | ||
|
@@ -117,9 +105,8 @@ Tensor *InterpreterFunction::getOrCreateTensor(const Value *v) { | |
return it->second; | ||
} | ||
|
||
Tensor * | ||
InterpreterFunction::getOrCreateUnownedTensor(const Value *v, const Value *src, | ||
llvm::ArrayRef<size_t> offsets) { | ||
Tensor *BoundInterpreterFunction::getOrCreateUnownedTensor( | ||
const Value *v, const Value *src, llvm::ArrayRef<size_t> offsets) { | ||
assert(llvm::isa<TensorViewInst>(v) && "Expected a tensor view"); | ||
|
||
// Pick the tensor. | ||
|
@@ -136,7 +123,7 @@ InterpreterFunction::getOrCreateUnownedTensor(const Value *v, const Value *src, | |
return T; | ||
} | ||
|
||
void InterpreterFunction::deleteTensor(const Value *v) { | ||
void BoundInterpreterFunction::deleteTensor(const Value *v) { | ||
auto it = tensors_.find(v); | ||
if (it == tensors_.end()) { | ||
return; | ||
|
@@ -146,7 +133,14 @@ void InterpreterFunction::deleteTensor(const Value *v) { | |
tensors_.erase(it); | ||
} | ||
|
||
void InterpreterFunction::execute() { | ||
void BoundInterpreterFunction::execute(IRFunction *F, Context *ctx) { | ||
// Register the concrete tensors that back the placeholder tensors. | ||
for (auto &ph : ctx->pairs()) { | ||
auto *w = F->getWeightForNode(ph.first); | ||
assert(!externalTensors_.count(w) && "The tensor is already registered"); | ||
externalTensors_[w] = ph.second; | ||
} | ||
|
||
// Do the forward pass. | ||
#define DEF_VALUE(CLASS, NAME) | ||
#define DEF_INSTR(CLASS, NAME) \ | ||
|
@@ -156,12 +150,18 @@ void InterpreterFunction::execute() { | |
} | ||
#define DEF_BACKEND_SPECIFIC_INSTR(CLASS, NAME) | ||
// Dispatch the interpreter on each instruction in the program: | ||
for (const auto &I : F_->getInstrs()) { | ||
for (const auto &I : F->getInstrs()) { | ||
switch (I.getKind()) { | ||
#include "glow/AutoGenInstr.def" | ||
|
||
default: | ||
llvm_unreachable("Invalid instruction."); | ||
} | ||
} | ||
|
||
// Remove the concrete tensors that back the placeholder tensors. | ||
for (auto &ph : ctx->pairs()) { | ||
auto *w = F->getWeightForNode(ph.first); | ||
externalTensors_.erase(w); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
magic :) comment about ctx was already in place