From f70c6f4c3e35df56aff94414dd0108fd755418c0 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Thu, 21 Aug 2025 12:01:32 -0700 Subject: [PATCH 1/2] fix host code --- .../src/builder/gpu_offload.rs | 141 +++++++++++++----- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 3 + tests/codegen-llvm/gpu_offload/gpu_host.rs | 70 ++++++--- 3 files changed, 151 insertions(+), 63 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs index 1280ab1442a09..eae9034e3c608 100644 --- a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs +++ b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs @@ -16,23 +16,41 @@ pub(crate) fn handle_gpu_code<'ll>( cx: &'ll SimpleCx<'_>, ) { // The offload memory transfer type for each kernel - let mut o_types = vec![]; - let mut kernels = vec![]; + let mut memtransfer_types = vec![]; + let mut region_ids = vec![]; let offload_entry_ty = add_tgt_offload_entry(&cx); for num in 0..9 { let kernel = cx.get_function(&format!("kernel_{num}")); if let Some(kernel) = kernel { - o_types.push(gen_define_handling(&cx, kernel, offload_entry_ty, num)); - kernels.push(kernel); + let (o, k) = gen_define_handling(&cx, kernel, offload_entry_ty, num); + memtransfer_types.push(o); + region_ids.push(k); } } - gen_call_handling(&cx, &kernels, &o_types); + gen_call_handling(&cx, &memtransfer_types, ®ion_ids); +} + +// ; Function Attrs: nounwind +// declare i32 @__tgt_target_kernel(ptr, i64, i32, i32, ptr, ptr) #2 +fn generate_launcher<'ll>(cx: &'ll SimpleCx<'_>) -> (&'ll llvm::Value, &'ll llvm::Type) { + let tptr = cx.type_ptr(); + let ti64 = cx.type_i64(); + let ti32 = cx.type_i32(); + let args = vec![tptr, ti64, ti32, ti32, tptr, tptr]; + let tgt_fn_ty = cx.type_func(&args, ti32); + let name = "__tgt_target_kernel"; + let tgt_decl = declare_offload_fn(&cx, name, tgt_fn_ty); + let nounwind = llvm::AttributeKind::NoUnwind.create_attr(cx.llcx); + attributes::apply_to_llfn(tgt_decl, Function, &[nounwind]); + (tgt_decl, tgt_fn_ty) } // What is our @1 here? A magic global, used in our data_{begin/update/end}_mapper: // @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1 // @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8 +// FIXME(offload): @0 should include the file name (e.g. lib.rs) in which the function to be +// offloaded was defined. fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value { // @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1 let unknown_txt = ";unknown;unknown;0;0;;"; @@ -83,7 +101,7 @@ pub(crate) fn add_tgt_offload_entry<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Ty offload_entry_ty } -fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) { +fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type { let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments"); let tptr = cx.type_ptr(); let ti64 = cx.type_i64(); @@ -107,7 +125,7 @@ fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) { // uint64_t NoWait : 1; // Was this kernel spawned with a `nowait` clause. // uint64_t IsCUDA : 1; // Was this kernel spawned via CUDA. // uint64_t Unused : 62; - // } Flags = {0, 0, 0}; + // } Flags = {0, 0, 0}; // totals to 64 Bit, 8 Byte // // The number of teams (for x,y,z dimension). // uint32_t NumTeams[3] = {0, 0, 0}; // // The number of threads (for x,y,z dimension). @@ -118,9 +136,7 @@ fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) { vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32]; cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false); - // For now we don't handle kernels, so for now we just add a global dummy - // to make sure that the __tgt_offload_entry is defined and handled correctly. - cx.declare_global("my_struct_global2", kernel_arguments_ty); + kernel_arguments_ty } fn gen_tgt_data_mappers<'ll>( @@ -187,7 +203,7 @@ fn gen_define_handling<'ll>( kernel: &'ll llvm::Value, offload_entry_ty: &'ll llvm::Type, num: i64, -) -> &'ll llvm::Value { +) -> (&'ll llvm::Value, &'ll llvm::Value) { let types = cx.func_params_types(cx.get_type_of_global(kernel)); // It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or // reference) types. @@ -205,10 +221,14 @@ fn gen_define_handling<'ll>( // or both to and from the gpu (=3). Other values shouldn't affect us for now. // A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten // will be 2. For now, everything is 3, until we have our frontend set up. - let o_types = - add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{num}"), &vec![3; num_ptr_types]); + // 1+2+32: 1 (MapTo), 2 (MapFrom), 32 (Add one extra input ptr per function, to be used later). + let memtransfer_types = add_priv_unnamed_arr( + &cx, + &format!(".offload_maptypes.{num}"), + &vec![1 + 2 + 32; num_ptr_types], + ); // Next: For each function, generate these three entries. A weak constant, - // the llvm.rodata entry name, and the omp_offloading_entries value + // the llvm.rodata entry name, and the llvm_offload_entries value let name = format!(".kernel_{num}.region_id"); let initializer = cx.get_const_i8(0); @@ -242,13 +262,13 @@ fn gen_define_handling<'ll>( llvm::set_global_constant(llglobal, true); llvm::set_linkage(llglobal, WeakAnyLinkage); llvm::set_initializer(llglobal, initializer); - llvm::set_alignment(llglobal, Align::ONE); - let c_section_name = CString::new(".omp_offloading_entries").unwrap(); + llvm::set_alignment(llglobal, Align::EIGHT); + let c_section_name = CString::new("llvm_offload_entries").unwrap(); llvm::set_section(llglobal, &c_section_name); - o_types + (memtransfer_types, region_id) } -fn declare_offload_fn<'ll>( +pub(crate) fn declare_offload_fn<'ll>( cx: &'ll SimpleCx<'_>, name: &str, ty: &'ll llvm::Type, @@ -285,9 +305,10 @@ fn declare_offload_fn<'ll>( // 6. generate __tgt_target_data_end calls to move data from the GPU fn gen_call_handling<'ll>( cx: &'ll SimpleCx<'_>, - _kernels: &[&'ll llvm::Value], - o_types: &[&'ll llvm::Value], + memtransfer_types: &[&'ll llvm::Value], + region_ids: &[&'ll llvm::Value], ) { + let (tgt_decl, tgt_target_kernel_ty) = generate_launcher(&cx); // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr } let tptr = cx.type_ptr(); let ti32 = cx.type_i32(); @@ -295,7 +316,7 @@ fn gen_call_handling<'ll>( let tgt_bin_desc = cx.type_named_struct("struct.__tgt_bin_desc"); cx.set_struct_body(tgt_bin_desc, &tgt_bin_desc_ty, false); - gen_tgt_kernel_global(&cx); + let tgt_kernel_decl = gen_tgt_kernel_global(&cx); let (begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers(&cx); let main_fn = cx.get_function("main"); @@ -329,35 +350,32 @@ fn gen_call_handling<'ll>( // These represent the sizes in bytes, e.g. the entry for `&[f64; 16]` will be 8*16. let ty2 = cx.type_array(cx.type_i64(), num_args); let a4 = builder.direct_alloca(ty2, Align::EIGHT, ".offload_sizes"); + + //%kernel_args = alloca %struct.__tgt_kernel_arguments, align 8 + let a5 = builder.direct_alloca(tgt_kernel_decl, Align::EIGHT, "kernel_args"); + + // Step 1) + unsafe { llvm::LLVMRustPositionBefore(builder.llbuilder, kernel_call) }; + builder.memset(tgt_bin_desc_alloca, cx.get_const_i8(0), cx.get_const_i64(32), Align::EIGHT); + // Now we allocate once per function param, a copy to be passed to one of our maps. let mut vals = vec![]; let mut geps = vec![]; let i32_0 = cx.get_const_i32(0); - for (index, in_ty) in types.iter().enumerate() { - // get function arg, store it into the alloca, and read it. - let p = llvm::get_param(called, index as u32); - let name = llvm::get_value_name(p); - let name = str::from_utf8(&name).unwrap(); - let arg_name = format!("{name}.addr"); - let alloca = builder.direct_alloca(in_ty, Align::EIGHT, &arg_name); - - builder.store(p, alloca, Align::EIGHT); - let val = builder.load(in_ty, alloca, Align::EIGHT); - let gep = builder.inbounds_gep(cx.type_f32(), val, &[i32_0]); - vals.push(val); + for index in 0..types.len() { + let v = unsafe { llvm::LLVMGetOperand(kernel_call, index as u32).unwrap() }; + let gep = builder.inbounds_gep(cx.type_f32(), v, &[i32_0]); + vals.push(v); geps.push(gep); } - // Step 1) - unsafe { llvm::LLVMRustPositionBefore(builder.llbuilder, kernel_call) }; - builder.memset(tgt_bin_desc_alloca, cx.get_const_i8(0), cx.get_const_i64(32), Align::EIGHT); - let mapper_fn_ty = cx.type_func(&[cx.type_ptr()], cx.type_void()); let register_lib_decl = declare_offload_fn(&cx, "__tgt_register_lib", mapper_fn_ty); let unregister_lib_decl = declare_offload_fn(&cx, "__tgt_unregister_lib", mapper_fn_ty); let init_ty = cx.type_func(&[], cx.type_void()); let init_rtls_decl = declare_offload_fn(cx, "__tgt_init_all_rtls", init_ty); + // FIXME(offload): Later we want to add them to the wrapper code, rather than our main function. // call void @__tgt_register_lib(ptr noundef %6) builder.call(mapper_fn_ty, register_lib_decl, &[tgt_bin_desc_alloca], None); // call void @__tgt_init_all_rtls() @@ -415,22 +433,63 @@ fn gen_call_handling<'ll>( // Step 2) let s_ident_t = generate_at_one(&cx); - let o = o_types[0]; + let o = memtransfer_types[0]; let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4); generate_mapper_call(&mut builder, &cx, geps, o, begin_mapper_decl, fn_ty, num_args, s_ident_t); // Step 3) - // Here we will add code for the actual kernel launches in a follow-up PR. - // FIXME(offload): launch kernels + let mut values = vec![]; + let offload_version = cx.get_const_i32(3); + values.push((4, offload_version)); + values.push((4, cx.get_const_i32(num_args))); + values.push((8, geps.0)); + values.push((8, geps.1)); + values.push((8, geps.2)); + values.push((8, memtransfer_types[0])); + // The next two are debug infos. FIXME(offload) set them + values.push((8, cx.const_null(cx.type_ptr()))); + values.push((8, cx.const_null(cx.type_ptr()))); + values.push((8, cx.get_const_i64(0))); + values.push((8, cx.get_const_i64(0))); + let ti32 = cx.type_i32(); + let ci32_0 = cx.get_const_i32(0); + values.push((4, cx.const_array(ti32, &vec![cx.get_const_i32(2097152), ci32_0, ci32_0]))); + values.push((4, cx.const_array(ti32, &vec![cx.get_const_i32(256), ci32_0, ci32_0]))); + values.push((4, cx.get_const_i32(0))); + + for (i, value) in values.iter().enumerate() { + let ptr = builder.inbounds_gep(tgt_kernel_decl, a5, &[i32_0, cx.get_const_i32(i as u64)]); + builder.store(value.1, ptr, Align::from_bytes(value.0).unwrap()); + } + + let args = vec![ + s_ident_t, + // MAX == -1 + cx.get_const_i64(u64::MAX), + cx.get_const_i32(2097152), + cx.get_const_i32(256), + region_ids[0], + a5, + ]; + let offload_success = builder.call(tgt_target_kernel_ty, tgt_decl, &args, None); + // %41 = call i32 @__tgt_target_kernel(ptr @1, i64 -1, i32 2097152, i32 256, ptr @.kernel_1.region_id, ptr %kernel_args) + unsafe { + let next = llvm::LLVMGetNextInstruction(offload_success).unwrap(); + llvm::LLVMRustPositionAfter(builder.llbuilder, next); + llvm::LLVMInstructionEraseFromParent(next); + } // Step 4) - unsafe { llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call) }; + //unsafe { llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call) }; let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4); generate_mapper_call(&mut builder, &cx, geps, o, end_mapper_decl, fn_ty, num_args, s_ident_t); builder.call(mapper_fn_ty, unregister_lib_decl, &[tgt_bin_desc_alloca], None); + drop(builder); + unsafe { llvm::LLVMDeleteFunction(called) }; + // With this we generated the following begin and end mappers. We could easily generate the // update mapper in an update. // call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 3, ptr %27, ptr %28, ptr %29, ptr @.offload_maptypes, ptr null, ptr null) diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 2461f70a86e35..5dead7f4e7ee5 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -1201,6 +1201,7 @@ unsafe extern "C" { // Operations on functions pub(crate) fn LLVMSetFunctionCallConv(Fn: &Value, CC: c_uint); + pub(crate) fn LLVMDeleteFunction(Fn: &Value); // Operations about llvm intrinsics pub(crate) fn LLVMLookupIntrinsicID(Name: *const c_char, NameLen: size_t) -> c_uint; @@ -1230,6 +1231,8 @@ unsafe extern "C" { pub(crate) fn LLVMIsAInstruction(Val: &Value) -> Option<&Value>; pub(crate) fn LLVMGetFirstBasicBlock(Fn: &Value) -> &BasicBlock; pub(crate) fn LLVMGetOperand(Val: &Value, Index: c_uint) -> Option<&Value>; + pub(crate) fn LLVMGetNextInstruction(Val: &Value) -> Option<&Value>; + pub(crate) fn LLVMInstructionEraseFromParent(Val: &Value); // Operations on call sites pub(crate) fn LLVMSetInstructionCallConv(Instr: &Value, CC: c_uint); diff --git a/tests/codegen-llvm/gpu_offload/gpu_host.rs b/tests/codegen-llvm/gpu_offload/gpu_host.rs index 513e27426bc0e..fac4054d1b7ff 100644 --- a/tests/codegen-llvm/gpu_offload/gpu_host.rs +++ b/tests/codegen-llvm/gpu_offload/gpu_host.rs @@ -21,16 +21,15 @@ fn main() { } // CHECK: %struct.__tgt_offload_entry = type { i64, i16, i16, i32, ptr, ptr, i64, i64, ptr } -// CHECK: %struct.__tgt_kernel_arguments = type { i32, i32, ptr, ptr, ptr, ptr, ptr, ptr, i64, i64, [3 x i32], [3 x i32], i32 } // CHECK: %struct.ident_t = type { i32, i32, i32, i32, ptr } // CHECK: %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr } +// CHECK: %struct.__tgt_kernel_arguments = type { i32, i32, ptr, ptr, ptr, ptr, ptr, ptr, i64, i64, [3 x i32], [3 x i32], i32 } // CHECK: @.offload_sizes.1 = private unnamed_addr constant [1 x i64] [i64 1024] -// CHECK: @.offload_maptypes.1 = private unnamed_addr constant [1 x i64] [i64 3] +// CHECK: @.offload_maptypes.1 = private unnamed_addr constant [1 x i64] [i64 35] // CHECK: @.kernel_1.region_id = weak unnamed_addr constant i8 0 // CHECK: @.offloading.entry_name.1 = internal unnamed_addr constant [9 x i8] c"kernel_1\00", section ".llvm.rodata.offloading", align 1 -// CHECK: @.offloading.entry.kernel_1 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.kernel_1.region_id, ptr @.offloading.entry_name.1, i64 0, i64 0, ptr null }, section ".omp_offloading_entries", align 1 -// CHECK: @my_struct_global2 = external global %struct.__tgt_kernel_arguments +// CHECK: @.offloading.entry.kernel_1 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.kernel_1.region_id, ptr @.offloading.entry_name.1, i64 0, i64 0, ptr null }, section "llvm_offload_entries", align 8 // CHECK: @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1 // CHECK: @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8 @@ -43,34 +42,61 @@ fn main() { // CHECK-NEXT: %.offload_baseptrs = alloca [1 x ptr], align 8 // CHECK-NEXT: %.offload_ptrs = alloca [1 x ptr], align 8 // CHECK-NEXT: %.offload_sizes = alloca [1 x i64], align 8 -// CHECK-NEXT: %x.addr = alloca ptr, align 8 -// CHECK-NEXT: store ptr %x, ptr %x.addr, align 8 -// CHECK-NEXT: %1 = load ptr, ptr %x.addr, align 8 -// CHECK-NEXT: %2 = getelementptr inbounds float, ptr %1, i32 0 +// CHECK-NEXT: %kernel_args = alloca %struct.__tgt_kernel_arguments, align 8 // CHECK: call void @llvm.memset.p0.i64(ptr align 8 %EmptyDesc, i8 0, i64 32, i1 false) +// CHECK-NEXT: %1 = getelementptr inbounds float, ptr %x, i32 0 // CHECK-NEXT: call void @__tgt_register_lib(ptr %EmptyDesc) // CHECK-NEXT: call void @__tgt_init_all_rtls() -// CHECK-NEXT: %3 = getelementptr inbounds [1 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 +// CHECK-NEXT: %2 = getelementptr inbounds [1 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 +// CHECK-NEXT: store ptr %x, ptr %2, align 8 +// CHECK-NEXT: %3 = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0 // CHECK-NEXT: store ptr %1, ptr %3, align 8 -// CHECK-NEXT: %4 = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0 -// CHECK-NEXT: store ptr %2, ptr %4, align 8 -// CHECK-NEXT: %5 = getelementptr inbounds [1 x i64], ptr %.offload_sizes, i32 0, i32 0 -// CHECK-NEXT: store i64 1024, ptr %5, align 8 -// CHECK-NEXT: %6 = getelementptr inbounds [1 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 -// CHECK-NEXT: %7 = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0 -// CHECK-NEXT: %8 = getelementptr inbounds [1 x i64], ptr %.offload_sizes, i32 0, i32 0 -// CHECK-NEXT: call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 1, ptr %6, ptr %7, ptr %8, ptr @.offload_maptypes.1, ptr null, ptr null) -// CHECK-NEXT: call void @kernel_1(ptr noalias noundef nonnull align 4 dereferenceable(1024) %x) -// CHECK-NEXT: %9 = getelementptr inbounds [1 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 -// CHECK-NEXT: %10 = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0 -// CHECK-NEXT: %11 = getelementptr inbounds [1 x i64], ptr %.offload_sizes, i32 0, i32 0 -// CHECK-NEXT: call void @__tgt_target_data_end_mapper(ptr @1, i64 -1, i32 1, ptr %9, ptr %10, ptr %11, ptr @.offload_maptypes.1, ptr null, ptr null) +// CHECK-NEXT: %4 = getelementptr inbounds [1 x i64], ptr %.offload_sizes, i32 0, i32 0 +// CHECK-NEXT: store i64 1024, ptr %4, align 8 +// CHECK-NEXT: %5 = getelementptr inbounds [1 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 +// CHECK-NEXT: %6 = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0 +// CHECK-NEXT: %7 = getelementptr inbounds [1 x i64], ptr %.offload_sizes, i32 0, i32 0 +// CHECK-NEXT: call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 1, ptr %5, ptr %6, ptr %7, ptr @.offload_maptypes.1, ptr null, ptr null) +// CHECK-NEXT: %8 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 0 +// CHECK-NEXT: store i32 3, ptr %8, align 4 +// CHECK-NEXT: %9 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 1 +// CHECK-NEXT: store i32 1, ptr %9, align 4 +// CHECK-NEXT: %10 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 2 +// CHECK-NEXT: store ptr %5, ptr %10, align 8 +// CHECK-NEXT: %11 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 3 +// CHECK-NEXT: store ptr %6, ptr %11, align 8 +// CHECK-NEXT: %12 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 4 +// CHECK-NEXT: store ptr %7, ptr %12, align 8 +// CHECK-NEXT: %13 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 5 +// CHECK-NEXT: store ptr @.offload_maptypes.1, ptr %13, align 8 +// CHECK-NEXT: %14 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 6 +// CHECK-NEXT: store ptr null, ptr %14, align 8 +// CHECK-NEXT: %15 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 7 +// CHECK-NEXT: store ptr null, ptr %15, align 8 +// CHECK-NEXT: %16 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 8 +// CHECK-NEXT: store i64 0, ptr %16, align 8 +// CHECK-NEXT: %17 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 9 +// CHECK-NEXT: store i64 0, ptr %17, align 8 +// CHECK-NEXT: %18 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 10 +// CHECK-NEXT: store [3 x i32] [i32 2097152, i32 0, i32 0], ptr %18, align 4 +// CHECK-NEXT: %19 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 11 +// CHECK-NEXT: store [3 x i32] [i32 256, i32 0, i32 0], ptr %19, align 4 +// CHECK-NEXT: %20 = getelementptr inbounds %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 12 +// CHECK-NEXT: store i32 0, ptr %20, align 4 +// CHECK-NEXT: %21 = call i32 @__tgt_target_kernel(ptr @1, i64 -1, i32 2097152, i32 256, ptr @.kernel_1.region_id, ptr %kernel_args) +// CHECK-NEXT: %22 = getelementptr inbounds [1 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 +// CHECK-NEXT: %23 = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0 +// CHECK-NEXT: %24 = getelementptr inbounds [1 x i64], ptr %.offload_sizes, i32 0, i32 0 +// CHECK-NEXT: call void @__tgt_target_data_end_mapper(ptr @1, i64 -1, i32 1, ptr %22, ptr %23, ptr %24, ptr @.offload_maptypes.1, ptr null, ptr null) // CHECK-NEXT: call void @__tgt_unregister_lib(ptr %EmptyDesc) // CHECK: store ptr %x, ptr %0, align 8 // CHECK-NEXT: call void asm sideeffect "", "r,~{memory}"(ptr nonnull %0) // CHECK: ret void // CHECK-NEXT: } +// CHECK: Function Attrs: nounwind +// CHECK: declare i32 @__tgt_target_kernel(ptr, i64, i32, i32, ptr, ptr) + #[unsafe(no_mangle)] #[inline(never)] pub fn kernel_1(x: &mut [f32; 256]) { From 0f05703ed77b31dc4045b9a1bdedad4818fe04a0 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sun, 31 Aug 2025 15:17:35 -0700 Subject: [PATCH 2/2] model offload C++ structs through Rust structs --- .../src/builder/gpu_offload.rs | 171 ++++++++++-------- 1 file changed, 96 insertions(+), 75 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs index eae9034e3c608..559180de3fe55 100644 --- a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs +++ b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs @@ -18,7 +18,7 @@ pub(crate) fn handle_gpu_code<'ll>( // The offload memory transfer type for each kernel let mut memtransfer_types = vec![]; let mut region_ids = vec![]; - let offload_entry_ty = add_tgt_offload_entry(&cx); + let offload_entry_ty = TgtOffloadEntry::new_decl(&cx); for num in 0..9 { let kernel = cx.get_function(&format!("kernel_{num}")); if let Some(kernel) = kernel { @@ -52,7 +52,6 @@ fn generate_launcher<'ll>(cx: &'ll SimpleCx<'_>) -> (&'ll llvm::Value, &'ll llvm // FIXME(offload): @0 should include the file name (e.g. lib.rs) in which the function to be // offloaded was defined. fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value { - // @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1 let unknown_txt = ";unknown;unknown;0;0;;"; let c_entry_name = CString::new(unknown_txt).unwrap(); let c_val = c_entry_name.as_bytes_with_nul(); @@ -77,15 +76,7 @@ fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value { at_one } -pub(crate) fn add_tgt_offload_entry<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type { - let offload_entry_ty = cx.type_named_struct("struct.__tgt_offload_entry"); - let tptr = cx.type_ptr(); - let ti64 = cx.type_i64(); - let ti32 = cx.type_i32(); - let ti16 = cx.type_i16(); - // For each kernel to run on the gpu, we will later generate one entry of this type. - // copied from LLVM - // typedef struct { +struct TgtOffloadEntry { // uint64_t Reserved; // uint16_t Version; // uint16_t Kind; @@ -95,21 +86,40 @@ pub(crate) fn add_tgt_offload_entry<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Ty // uint64_t Size; Size of the entry info (0 if it is a function) // uint64_t Data; // void *AuxAddr; - // } __tgt_offload_entry; - let entry_elements = vec![ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr]; - cx.set_struct_body(offload_entry_ty, &entry_elements, false); - offload_entry_ty } -fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type { - let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments"); - let tptr = cx.type_ptr(); - let ti64 = cx.type_i64(); - let ti32 = cx.type_i32(); - let tarr = cx.type_array(ti32, 3); +impl TgtOffloadEntry { + pub(crate) fn new_decl<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type { + let offload_entry_ty = cx.type_named_struct("struct.__tgt_offload_entry"); + let tptr = cx.type_ptr(); + let ti64 = cx.type_i64(); + let ti32 = cx.type_i32(); + let ti16 = cx.type_i16(); + // For each kernel to run on the gpu, we will later generate one entry of this type. + // copied from LLVM + let entry_elements = vec![ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr]; + cx.set_struct_body(offload_entry_ty, &entry_elements, false); + offload_entry_ty + } + + fn new<'ll>( + cx: &'ll SimpleCx<'_>, + region_id: &'ll Value, + llglobal: &'ll Value, + ) -> Vec<&'ll Value> { + let reserved = cx.get_const_i64(0); + let version = cx.get_const_i16(1); + let kind = cx.get_const_i16(1); + let flags = cx.get_const_i32(0); + let size = cx.get_const_i64(0); + let data = cx.get_const_i64(0); + let aux_addr = cx.const_null(cx.type_ptr()); + vec![reserved, version, kind, flags, region_id, llglobal, size, data, aux_addr] + } +} - // Taken from the LLVM APITypes.h declaration: - //struct KernelArgsTy { +// Taken from the LLVM APITypes.h declaration: +struct KernelArgsTy { // uint32_t Version = 0; // Version of this struct for ABI compatibility. // uint32_t NumArgs = 0; // Number of arguments in each input pointer. // void **ArgBasePtrs = @@ -120,8 +130,8 @@ fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type { // void **ArgNames = nullptr; // Name of the data for debugging, possibly null. // void **ArgMappers = nullptr; // User-defined mappers, possibly null. // uint64_t Tripcount = - // 0; // Tripcount for the teams / distribute loop, 0 otherwise. - // struct { + // 0; // Tripcount for the teams / distribute loop, 0 otherwise. + // struct { // uint64_t NoWait : 1; // Was this kernel spawned with a `nowait` clause. // uint64_t IsCUDA : 1; // Was this kernel spawned via CUDA. // uint64_t Unused : 62; @@ -131,12 +141,53 @@ fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type { // // The number of threads (for x,y,z dimension). // uint32_t ThreadLimit[3] = {0, 0, 0}; // uint32_t DynCGroupMem = 0; // Amount of dynamic cgroup memory requested. - //}; - let kernel_elements = - vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32]; +} + +impl KernelArgsTy { + const OFFLOAD_VERSION: u64 = 3; + const FLAGS: u64 = 0; + const TRIPCOUNT: u64 = 0; + fn new_decl<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll Type { + let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments"); + let tptr = cx.type_ptr(); + let ti64 = cx.type_i64(); + let ti32 = cx.type_i32(); + let tarr = cx.type_array(ti32, 3); + + let kernel_elements = + vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32]; + + cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false); + kernel_arguments_ty + } - cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false); - kernel_arguments_ty + fn new<'ll>( + cx: &'ll SimpleCx<'_>, + num_args: u64, + memtransfer_types: &[&'ll Value], + geps: [&'ll Value; 3], + ) -> [(Align, &'ll Value); 13] { + let four = Align::from_bytes(4).expect("4 Byte alignment should work"); + let eight = Align::EIGHT; + let mut values = vec![]; + values.push((four, cx.get_const_i32(KernelArgsTy::OFFLOAD_VERSION))); + values.push((four, cx.get_const_i32(num_args))); + values.push((eight, geps[0])); + values.push((eight, geps[1])); + values.push((eight, geps[2])); + values.push((eight, memtransfer_types[0])); + // The next two are debug infos. FIXME(offload): set them + values.push((eight, cx.const_null(cx.type_ptr()))); + values.push((eight, cx.const_null(cx.type_ptr()))); + values.push((eight, cx.get_const_i64(KernelArgsTy::TRIPCOUNT))); + values.push((eight, cx.get_const_i64(KernelArgsTy::FLAGS))); + let ti32 = cx.type_i32(); + let ci32_0 = cx.get_const_i32(0); + values.push((four, cx.const_array(ti32, &vec![cx.get_const_i32(2097152), ci32_0, ci32_0]))); + values.push((four, cx.const_array(ti32, &vec![cx.get_const_i32(256), ci32_0, ci32_0]))); + values.push((four, cx.get_const_i32(0))); + values.try_into().expect("tgt_kernel_arguments construction failed") + } } fn gen_tgt_data_mappers<'ll>( @@ -242,19 +293,10 @@ fn gen_define_handling<'ll>( let llglobal = add_unnamed_global(&cx, &offload_entry_name, initializer, InternalLinkage); llvm::set_alignment(llglobal, Align::ONE); llvm::set_section(llglobal, c".llvm.rodata.offloading"); - - // Not actively used yet, for calling real kernels let name = format!(".offloading.entry.kernel_{num}"); // See the __tgt_offload_entry documentation above. - let reserved = cx.get_const_i64(0); - let version = cx.get_const_i16(1); - let kind = cx.get_const_i16(1); - let flags = cx.get_const_i32(0); - let size = cx.get_const_i64(0); - let data = cx.get_const_i64(0); - let aux_addr = cx.const_null(cx.type_ptr()); - let elems = vec![reserved, version, kind, flags, region_id, llglobal, size, data, aux_addr]; + let elems = TgtOffloadEntry::new(&cx, region_id, llglobal); let initializer = crate::common::named_struct(offload_entry_ty, &elems); let c_name = CString::new(name).unwrap(); @@ -316,7 +358,7 @@ fn gen_call_handling<'ll>( let tgt_bin_desc = cx.type_named_struct("struct.__tgt_bin_desc"); cx.set_struct_body(tgt_bin_desc, &tgt_bin_desc_ty, false); - let tgt_kernel_decl = gen_tgt_kernel_global(&cx); + let tgt_kernel_decl = KernelArgsTy::new_decl(&cx); let (begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers(&cx); let main_fn = cx.get_function("main"); @@ -404,19 +446,19 @@ fn gen_call_handling<'ll>( a1: &'ll Value, a2: &'ll Value, a4: &'ll Value, - ) -> (&'ll Value, &'ll Value, &'ll Value) { + ) -> [&'ll Value; 3] { let i32_0 = cx.get_const_i32(0); let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]); let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, i32_0]); let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, i32_0]); - (gep1, gep2, gep3) + [gep1, gep2, gep3] } fn generate_mapper_call<'a, 'll>( builder: &mut SBuilder<'a, 'll>, cx: &'ll SimpleCx<'ll>, - geps: (&'ll Value, &'ll Value, &'ll Value), + geps: [&'ll Value; 3], o_type: &'ll Value, fn_to_call: &'ll Value, fn_ty: &'ll Type, @@ -427,7 +469,7 @@ fn gen_call_handling<'ll>( let i64_max = cx.get_const_i64(u64::MAX); let num_args = cx.get_const_i32(num_args); let args = - vec![s_ident_t, i64_max, num_args, geps.0, geps.1, geps.2, o_type, nullptr, nullptr]; + vec![s_ident_t, i64_max, num_args, geps[0], geps[1], geps[2], o_type, nullptr, nullptr]; builder.call(fn_ty, fn_to_call, &args, None); } @@ -436,36 +478,20 @@ fn gen_call_handling<'ll>( let o = memtransfer_types[0]; let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4); generate_mapper_call(&mut builder, &cx, geps, o, begin_mapper_decl, fn_ty, num_args, s_ident_t); + let values = KernelArgsTy::new(&cx, num_args, memtransfer_types, geps); // Step 3) - let mut values = vec![]; - let offload_version = cx.get_const_i32(3); - values.push((4, offload_version)); - values.push((4, cx.get_const_i32(num_args))); - values.push((8, geps.0)); - values.push((8, geps.1)); - values.push((8, geps.2)); - values.push((8, memtransfer_types[0])); - // The next two are debug infos. FIXME(offload) set them - values.push((8, cx.const_null(cx.type_ptr()))); - values.push((8, cx.const_null(cx.type_ptr()))); - values.push((8, cx.get_const_i64(0))); - values.push((8, cx.get_const_i64(0))); - let ti32 = cx.type_i32(); - let ci32_0 = cx.get_const_i32(0); - values.push((4, cx.const_array(ti32, &vec![cx.get_const_i32(2097152), ci32_0, ci32_0]))); - values.push((4, cx.const_array(ti32, &vec![cx.get_const_i32(256), ci32_0, ci32_0]))); - values.push((4, cx.get_const_i32(0))); - + // Here we fill the KernelArgsTy, see the documentation above for (i, value) in values.iter().enumerate() { let ptr = builder.inbounds_gep(tgt_kernel_decl, a5, &[i32_0, cx.get_const_i32(i as u64)]); - builder.store(value.1, ptr, Align::from_bytes(value.0).unwrap()); + builder.store(value.1, ptr, value.0); } let args = vec![ s_ident_t, - // MAX == -1 - cx.get_const_i64(u64::MAX), + // FIXME(offload) give users a way to select which GPU to use. + cx.get_const_i64(u64::MAX), // MAX == -1. + // FIXME(offload): Don't hardcode the numbers of threads in the future. cx.get_const_i32(2097152), cx.get_const_i32(256), region_ids[0], @@ -480,19 +506,14 @@ fn gen_call_handling<'ll>( } // Step 4) - //unsafe { llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call) }; - let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4); generate_mapper_call(&mut builder, &cx, geps, o, end_mapper_decl, fn_ty, num_args, s_ident_t); builder.call(mapper_fn_ty, unregister_lib_decl, &[tgt_bin_desc_alloca], None); drop(builder); + // FIXME(offload) The issue is that we right now add a call to the gpu version of the function, + // and then delete the call to the CPU version. In the future, we should use an intrinsic which + // directly resolves to a call to the GPU version. unsafe { llvm::LLVMDeleteFunction(called) }; - - // With this we generated the following begin and end mappers. We could easily generate the - // update mapper in an update. - // call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 3, ptr %27, ptr %28, ptr %29, ptr @.offload_maptypes, ptr null, ptr null) - // call void @__tgt_target_data_update_mapper(ptr @1, i64 -1, i32 2, ptr %46, ptr %47, ptr %48, ptr @.offload_maptypes.1, ptr null, ptr null) - // call void @__tgt_target_data_end_mapper(ptr @1, i64 -1, i32 3, ptr %49, ptr %50, ptr %51, ptr @.offload_maptypes, ptr null, ptr null) }