Skip to content

Commit cdbbe9c

Browse files
committed
fix device code generation
1 parent 0f05703 commit cdbbe9c

File tree

5 files changed

+142
-1
lines changed

5 files changed

+142
-1
lines changed

compiler/rustc_codegen_llvm/src/back/lto.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,7 @@ pub(crate) fn run_pass_manager(
585585
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
586586
}
587587

588-
if enable_gpu && !thin {
588+
if enable_gpu && !thin && !(cgcx.target_arch == "nvptx64" || cgcx.target_arch == "amdgpu") {
589589
let cx =
590590
SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size);
591591
crate::builder::gpu_offload::handle_gpu_code(cgcx, &cx);

compiler/rustc_codegen_llvm/src/back/write.rs

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,87 @@ pub(crate) unsafe fn llvm_optimize(
653653
None
654654
};
655655

656+
fn handle_offload(m: &llvm::Module, llcx: &llvm::Context, old_fn: &llvm::Value) {
657+
unsafe { llvm::LLVMRustOffloadWrapper(m, old_fn) };
658+
//unsafe {llvm::LLVMDumpModule(m);}
659+
//unsafe {
660+
// // Get the old function type
661+
// let old_fn_ty = llvm::LLVMGlobalGetValueType(old_fn);
662+
// dbg!(&old_fn_ty);
663+
// let old_param_count = llvm::LLVMCountParamTypes(old_fn_ty);
664+
// dbg!(&old_param_count);
665+
666+
// // Get the old parameter types
667+
// let mut old_param_types = Vec::with_capacity(old_param_count as usize);
668+
// llvm::LLVMGetParamTypes(old_fn_ty, old_param_types.as_mut_ptr());
669+
// old_param_types.set_len(old_param_count as usize);
670+
671+
// // Create the new parameter list, with ptr as the first argument
672+
// let ptr_ty = llvm::LLVMPointerTypeInContext(llcx, 0);
673+
// let mut new_param_types = Vec::with_capacity(old_param_count as usize + 1);
674+
// new_param_types.push(ptr_ty);
675+
// for old_param in old_param_types {
676+
// new_param_types.push(old_param);
677+
// }
678+
// dbg!(&new_param_types);
679+
680+
// // Create the new function type
681+
// let ret_ty = llvm::LLVMGetReturnType(old_fn_ty);
682+
// let new_fn_ty = llvm::LLVMFunctionType(ret_ty, new_param_types.as_mut_ptr(), new_param_types.len() as u32, 0);
683+
// dbg!(&new_fn_ty);
684+
685+
// // Create the new function
686+
// let old_fn_name = String::from_utf8(llvm::get_value_name(old_fn)).unwrap();
687+
// //let old_fn_name = std::ffi::CStr::from_ptr(llvm::LLVMGetValueName2(old_fn)).to_str().unwrap();
688+
// let new_fn_name = format!("{}_with_dyn_ptr", old_fn_name);
689+
// let new_fn_cstr = CString::new(new_fn_name).unwrap();
690+
// let new_fn = llvm::LLVMAddFunction(m, new_fn_cstr.as_ptr(), new_fn_ty);
691+
// dbg!(&new_fn);
692+
// let a0 = llvm::LLVMGetParam(new_fn, 0);
693+
// llvm::LLVMSetValueName2(a0, b"dyn_ptr\0".as_ptr().cast(), "dyn_ptr".len());
694+
// dbg!(&new_fn);
695+
696+
// // Move basic blocks
697+
// let mut bb = llvm::LLVMGetFirstBasicBlock(old_fn);
698+
// //dbg!(&bb);
699+
// llvm::LLVMAppendExistingBasicBlock(new_fn, bb);
700+
// //while !bb.is_null() {
701+
// // let next = llvm::LLVMGetNextBasicBlock(bb);
702+
// // llvm::LLVMAppendExistingBasicBlock(new_fn, bb);
703+
// // bb = next;
704+
// //}// Shift argument uses: old %0 -> new %1, old %1 -> new %2, ...
705+
// let old_n = llvm::LLVMCountParams(old_fn);
706+
// for i in 0..old_n {
707+
// let old_arg = llvm::LLVMGetParam(old_fn, i);
708+
// let new_arg = llvm::LLVMGetParam(new_fn, i + 1);
709+
// llvm::LLVMReplaceAllUsesWith(old_arg, new_arg);
710+
// }
711+
712+
// // Copy linkage and visibility
713+
// //llvm::LLVMSetLinkage(new_fn, llvm::LLVMGetLinkage(old_fn));
714+
// //llvm::LLVMSetVisibility(new_fn, llvm::LLVMGetVisibility(old_fn));
715+
716+
// // Replace all uses of old_fn with new_fn (RAUW)
717+
// llvm::LLVMReplaceAllUsesWith(old_fn, new_fn);
718+
719+
// // Optionally, remove the old function
720+
// llvm::LLVMDeleteFunction(old_fn);
721+
//}
722+
}
723+
724+
let consider_offload = config.offload.contains(&config::Offload::Enable);
725+
if consider_offload && (cgcx.target_arch == "amdgpu" || cgcx.target_arch == "nvptx64") {
726+
for num in 0..9 {
727+
let name = format!("kernel_{num}");
728+
let c_name = CString::new(name).unwrap();
729+
if let Some(kernel) =
730+
unsafe { llvm::LLVMGetNamedFunction(module.module_llvm.llmod(), c_name.as_ptr()) }
731+
{
732+
handle_offload(module.module_llvm.llmod(), module.module_llvm.llcx, kernel);
733+
}
734+
}
735+
}
736+
656737
let mut llvm_profiler = cgcx
657738
.prof
658739
.llvm_recording_enabled()

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,6 +1201,11 @@ unsafe extern "C" {
12011201

12021202
// Operations on functions
12031203
pub(crate) fn LLVMSetFunctionCallConv(Fn: &Value, CC: c_uint);
1204+
pub(crate) fn LLVMAddFunction<'a>(
1205+
Mod: &'a Module,
1206+
Name: *const c_char,
1207+
FunctionTy: &'a Type,
1208+
) -> &'a Value;
12041209
pub(crate) fn LLVMDeleteFunction(Fn: &Value);
12051210

12061211
// Operations about llvm intrinsics
@@ -1219,6 +1224,7 @@ unsafe extern "C" {
12191224

12201225
// Operations on basic blocks
12211226
pub(crate) fn LLVMGetBasicBlockParent(BB: &BasicBlock) -> &Value;
1227+
pub(crate) fn LLVMAppendExistingBasicBlock<'a>(Fn: &'a Value, BB: &BasicBlock);
12221228
pub(crate) fn LLVMAppendBasicBlockInContext<'a>(
12231229
C: &'a Context,
12241230
Fn: &'a Value,
@@ -1892,6 +1898,7 @@ unsafe extern "C" {
18921898
) -> &Attribute;
18931899

18941900
// Operations on functions
1901+
pub(crate) fn LLVMRustOffloadWrapper<'a>(M: &'a Module, Fn: &'a Value);
18951902
pub(crate) fn LLVMRustGetOrInsertFunction<'a>(
18961903
M: &'a Module,
18971904
Name: *const c_char,

compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
#include "llvm/Support/Signals.h"
3636
#include "llvm/Support/Timer.h"
3737
#include "llvm/Support/ToolOutputFile.h"
38+
#include "llvm/Transforms/Utils/Cloning.h"
39+
#include "llvm/Transforms/Utils/ValueMapper.h"
3840
#include <iostream>
3941

4042
// for raw `write` in the bad-alloc handler
@@ -170,6 +172,56 @@ extern "C" void LLVMRustPrintStatistics(RustStringRef OutBuf) {
170172
llvm::PrintStatistics(OS);
171173
}
172174

175+
extern "C" void LLVMRustOffloadWrapper(LLVMModuleRef M, LLVMValueRef Fn) {
176+
llvm::Module *module = llvm::unwrap(M);
177+
llvm::Function *oldFn = llvm::unwrap<llvm::Function>(Fn);
178+
179+
if (oldFn->arg_size() > 0 && oldFn->getArg(0)->getName() == "dyn_ptr") {
180+
return;
181+
}
182+
183+
// 1. Create new function type with the leading extra %dyn_ptr arg which llvm
184+
// offload requries.
185+
llvm::LLVMContext &ctx = module->getContext();
186+
llvm::Type *dynPtrType = llvm::PointerType::get(ctx, 0);
187+
std::vector<llvm::Type *> argTypes;
188+
argTypes.push_back(dynPtrType);
189+
190+
for (auto &arg : oldFn->args()) {
191+
argTypes.push_back(arg.getType());
192+
}
193+
194+
llvm::FunctionType *newFnType = llvm::FunctionType::get(
195+
oldFn->getReturnType(), argTypes, oldFn->isVarArg());
196+
197+
// use a temporary .offload appendix to avoid name clashes
198+
llvm::Function *newFn = llvm::Function::Create(
199+
newFnType, oldFn->getLinkage(), oldFn->getName() + ".offload", module);
200+
201+
// Map old arguments to new arguments. We skip the first dyn_ptr argument,
202+
// since it can't be used directly by user code.
203+
llvm::ValueToValueMapTy vmap;
204+
auto newArgIt = newFn->arg_begin();
205+
newArgIt->setName("dyn_ptr");
206+
++newArgIt; // skip %dyn_ptr
207+
for (auto &oldArg : oldFn->args()) {
208+
vmap[&oldArg] = &*newArgIt++;
209+
}
210+
211+
llvm::SmallVector<llvm::ReturnInst *, 8> returns;
212+
llvm::CloneFunctionInto(newFn, oldFn, vmap,
213+
llvm::CloneFunctionChangeType::LocalChangesOnly,
214+
returns);
215+
newFn->setLinkage(oldFn->getLinkage());
216+
newFn->setVisibility(oldFn->getVisibility());
217+
218+
// Replace uses, delete old function, and reset name to the original one.
219+
oldFn->replaceAllUsesWith(newFn);
220+
auto name = oldFn->getName();
221+
oldFn->eraseFromParent();
222+
newFn->setName(name);
223+
}
224+
173225
extern "C" LLVMValueRef LLVMRustGetNamedValue(LLVMModuleRef M, const char *Name,
174226
size_t NameLen) {
175227
return wrap(unwrap(M)->getNamedValue(StringRef(Name, NameLen)));

compiler/rustc_target/src/callconv/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,7 @@ impl RiscvInterruptKind {
577577
///
578578
/// The signature represented by this type may not match the MIR function signature.
579579
/// Certain attributes, like `#[track_caller]` can introduce additional arguments, which are present in [`FnAbi`], but not in `FnSig`.
580+
/// The std::offload module also adds an addition dyn_ptr argument to the GpuKernel ABI.
580581
/// While this difference is rarely relevant, it should still be kept in mind.
581582
///
582583
/// I will do my best to describe this structure, but these

0 commit comments

Comments
 (0)