Skip to content

Commit 773389b

Browse files
committed
working
1 parent 215bcfc commit 773389b

File tree

4 files changed

+145
-2
lines changed

4 files changed

+145
-2
lines changed

compiler/rustc_codegen_llvm/src/back/write.rs

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,84 @@ pub(crate) unsafe fn llvm_optimize(
653653
None
654654
};
655655

656+
fn handle_offload(m: &llvm::Module, llcx: &llvm::Context, old_fn: &llvm::Value) {
657+
dbg!(&old_fn);
658+
unsafe {llvm::LLVMRustOffloadWrapper(m, old_fn)};
659+
//unsafe {llvm::LLVMDumpModule(m);}
660+
//unsafe {
661+
// // Get the old function type
662+
// let old_fn_ty = llvm::LLVMGlobalGetValueType(old_fn);
663+
// dbg!(&old_fn_ty);
664+
// let old_param_count = llvm::LLVMCountParamTypes(old_fn_ty);
665+
// dbg!(&old_param_count);
666+
667+
// // Get the old parameter types
668+
// let mut old_param_types = Vec::with_capacity(old_param_count as usize);
669+
// llvm::LLVMGetParamTypes(old_fn_ty, old_param_types.as_mut_ptr());
670+
// old_param_types.set_len(old_param_count as usize);
671+
672+
// // Create the new parameter list, with ptr as the first argument
673+
// let ptr_ty = llvm::LLVMPointerTypeInContext(llcx, 0);
674+
// let mut new_param_types = Vec::with_capacity(old_param_count as usize + 1);
675+
// new_param_types.push(ptr_ty);
676+
// for old_param in old_param_types {
677+
// new_param_types.push(old_param);
678+
// }
679+
// dbg!(&new_param_types);
680+
681+
// // Create the new function type
682+
// let ret_ty = llvm::LLVMGetReturnType(old_fn_ty);
683+
// let new_fn_ty = llvm::LLVMFunctionType(ret_ty, new_param_types.as_mut_ptr(), new_param_types.len() as u32, 0);
684+
// dbg!(&new_fn_ty);
685+
686+
// // Create the new function
687+
// let old_fn_name = String::from_utf8(llvm::get_value_name(old_fn)).unwrap();
688+
// //let old_fn_name = std::ffi::CStr::from_ptr(llvm::LLVMGetValueName2(old_fn)).to_str().unwrap();
689+
// let new_fn_name = format!("{}_with_dyn_ptr", old_fn_name);
690+
// let new_fn_cstr = CString::new(new_fn_name).unwrap();
691+
// let new_fn = llvm::LLVMAddFunction(m, new_fn_cstr.as_ptr(), new_fn_ty);
692+
// dbg!(&new_fn);
693+
// let a0 = llvm::LLVMGetParam(new_fn, 0);
694+
// llvm::LLVMSetValueName2(a0, b"dyn_ptr\0".as_ptr().cast(), "dyn_ptr".len());
695+
// dbg!(&new_fn);
696+
697+
// // Move basic blocks
698+
// let mut bb = llvm::LLVMGetFirstBasicBlock(old_fn);
699+
// //dbg!(&bb);
700+
// llvm::LLVMAppendExistingBasicBlock(new_fn, bb);
701+
// //while !bb.is_null() {
702+
// // let next = llvm::LLVMGetNextBasicBlock(bb);
703+
// // llvm::LLVMAppendExistingBasicBlock(new_fn, bb);
704+
// // bb = next;
705+
// //}// Shift argument uses: old %0 -> new %1, old %1 -> new %2, ...
706+
// let old_n = llvm::LLVMCountParams(old_fn);
707+
// for i in 0..old_n {
708+
// let old_arg = llvm::LLVMGetParam(old_fn, i);
709+
// let new_arg = llvm::LLVMGetParam(new_fn, i + 1);
710+
// llvm::LLVMReplaceAllUsesWith(old_arg, new_arg);
711+
// }
712+
713+
// // Copy linkage and visibility
714+
// //llvm::LLVMSetLinkage(new_fn, llvm::LLVMGetLinkage(old_fn));
715+
// //llvm::LLVMSetVisibility(new_fn, llvm::LLVMGetVisibility(old_fn));
716+
717+
// // Replace all uses of old_fn with new_fn (RAUW)
718+
// llvm::LLVMReplaceAllUsesWith(old_fn, new_fn);
719+
720+
// // Optionally, remove the old function
721+
// llvm::LLVMDeleteFunction(old_fn);
722+
//}
723+
}
724+
725+
for num in 0..9 {
726+
let name = format!("kernel_{num}");
727+
let c_name = CString::new(name).unwrap();
728+
if let Some(kernel) = unsafe {llvm::LLVMGetNamedFunction(module.module_llvm.llmod() , c_name.as_ptr())} {
729+
dbg!("found offload kernel asfd");
730+
handle_offload(module.module_llvm.llmod(), module.module_llvm.llcx, kernel);
731+
}
732+
}
733+
656734
let mut llvm_profiler = cgcx
657735
.prof
658736
.llvm_recording_enabled()

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,6 +1201,7 @@ unsafe extern "C" {
12011201

12021202
// Operations on functions
12031203
pub(crate) fn LLVMSetFunctionCallConv(Fn: &Value, CC: c_uint);
1204+
pub(crate) fn LLVMAddFunction<'a>(Mod: &'a Module, Name: *const c_char, FunctionTy: &'a Type) -> &'a Value;
12041205
pub(crate) fn LLVMDeleteFunction(Fn: &Value);
12051206

12061207
// Operations about llvm intrinsics
@@ -1219,6 +1220,7 @@ unsafe extern "C" {
12191220

12201221
// Operations on basic blocks
12211222
pub(crate) fn LLVMGetBasicBlockParent(BB: &BasicBlock) -> &Value;
1223+
pub(crate) fn LLVMAppendExistingBasicBlock<'a>(Fn: &'a Value, BB: &BasicBlock);
12221224
pub(crate) fn LLVMAppendBasicBlockInContext<'a>(
12231225
C: &'a Context,
12241226
Fn: &'a Value,
@@ -1892,6 +1894,10 @@ unsafe extern "C" {
18921894
) -> &Attribute;
18931895

18941896
// Operations on functions
1897+
pub(crate) fn LLVMRustOffloadWrapper<'a>(
1898+
M: &'a Module,
1899+
Fn: &'a Value,
1900+
);
18951901
pub(crate) fn LLVMRustGetOrInsertFunction<'a>(
18961902
M: &'a Module,
18971903
Name: *const c_char,

compiler/rustc_codegen_llvm/src/mono_item.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ impl<'tcx> PreDefineCodegenMethods<'tcx> for CodegenCx<'_, 'tcx> {
6868
let fn_abi = self.fn_abi_of_instance(instance, ty::List::empty());
6969
let fn_abi = if fn_abi.conv == rustc_abi::CanonAbi::GpuKernel {
7070
dbg!("found gpu fn!");
71-
my_fn_abi(fn_abi)
71+
fn_abi.clone()
72+
//my_fn_abi(fn_abi)
7273
} else {
73-
dbg!("asdf!");
7474
fn_abi.clone()
7575
};
7676
let lldecl = self.declare_fn(symbol_name, &fn_abi, Some(instance));

compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include "LLVMWrapper.h"
22

3+
#include "llvm/Transforms/Utils/ValueMapper.h"
4+
#include "llvm/Transforms/Utils/Cloning.h"
35
#include "llvm-c/Analysis.h"
46
#include "llvm-c/Core.h"
57
#include "llvm-c/DebugInfo.h"
@@ -170,6 +172,63 @@ extern "C" void LLVMRustPrintStatistics(RustStringRef OutBuf) {
170172
llvm::PrintStatistics(OS);
171173
}
172174

175+
extern "C" void LLVMRustOffloadWrapper(LLVMModuleRef M, LLVMValueRef Fn) {
176+
// Convert to C++ types
177+
llvm::Module *module = llvm::unwrap(M);
178+
llvm::Function *oldFn = llvm::unwrap<llvm::Function>(Fn);
179+
180+
if (oldFn->arg_size() > 0 && oldFn->getArg(0)->getName() == "dyn_ptr") {
181+
return;
182+
}
183+
184+
// 1. Create new function type
185+
llvm::LLVMContext &ctx = module->getContext();
186+
llvm::Type *dynPtrType = llvm::PointerType::get(ctx,0);
187+
std::vector<llvm::Type *> argTypes;
188+
argTypes.push_back(dynPtrType); // First argument
189+
190+
for (auto &arg : oldFn->args()) {
191+
argTypes.push_back(arg.getType());
192+
}
193+
194+
llvm::FunctionType *newFnType = llvm::FunctionType::get(
195+
oldFn->getReturnType(), argTypes, oldFn->isVarArg()
196+
);
197+
198+
199+
// 2. Create new function
200+
llvm::Function *newFn = llvm::Function::Create(
201+
//newFnType, oldFn->getLinkage(), oldFn->getName(), module
202+
newFnType, oldFn->getLinkage(), oldFn->getName() + ".offload", module
203+
);
204+
205+
// Map old arguments to new arguments (skip first argument)
206+
llvm::ValueToValueMapTy vmap;
207+
auto newArgIt = newFn->arg_begin();
208+
newArgIt->setName("dyn_ptr");
209+
++newArgIt; // skip %dyn_ptr
210+
for (auto &oldArg : oldFn->args()) {
211+
vmap[&oldArg] = &*newArgIt++;
212+
}
213+
214+
// 2. Clone body
215+
llvm::SmallVector<llvm::ReturnInst *, 8> returns;
216+
llvm::CloneFunctionInto(newFn, oldFn, vmap, llvm::CloneFunctionChangeType::LocalChangesOnly, returns);
217+
//llvm::CloneFunctionInto(newFn, oldFn, vmap, false, returns);
218+
//
219+
newFn->setLinkage(oldFn->getLinkage());
220+
newFn->setVisibility(oldFn->getVisibility());
221+
222+
// 3. Print new function
223+
newFn->print(llvm::errs());
224+
225+
// Replace uses and delete old function
226+
oldFn->replaceAllUsesWith(newFn);
227+
auto name = oldFn->getName();
228+
oldFn->eraseFromParent();
229+
newFn->setName(name);
230+
}
231+
173232
extern "C" LLVMValueRef LLVMRustGetNamedValue(LLVMModuleRef M, const char *Name,
174233
size_t NameLen) {
175234
return wrap(unwrap(M)->getNamedValue(StringRef(Name, NameLen)));

0 commit comments

Comments
 (0)