diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 62998003ca114..6217fac86b6ec 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -653,6 +653,84 @@ pub(crate) unsafe fn llvm_optimize( None }; + fn handle_offload(m: &llvm::Module, llcx: &llvm::Context, old_fn: &llvm::Value) { + dbg!(&old_fn); + unsafe {llvm::LLVMRustOffloadWrapper(m, old_fn)}; + //unsafe {llvm::LLVMDumpModule(m);} + //unsafe { + // // Get the old function type + // let old_fn_ty = llvm::LLVMGlobalGetValueType(old_fn); + // dbg!(&old_fn_ty); + // let old_param_count = llvm::LLVMCountParamTypes(old_fn_ty); + // dbg!(&old_param_count); + + // // Get the old parameter types + // let mut old_param_types = Vec::with_capacity(old_param_count as usize); + // llvm::LLVMGetParamTypes(old_fn_ty, old_param_types.as_mut_ptr()); + // old_param_types.set_len(old_param_count as usize); + + // // Create the new parameter list, with ptr as the first argument + // let ptr_ty = llvm::LLVMPointerTypeInContext(llcx, 0); + // let mut new_param_types = Vec::with_capacity(old_param_count as usize + 1); + // new_param_types.push(ptr_ty); + // for old_param in old_param_types { + // new_param_types.push(old_param); + // } + // dbg!(&new_param_types); + + // // Create the new function type + // let ret_ty = llvm::LLVMGetReturnType(old_fn_ty); + // let new_fn_ty = llvm::LLVMFunctionType(ret_ty, new_param_types.as_mut_ptr(), new_param_types.len() as u32, 0); + // dbg!(&new_fn_ty); + + // // Create the new function + // let old_fn_name = String::from_utf8(llvm::get_value_name(old_fn)).unwrap(); + // //let old_fn_name = std::ffi::CStr::from_ptr(llvm::LLVMGetValueName2(old_fn)).to_str().unwrap(); + // let new_fn_name = format!("{}_with_dyn_ptr", old_fn_name); + // let new_fn_cstr = CString::new(new_fn_name).unwrap(); + // let new_fn = llvm::LLVMAddFunction(m, new_fn_cstr.as_ptr(), new_fn_ty); + // dbg!(&new_fn); + // let a0 = llvm::LLVMGetParam(new_fn, 0); + // llvm::LLVMSetValueName2(a0, b"dyn_ptr\0".as_ptr().cast(), "dyn_ptr".len()); + // dbg!(&new_fn); + + // // Move basic blocks + // let mut bb = llvm::LLVMGetFirstBasicBlock(old_fn); + // //dbg!(&bb); + // llvm::LLVMAppendExistingBasicBlock(new_fn, bb); + // //while !bb.is_null() { + // // let next = llvm::LLVMGetNextBasicBlock(bb); + // // llvm::LLVMAppendExistingBasicBlock(new_fn, bb); + // // bb = next; + // //}// Shift argument uses: old %0 -> new %1, old %1 -> new %2, ... + // let old_n = llvm::LLVMCountParams(old_fn); + // for i in 0..old_n { + // let old_arg = llvm::LLVMGetParam(old_fn, i); + // let new_arg = llvm::LLVMGetParam(new_fn, i + 1); + // llvm::LLVMReplaceAllUsesWith(old_arg, new_arg); + // } + + // // Copy linkage and visibility + // //llvm::LLVMSetLinkage(new_fn, llvm::LLVMGetLinkage(old_fn)); + // //llvm::LLVMSetVisibility(new_fn, llvm::LLVMGetVisibility(old_fn)); + + // // Replace all uses of old_fn with new_fn (RAUW) + // llvm::LLVMReplaceAllUsesWith(old_fn, new_fn); + + // // Optionally, remove the old function + // llvm::LLVMDeleteFunction(old_fn); + //} + } + + for num in 0..9 { + let name = format!("kernel_{num}"); + let c_name = CString::new(name).unwrap(); + if let Some(kernel) = unsafe {llvm::LLVMGetNamedFunction(module.module_llvm.llmod() , c_name.as_ptr())} { + dbg!("found offload kernel asfd"); + handle_offload(module.module_llvm.llmod(), module.module_llvm.llcx, kernel); + } + } + let mut llvm_profiler = cgcx .prof .llvm_recording_enabled() diff --git a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs index 1280ab1442a09..ea15ff99ae9d6 100644 --- a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs +++ b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs @@ -18,23 +18,42 @@ pub(crate) fn handle_gpu_code<'ll>( // The offload memory transfer type for each kernel let mut o_types = vec![]; let mut kernels = vec![]; - let offload_entry_ty = add_tgt_offload_entry(&cx); + let mut region_ids = vec![]; + let offload_entry_ty = TgtOffloadEntry::new_decl(&cx); for num in 0..9 { let kernel = cx.get_function(&format!("kernel_{num}")); if let Some(kernel) = kernel { - o_types.push(gen_define_handling(&cx, kernel, offload_entry_ty, num)); + let (o, k) = gen_define_handling(&cx, kernel, offload_entry_ty, num); + o_types.push(o); + region_ids.push(k); kernels.push(kernel); } } - gen_call_handling(&cx, &kernels, &o_types); + gen_call_handling(&cx, &kernels, &o_types, ®ion_ids); +} + +// ; Function Attrs: nounwind +// declare i32 @__tgt_target_kernel(ptr, i64, i32, i32, ptr, ptr) #2 +fn generate_launcher<'ll>(cx: &'ll SimpleCx<'_>) -> (&'ll llvm::Value, &'ll llvm::Type) { + let tptr = cx.type_ptr(); + let ti64 = cx.type_i64(); + let ti32 = cx.type_i32(); + let args = vec![tptr, ti64, ti32, ti32, tptr, tptr]; + let tgt_fn_ty = cx.type_func(&args, ti32); + let name = "__tgt_target_kernel"; + let tgt_decl = declare_offload_fn(&cx, name, tgt_fn_ty); + let nounwind = llvm::AttributeKind::NoUnwind.create_attr(cx.llcx); + attributes::apply_to_llfn(tgt_decl, Function, &[nounwind]); + (tgt_decl, tgt_fn_ty) } // What is our @1 here? A magic global, used in our data_{begin/update/end}_mapper: // @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1 // @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8 +// FIXME(offload): @0 should include the file name (e.g. lib.rs) in which the function to be +// offloaded was defined. fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value { - // @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1 let unknown_txt = ";unknown;unknown;0;0;;"; let c_entry_name = CString::new(unknown_txt).unwrap(); let c_val = c_entry_name.as_bytes_with_nul(); @@ -59,15 +78,7 @@ fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value { at_one } -pub(crate) fn add_tgt_offload_entry<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type { - let offload_entry_ty = cx.type_named_struct("struct.__tgt_offload_entry"); - let tptr = cx.type_ptr(); - let ti64 = cx.type_i64(); - let ti32 = cx.type_i32(); - let ti16 = cx.type_i16(); - // For each kernel to run on the gpu, we will later generate one entry of this type. - // copied from LLVM - // typedef struct { +struct TgtOffloadEntry { // uint64_t Reserved; // uint16_t Version; // uint16_t Kind; @@ -77,21 +88,40 @@ pub(crate) fn add_tgt_offload_entry<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Ty // uint64_t Size; Size of the entry info (0 if it is a function) // uint64_t Data; // void *AuxAddr; - // } __tgt_offload_entry; - let entry_elements = vec![ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr]; - cx.set_struct_body(offload_entry_ty, &entry_elements, false); - offload_entry_ty } -fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) { - let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments"); - let tptr = cx.type_ptr(); - let ti64 = cx.type_i64(); - let ti32 = cx.type_i32(); - let tarr = cx.type_array(ti32, 3); +impl TgtOffloadEntry { + pub(crate) fn new_decl<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type { + let offload_entry_ty = cx.type_named_struct("struct.__tgt_offload_entry"); + let tptr = cx.type_ptr(); + let ti64 = cx.type_i64(); + let ti32 = cx.type_i32(); + let ti16 = cx.type_i16(); + // For each kernel to run on the gpu, we will later generate one entry of this type. + // copied from LLVM + let entry_elements = vec![ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr]; + cx.set_struct_body(offload_entry_ty, &entry_elements, false); + offload_entry_ty + } + + fn new<'ll>( + cx: &'ll SimpleCx<'_>, + region_id: &'ll Value, + llglobal: &'ll Value, + ) -> Vec<&'ll Value> { + let reserved = cx.get_const_i64(0); + let version = cx.get_const_i16(1); + let kind = cx.get_const_i16(1); + let flags = cx.get_const_i32(0); + let size = cx.get_const_i64(0); + let data = cx.get_const_i64(0); + let aux_addr = cx.const_null(cx.type_ptr()); + vec![reserved, version, kind, flags, region_id, llglobal, size, data, aux_addr] + } +} - // Taken from the LLVM APITypes.h declaration: - //struct KernelArgsTy { +// Taken from the LLVM APITypes.h declaration: +struct KernelArgsTy { // uint32_t Version = 0; // Version of this struct for ABI compatibility. // uint32_t NumArgs = 0; // Number of arguments in each input pointer. // void **ArgBasePtrs = @@ -102,25 +132,64 @@ fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) { // void **ArgNames = nullptr; // Name of the data for debugging, possibly null. // void **ArgMappers = nullptr; // User-defined mappers, possibly null. // uint64_t Tripcount = - // 0; // Tripcount for the teams / distribute loop, 0 otherwise. - // struct { + // 0; // Tripcount for the teams / distribute loop, 0 otherwise. + // struct { // uint64_t NoWait : 1; // Was this kernel spawned with a `nowait` clause. // uint64_t IsCUDA : 1; // Was this kernel spawned via CUDA. // uint64_t Unused : 62; - // } Flags = {0, 0, 0}; + // } Flags = {0, 0, 0}; // totals to 64 Bit, 8 Byte // // The number of teams (for x,y,z dimension). // uint32_t NumTeams[3] = {0, 0, 0}; // // The number of threads (for x,y,z dimension). // uint32_t ThreadLimit[3] = {0, 0, 0}; // uint32_t DynCGroupMem = 0; // Amount of dynamic cgroup memory requested. - //}; - let kernel_elements = - vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32]; - - cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false); - // For now we don't handle kernels, so for now we just add a global dummy - // to make sure that the __tgt_offload_entry is defined and handled correctly. - cx.declare_global("my_struct_global2", kernel_arguments_ty); +} + +impl KernelArgsTy { + const OFFLOAD_VERSION: u64 = 3; + const FLAGS: u64 = 0; + const TRIPCOUNT: u64 = 0; + fn new_decl<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll Type { + let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments"); + let tptr = cx.type_ptr(); + let ti64 = cx.type_i64(); + let ti32 = cx.type_i32(); + let tarr = cx.type_array(ti32, 3); + + let kernel_elements = + vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32]; + + cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false); + kernel_arguments_ty + } + + fn new<'ll>( + cx: &'ll SimpleCx<'_>, + num_args: u64, + o_types: &[&'ll Value], + geps: [&'ll Value; 3], + ) -> [(Align, &'ll Value); 13] { + let four = Align::from_bytes(4).expect("4 Byte alignment should work"); + let eight = Align::EIGHT; + let mut values = vec![]; + values.push((four, cx.get_const_i32(KernelArgsTy::OFFLOAD_VERSION))); + values.push((four, cx.get_const_i32(num_args))); + values.push((eight, geps[0])); + values.push((eight, geps[1])); + values.push((eight, geps[2])); + values.push((eight, o_types[0])); + // The next two are debug infos. FIXME(offload): set them + values.push((eight, cx.const_null(cx.type_ptr()))); + values.push((eight, cx.const_null(cx.type_ptr()))); + values.push((eight, cx.get_const_i64(KernelArgsTy::TRIPCOUNT))); + values.push((eight, cx.get_const_i64(KernelArgsTy::FLAGS))); + let ti32 = cx.type_i32(); + let ci32_0 = cx.get_const_i32(0); + values.push((four, cx.const_array(ti32, &vec![cx.get_const_i32(2097152), ci32_0, ci32_0]))); + values.push((four, cx.const_array(ti32, &vec![cx.get_const_i32(256), ci32_0, ci32_0]))); + values.push((four, cx.get_const_i32(0))); + values.try_into().expect("tgt_kernel_arguments construction failed") + } } fn gen_tgt_data_mappers<'ll>( @@ -187,7 +256,7 @@ fn gen_define_handling<'ll>( kernel: &'ll llvm::Value, offload_entry_ty: &'ll llvm::Type, num: i64, -) -> &'ll llvm::Value { +) -> (&'ll llvm::Value, &'ll llvm::Value) { let types = cx.func_params_types(cx.get_type_of_global(kernel)); // It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or // reference) types. @@ -205,10 +274,14 @@ fn gen_define_handling<'ll>( // or both to and from the gpu (=3). Other values shouldn't affect us for now. // A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten // will be 2. For now, everything is 3, until we have our frontend set up. - let o_types = - add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{num}"), &vec![3; num_ptr_types]); + // 1+2+32: 1 (MapTo), 2 (MapFrom), 32 (Add one extra input ptr per function, to be used later). + let o_types = add_priv_unnamed_arr( + &cx, + &format!(".offload_maptypes.{num}"), + &vec![1 + 2 + 32; num_ptr_types], + ); // Next: For each function, generate these three entries. A weak constant, - // the llvm.rodata entry name, and the omp_offloading_entries value + // the llvm.rodata entry name, and the llvm_offload_entries value let name = format!(".kernel_{num}.region_id"); let initializer = cx.get_const_i8(0); @@ -222,19 +295,10 @@ fn gen_define_handling<'ll>( let llglobal = add_unnamed_global(&cx, &offload_entry_name, initializer, InternalLinkage); llvm::set_alignment(llglobal, Align::ONE); llvm::set_section(llglobal, c".llvm.rodata.offloading"); - - // Not actively used yet, for calling real kernels let name = format!(".offloading.entry.kernel_{num}"); // See the __tgt_offload_entry documentation above. - let reserved = cx.get_const_i64(0); - let version = cx.get_const_i16(1); - let kind = cx.get_const_i16(1); - let flags = cx.get_const_i32(0); - let size = cx.get_const_i64(0); - let data = cx.get_const_i64(0); - let aux_addr = cx.const_null(cx.type_ptr()); - let elems = vec![reserved, version, kind, flags, region_id, llglobal, size, data, aux_addr]; + let elems = TgtOffloadEntry::new(&cx, region_id, llglobal); let initializer = crate::common::named_struct(offload_entry_ty, &elems); let c_name = CString::new(name).unwrap(); @@ -242,13 +306,13 @@ fn gen_define_handling<'ll>( llvm::set_global_constant(llglobal, true); llvm::set_linkage(llglobal, WeakAnyLinkage); llvm::set_initializer(llglobal, initializer); - llvm::set_alignment(llglobal, Align::ONE); - let c_section_name = CString::new(".omp_offloading_entries").unwrap(); + llvm::set_alignment(llglobal, Align::EIGHT); + let c_section_name = CString::new("llvm_offload_entries").unwrap(); llvm::set_section(llglobal, &c_section_name); - o_types + (o_types, region_id) } -fn declare_offload_fn<'ll>( +pub(crate) fn declare_offload_fn<'ll>( cx: &'ll SimpleCx<'_>, name: &str, ty: &'ll llvm::Type, @@ -287,7 +351,9 @@ fn gen_call_handling<'ll>( cx: &'ll SimpleCx<'_>, _kernels: &[&'ll llvm::Value], o_types: &[&'ll llvm::Value], + region_ids: &[&'ll llvm::Value], ) { + let (tgt_decl, tgt_target_kernel_ty) = generate_launcher(&cx); // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr } let tptr = cx.type_ptr(); let ti32 = cx.type_i32(); @@ -295,7 +361,7 @@ fn gen_call_handling<'ll>( let tgt_bin_desc = cx.type_named_struct("struct.__tgt_bin_desc"); cx.set_struct_body(tgt_bin_desc, &tgt_bin_desc_ty, false); - gen_tgt_kernel_global(&cx); + let tgt_kernel_decl = KernelArgsTy::new_decl(&cx); let (begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers(&cx); let main_fn = cx.get_function("main"); @@ -329,35 +395,32 @@ fn gen_call_handling<'ll>( // These represent the sizes in bytes, e.g. the entry for `&[f64; 16]` will be 8*16. let ty2 = cx.type_array(cx.type_i64(), num_args); let a4 = builder.direct_alloca(ty2, Align::EIGHT, ".offload_sizes"); + + //%kernel_args = alloca %struct.__tgt_kernel_arguments, align 8 + let a5 = builder.direct_alloca(tgt_kernel_decl, Align::EIGHT, "kernel_args"); + + // Step 1) + unsafe { llvm::LLVMRustPositionBefore(builder.llbuilder, kernel_call) }; + builder.memset(tgt_bin_desc_alloca, cx.get_const_i8(0), cx.get_const_i64(32), Align::EIGHT); + // Now we allocate once per function param, a copy to be passed to one of our maps. let mut vals = vec![]; let mut geps = vec![]; let i32_0 = cx.get_const_i32(0); - for (index, in_ty) in types.iter().enumerate() { - // get function arg, store it into the alloca, and read it. - let p = llvm::get_param(called, index as u32); - let name = llvm::get_value_name(p); - let name = str::from_utf8(&name).unwrap(); - let arg_name = format!("{name}.addr"); - let alloca = builder.direct_alloca(in_ty, Align::EIGHT, &arg_name); - - builder.store(p, alloca, Align::EIGHT); - let val = builder.load(in_ty, alloca, Align::EIGHT); - let gep = builder.inbounds_gep(cx.type_f32(), val, &[i32_0]); - vals.push(val); + for index in 0..types.len() { + let v = unsafe { llvm::LLVMGetOperand(kernel_call, index as u32).unwrap() }; + let gep = builder.inbounds_gep(cx.type_f32(), v, &[i32_0]); + vals.push(v); geps.push(gep); } - // Step 1) - unsafe { llvm::LLVMRustPositionBefore(builder.llbuilder, kernel_call) }; - builder.memset(tgt_bin_desc_alloca, cx.get_const_i8(0), cx.get_const_i64(32), Align::EIGHT); - let mapper_fn_ty = cx.type_func(&[cx.type_ptr()], cx.type_void()); let register_lib_decl = declare_offload_fn(&cx, "__tgt_register_lib", mapper_fn_ty); let unregister_lib_decl = declare_offload_fn(&cx, "__tgt_unregister_lib", mapper_fn_ty); let init_ty = cx.type_func(&[], cx.type_void()); let init_rtls_decl = declare_offload_fn(cx, "__tgt_init_all_rtls", init_ty); + // FIXME(offload): Later we want to add them to the wrapper code, rather than our main function. // call void @__tgt_register_lib(ptr noundef %6) builder.call(mapper_fn_ty, register_lib_decl, &[tgt_bin_desc_alloca], None); // call void @__tgt_init_all_rtls() @@ -386,19 +449,19 @@ fn gen_call_handling<'ll>( a1: &'ll Value, a2: &'ll Value, a4: &'ll Value, - ) -> (&'ll Value, &'ll Value, &'ll Value) { + ) -> [&'ll Value; 3] { let i32_0 = cx.get_const_i32(0); let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]); let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, i32_0]); let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, i32_0]); - (gep1, gep2, gep3) + [gep1, gep2, gep3] } fn generate_mapper_call<'a, 'll>( builder: &mut SBuilder<'a, 'll>, cx: &'ll SimpleCx<'ll>, - geps: (&'ll Value, &'ll Value, &'ll Value), + geps: [&'ll Value; 3], o_type: &'ll Value, fn_to_call: &'ll Value, fn_ty: &'ll Type, @@ -409,7 +472,7 @@ fn gen_call_handling<'ll>( let i64_max = cx.get_const_i64(u64::MAX); let num_args = cx.get_const_i32(num_args); let args = - vec![s_ident_t, i64_max, num_args, geps.0, geps.1, geps.2, o_type, nullptr, nullptr]; + vec![s_ident_t, i64_max, num_args, geps[0], geps[1], geps[2], o_type, nullptr, nullptr]; builder.call(fn_ty, fn_to_call, &args, None); } @@ -418,22 +481,42 @@ fn gen_call_handling<'ll>( let o = o_types[0]; let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4); generate_mapper_call(&mut builder, &cx, geps, o, begin_mapper_decl, fn_ty, num_args, s_ident_t); + let values = KernelArgsTy::new(&cx, num_args, o_types, geps); // Step 3) - // Here we will add code for the actual kernel launches in a follow-up PR. - // FIXME(offload): launch kernels + // Here we fill the KernelArgsTy, see the documentation above + for (i, value) in values.iter().enumerate() { + let ptr = builder.inbounds_gep(tgt_kernel_decl, a5, &[i32_0, cx.get_const_i32(i as u64)]); + builder.store(value.1, ptr, value.0); + } - // Step 4) - unsafe { llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call) }; + let args = vec![ + s_ident_t, + // FIXME(offload) give users a way to select which GPU to use. + cx.get_const_i64(u64::MAX), // MAX == -1. + // FIXME(offload): Don't hardcode the numbers of threads in the future. + cx.get_const_i32(2097152), + cx.get_const_i32(256), + region_ids[0], + a5, + ]; + let offload_success = builder.call(tgt_target_kernel_ty, tgt_decl, &args, None); + // %41 = call i32 @__tgt_target_kernel(ptr @1, i64 -1, i32 2097152, i32 256, ptr @.kernel_1.region_id, ptr %kernel_args) + unsafe { + let next = llvm::LLVMGetNextInstruction(offload_success).unwrap(); + llvm::LLVMRustPositionAfter(builder.llbuilder, next); + llvm::LLVMInstructionEraseFromParent(next); + } + // Step 4) let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4); generate_mapper_call(&mut builder, &cx, geps, o, end_mapper_decl, fn_ty, num_args, s_ident_t); builder.call(mapper_fn_ty, unregister_lib_decl, &[tgt_bin_desc_alloca], None); - // With this we generated the following begin and end mappers. We could easily generate the - // update mapper in an update. - // call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 3, ptr %27, ptr %28, ptr %29, ptr @.offload_maptypes, ptr null, ptr null) - // call void @__tgt_target_data_update_mapper(ptr @1, i64 -1, i32 2, ptr %46, ptr %47, ptr %48, ptr @.offload_maptypes.1, ptr null, ptr null) - // call void @__tgt_target_data_end_mapper(ptr @1, i64 -1, i32 3, ptr %49, ptr %50, ptr %51, ptr @.offload_maptypes, ptr null, ptr null) + drop(builder); + // FIXME(offload) The issue is that we right now add a call to the gpu version of the function, + // and then delete the call to the CPU version. In the future, we should use an intrinsic which + // directly resolves to a call to the GPU version. + unsafe { llvm::LLVMDeleteFunction(called) }; } diff --git a/compiler/rustc_codegen_llvm/src/callee.rs b/compiler/rustc_codegen_llvm/src/callee.rs index 791a71d73ae58..b06b12043fa2a 100644 --- a/compiler/rustc_codegen_llvm/src/callee.rs +++ b/compiler/rustc_codegen_llvm/src/callee.rs @@ -31,6 +31,11 @@ pub(crate) fn get_fn<'ll, 'tcx>(cx: &CodegenCx<'ll, 'tcx>, instance: Instance<'t debug!("get_fn({:?}: {:?}) => {}", instance, instance.ty(cx.tcx(), cx.typing_env()), sym); let fn_abi = cx.fn_abi_of_instance(instance, ty::List::empty()); + if fn_abi.conv == rustc_abi::CanonAbi::GpuKernel { + dbg!("found gpu fn2!"); + } else { + dbg!("asdf2!"); + } let llfn = if let Some(llfn) = cx.get_declared_value(sym) { llfn @@ -69,6 +74,11 @@ pub(crate) fn get_fn<'ll, 'tcx>(cx: &CodegenCx<'ll, 'tcx>, instance: Instance<'t llvm::set_dllimport_storage_class(llfn); llfn } else { + //if Some(fn_abi) = fn_abi { + if fn_abi.conv == rustc_abi::CanonAbi::GpuKernel { + dbg!("found gpu fn!"); + } + //} cx.declare_fn(sym, fn_abi, Some(instance)) }; debug!("get_fn: not casting pointer!"); diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 2461f70a86e35..0d4f60f499015 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -1201,6 +1201,8 @@ unsafe extern "C" { // Operations on functions pub(crate) fn LLVMSetFunctionCallConv(Fn: &Value, CC: c_uint); + pub(crate) fn LLVMAddFunction<'a>(Mod: &'a Module, Name: *const c_char, FunctionTy: &'a Type) -> &'a Value; + pub(crate) fn LLVMDeleteFunction(Fn: &Value); // Operations about llvm intrinsics pub(crate) fn LLVMLookupIntrinsicID(Name: *const c_char, NameLen: size_t) -> c_uint; @@ -1218,6 +1220,7 @@ unsafe extern "C" { // Operations on basic blocks pub(crate) fn LLVMGetBasicBlockParent(BB: &BasicBlock) -> &Value; + pub(crate) fn LLVMAppendExistingBasicBlock<'a>(Fn: &'a Value, BB: &BasicBlock); pub(crate) fn LLVMAppendBasicBlockInContext<'a>( C: &'a Context, Fn: &'a Value, @@ -1230,6 +1233,8 @@ unsafe extern "C" { pub(crate) fn LLVMIsAInstruction(Val: &Value) -> Option<&Value>; pub(crate) fn LLVMGetFirstBasicBlock(Fn: &Value) -> &BasicBlock; pub(crate) fn LLVMGetOperand(Val: &Value, Index: c_uint) -> Option<&Value>; + pub(crate) fn LLVMGetNextInstruction(Val: &Value) -> Option<&Value>; + pub(crate) fn LLVMInstructionEraseFromParent(Val: &Value); // Operations on call sites pub(crate) fn LLVMSetInstructionCallConv(Instr: &Value, CC: c_uint); @@ -1889,6 +1894,10 @@ unsafe extern "C" { ) -> &Attribute; // Operations on functions + pub(crate) fn LLVMRustOffloadWrapper<'a>( + M: &'a Module, + Fn: &'a Value, + ); pub(crate) fn LLVMRustGetOrInsertFunction<'a>( M: &'a Module, Name: *const c_char, diff --git a/compiler/rustc_codegen_llvm/src/mono_item.rs b/compiler/rustc_codegen_llvm/src/mono_item.rs index 5075befae8a54..78139baf8b833 100644 --- a/compiler/rustc_codegen_llvm/src/mono_item.rs +++ b/compiler/rustc_codegen_llvm/src/mono_item.rs @@ -5,16 +5,28 @@ use rustc_hir::def_id::{DefId, LOCAL_CRATE}; use rustc_middle::bug; use rustc_middle::mir::mono::Visibility; use rustc_middle::ty::layout::{FnAbiOf, HasTypingEnv, LayoutOf}; -use rustc_middle::ty::{self, Instance, TypeVisitableExt}; +use rustc_middle::ty::{self, Ty, TyCtxt,Instance, TypeVisitableExt}; use rustc_session::config::CrateType; use rustc_target::spec::RelocModel; use tracing::debug; +use rustc_target::callconv::FnAbi; + use crate::context::CodegenCx; use crate::errors::SymbolAlreadyDefined; use crate::type_of::LayoutLlvmExt; use crate::{base, llvm}; +fn my_fn_abi<'tcx>(abi: &FnAbi<'tcx, Ty<'tcx>> ) -> FnAbi<'tcx, Ty<'tcx>> { + let mut myABI = abi.clone(); + dbg!(&myABI.args); + let val = myABI.clone().args[0].clone(); + myABI.args = Box::new([val.clone(), val.clone()]); + dbg!(&myABI.args); + myABI +} + + impl<'tcx> PreDefineCodegenMethods<'tcx> for CodegenCx<'_, 'tcx> { fn predefine_static( &mut self, @@ -54,7 +66,14 @@ impl<'tcx> PreDefineCodegenMethods<'tcx> for CodegenCx<'_, 'tcx> { assert!(!instance.args.has_infer()); let fn_abi = self.fn_abi_of_instance(instance, ty::List::empty()); - let lldecl = self.declare_fn(symbol_name, fn_abi, Some(instance)); + let fn_abi = if fn_abi.conv == rustc_abi::CanonAbi::GpuKernel { + dbg!("found gpu fn!"); + fn_abi.clone() + //my_fn_abi(fn_abi) + } else { + fn_abi.clone() + }; + let lldecl = self.declare_fn(symbol_name, &fn_abi, Some(instance)); llvm::set_linkage(lldecl, base::linkage_to_llvm(linkage)); let attrs = self.tcx.codegen_instance_attrs(instance.def); base::set_link_section(lldecl, &attrs); diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index e699e4b9c13f8..df4e249f0d789 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -1,5 +1,7 @@ #include "LLVMWrapper.h" +#include "llvm/Transforms/Utils/ValueMapper.h" +#include "llvm/Transforms/Utils/Cloning.h" #include "llvm-c/Analysis.h" #include "llvm-c/Core.h" #include "llvm-c/DebugInfo.h" @@ -170,6 +172,63 @@ extern "C" void LLVMRustPrintStatistics(RustStringRef OutBuf) { llvm::PrintStatistics(OS); } +extern "C" void LLVMRustOffloadWrapper(LLVMModuleRef M, LLVMValueRef Fn) { + // Convert to C++ types + llvm::Module *module = llvm::unwrap(M); + llvm::Function *oldFn = llvm::unwrap(Fn); + + if (oldFn->arg_size() > 0 && oldFn->getArg(0)->getName() == "dyn_ptr") { + return; + } + + // 1. Create new function type + llvm::LLVMContext &ctx = module->getContext(); + llvm::Type *dynPtrType = llvm::PointerType::get(ctx,0); + std::vector argTypes; + argTypes.push_back(dynPtrType); // First argument + + for (auto &arg : oldFn->args()) { + argTypes.push_back(arg.getType()); + } + + llvm::FunctionType *newFnType = llvm::FunctionType::get( + oldFn->getReturnType(), argTypes, oldFn->isVarArg() + ); + + + // 2. Create new function + llvm::Function *newFn = llvm::Function::Create( + //newFnType, oldFn->getLinkage(), oldFn->getName(), module + newFnType, oldFn->getLinkage(), oldFn->getName() + ".offload", module + ); + + // Map old arguments to new arguments (skip first argument) + llvm::ValueToValueMapTy vmap; + auto newArgIt = newFn->arg_begin(); + newArgIt->setName("dyn_ptr"); + ++newArgIt; // skip %dyn_ptr + for (auto &oldArg : oldFn->args()) { + vmap[&oldArg] = &*newArgIt++; + } + + // 2. Clone body + llvm::SmallVector returns; + llvm::CloneFunctionInto(newFn, oldFn, vmap, llvm::CloneFunctionChangeType::LocalChangesOnly, returns); + //llvm::CloneFunctionInto(newFn, oldFn, vmap, false, returns); + // + newFn->setLinkage(oldFn->getLinkage()); + newFn->setVisibility(oldFn->getVisibility()); + + // 3. Print new function + newFn->print(llvm::errs()); + + // Replace uses and delete old function + oldFn->replaceAllUsesWith(newFn); + auto name = oldFn->getName(); + oldFn->eraseFromParent(); + newFn->setName(name); +} + extern "C" LLVMValueRef LLVMRustGetNamedValue(LLVMModuleRef M, const char *Name, size_t NameLen) { return wrap(unwrap(M)->getNamedValue(StringRef(Name, NameLen))); diff --git a/compiler/rustc_target/src/callconv/mod.rs b/compiler/rustc_target/src/callconv/mod.rs index 5f2a6f7ba38a1..88225137588aa 100644 --- a/compiler/rustc_target/src/callconv/mod.rs +++ b/compiler/rustc_target/src/callconv/mod.rs @@ -577,6 +577,7 @@ impl RiscvInterruptKind { /// /// The signature represented by this type may not match the MIR function signature. /// Certain attributes, like `#[track_caller]` can introduce additional arguments, which are present in [`FnAbi`], but not in `FnSig`. +/// The std::offload module also adds an addition dyn_ptr argument to the GpuKernel ABI. /// While this difference is rarely relevant, it should still be kept in mind. /// /// I will do my best to describe this structure, but these diff --git a/tests/codegen-llvm/gpu_offload/gpu_host.rs b/tests/codegen-llvm/gpu_offload/gpu_host.rs index 513e27426bc0e..fac4054d1b7ff 100644 --- a/tests/codegen-llvm/gpu_offload/gpu_host.rs +++ b/tests/codegen-llvm/gpu_offload/gpu_host.rs @@ -21,16 +21,15 @@ fn main() { } // CHECK: %struct.__tgt_offload_entry = type { i64, i16, i16, i32, ptr, ptr, i64, i64, ptr } -// CHECK: %struct.__tgt_kernel_arguments = type { i32, i32, ptr, ptr, ptr, ptr, ptr, ptr, i64, i64, [3 x i32], [3 x i32], i32 } // CHECK: %struct.ident_t = type { i32, i32, i32, i32, ptr } // CHECK: %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr } +// CHECK: %struct.__tgt_kernel_arguments = type { i32, i32, ptr, ptr, ptr, ptr, ptr, ptr, i64, i64, [3 x i32], [3 x i32], i32 } // CHECK: @.offload_sizes.1 = private unnamed_addr constant [1 x i64] [i64 1024] -// CHECK: @.offload_maptypes.1 = private unnamed_addr constant [1 x i64] [i64 3] +// CHECK: @.offload_maptypes.1 = private unnamed_addr constant [1 x i64] [i64 35] // CHECK: @.kernel_1.region_id = weak unnamed_addr constant i8 0 // CHECK: @.offloading.entry_name.1 = internal unnamed_addr constant [9 x i8] c"kernel_1\00", section ".llvm.rodata.offloading", align 1 -// CHECK: @.offloading.entry.kernel_1 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.kernel_1.region_id, ptr @.offloading.entry_name.1, i64 0, i64 0, ptr null }, section ".omp_offloading_entries", align 1 -// CHECK: @my_struct_global2 = external global %struct.__tgt_kernel_arguments +// CHECK: @.offloading.entry.kernel_1 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.kernel_1.region_id, ptr @.offloading.entry_name.1, i64 0, i64 0, ptr null }, section "llvm_offload_entries", align 8 // CHECK: @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1 // CHECK: @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8 @@ -43,34 +42,61 @@ fn main() { // CHECK-NEXT: %.offload_baseptrs = alloca [1 x ptr], align 8 // CHECK-NEXT: %.offload_ptrs = alloca [1 x ptr], align 8 // CHECK-NEXT: %.offload_sizes = alloca [1 x i64], align 8 -// CHECK-NEXT: %x.addr = alloca ptr, align 8 -// CHECK-NEXT: store ptr %x, ptr %x.addr, align 8 -// CHECK-NEXT: %1 = load ptr, ptr %x.addr, align 8 -// CHECK-NEXT: %2 = getelementptr inbounds float, ptr %1, i32 0 +// CHECK-NEXT: %kernel_args = alloca %struct.__tgt_kernel_arguments, align 8 // CHECK: call void @llvm.memset.p0.i64(ptr align 8 %EmptyDesc, i8 0, i64 32, i1 false) +// CHECK-NEXT: %1 = getelementptr inbounds float, ptr %x, i32 0 // CHECK-NEXT: call void @__tgt_register_lib(ptr %EmptyDesc) // CHECK-NEXT: call void @__tgt_init_all_rtls() -// CHECK-NEXT: %3 = getelementptr inbounds [1 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 +// CHECK-NEXT: %2 = getelementptr inbounds [1 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 +// CHECK-NEXT: store ptr %x, ptr %2, align 8 +// CHECK-NEXT: %3 = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0 // CHECK-NEXT: store ptr %1, ptr %3, align 8 -// CHECK-NEXT: %4 = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0 -// CHECK-NEXT: store ptr %2, ptr %4, align 8 -// CHECK-NEXT: %5 = getelementptr inbounds [1 x i64], ptr %.offload_sizes, i32 0, i32 0 -// CHECK-NEXT: store i64 1024, ptr %5, align 8 -// CHECK-NEXT: %6 = getelementptr inbounds [1 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 -// CHECK-NEXT: %7 = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0 -// CHECK-NEXT: %8 = getelementptr inbounds [1 x i64], ptr %.offload_sizes, i32 0, i32 0 -// CHECK-NEXT: call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 1, ptr %6, ptr %7, ptr %8, ptr @.offload_maptypes.1, ptr null, ptr null) -// CHECK-NEXT: call void @kernel_1(ptr noalias noundef nonnull align 4 dereferenceable(1024) %x) -// CHECK-NEXT: %9 = getelementptr inbounds [1 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 -// CHECK-NEXT: %10 = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0 -// CHECK-NEXT: %11 = getelementptr inbounds [1 x i64], ptr %.offload_sizes, i32 0, i32 0 -// CHECK-NEXT: call void @__tgt_target_data_end_mapper(ptr @1, i64 -1, i32 1, ptr %9, ptr %10, ptr %11, ptr @.offload_maptypes.1, ptr null, ptr null) +// CHECK-NEXT: %4 = getelementptr inbounds [1 x i64], ptr %.offload_sizes, i32 0, i32 0 +// CHECK-NEXT: store i64 1024, ptr %4, align 8 +// CHECK-NEXT: %5 = getelementptr inbounds [1 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 +// CHECK-NEXT: %6 = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0 +// CHECK-NEXT: %7 = getelementptr inbounds [1 x i64], ptr %.offload_sizes, i32 0, i32 0 +// CHECK-NEXT: call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 1, ptr %5, ptr %6, ptr %7, ptr @.offload_maptypes.1, ptr null, ptr null) +// CHECK-NEXT: %8 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 0 +// CHECK-NEXT: store i32 3, ptr %8, align 4 +// CHECK-NEXT: %9 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 1 +// CHECK-NEXT: store i32 1, ptr %9, align 4 +// CHECK-NEXT: %10 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 2 +// CHECK-NEXT: store ptr %5, ptr %10, align 8 +// CHECK-NEXT: %11 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 3 +// CHECK-NEXT: store ptr %6, ptr %11, align 8 +// CHECK-NEXT: %12 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 4 +// CHECK-NEXT: store ptr %7, ptr %12, align 8 +// CHECK-NEXT: %13 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 5 +// CHECK-NEXT: store ptr @.offload_maptypes.1, ptr %13, align 8 +// CHECK-NEXT: %14 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 6 +// CHECK-NEXT: store ptr null, ptr %14, align 8 +// CHECK-NEXT: %15 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 7 +// CHECK-NEXT: store ptr null, ptr %15, align 8 +// CHECK-NEXT: %16 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 8 +// CHECK-NEXT: store i64 0, ptr %16, align 8 +// CHECK-NEXT: %17 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 9 +// CHECK-NEXT: store i64 0, ptr %17, align 8 +// CHECK-NEXT: %18 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 10 +// CHECK-NEXT: store [3 x i32] [i32 2097152, i32 0, i32 0], ptr %18, align 4 +// CHECK-NEXT: %19 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 11 +// CHECK-NEXT: store [3 x i32] [i32 256, i32 0, i32 0], ptr %19, align 4 +// CHECK-NEXT: %20 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 12 +// CHECK-NEXT: store i32 0, ptr %20, align 4 +// CHECK-NEXT: %21 = call i32 @__tgt_target_kernel(ptr @1, i64 -1, i32 2097152, i32 256, ptr @.kernel_1.region_id, ptr %kernel_args) +// CHECK-NEXT: %22 = getelementptr inbounds [1 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 +// CHECK-NEXT: %23 = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0 +// CHECK-NEXT: %24 = getelementptr inbounds [1 x i64], ptr %.offload_sizes, i32 0, i32 0 +// CHECK-NEXT: call void @__tgt_target_data_end_mapper(ptr @1, i64 -1, i32 1, ptr %22, ptr %23, ptr %24, ptr @.offload_maptypes.1, ptr null, ptr null) // CHECK-NEXT: call void @__tgt_unregister_lib(ptr %EmptyDesc) // CHECK: store ptr %x, ptr %0, align 8 // CHECK-NEXT: call void asm sideeffect "", "r,~{memory}"(ptr nonnull %0) // CHECK: ret void // CHECK-NEXT: } +// CHECK: Function Attrs: nounwind +// CHECK: declare i32 @__tgt_target_kernel(ptr, i64, i32, i32, ptr, ptr) + #[unsafe(no_mangle)] #[inline(never)] pub fn kernel_1(x: &mut [f32; 256]) {