@@ -653,6 +653,87 @@ pub(crate) unsafe fn llvm_optimize(
653
653
None
654
654
} ;
655
655
656
+ fn handle_offload ( m : & llvm:: Module , llcx : & llvm:: Context , old_fn : & llvm:: Value ) {
657
+ unsafe { llvm:: LLVMRustOffloadWrapper ( m, old_fn) } ;
658
+ //unsafe {llvm::LLVMDumpModule(m);}
659
+ //unsafe {
660
+ // // Get the old function type
661
+ // let old_fn_ty = llvm::LLVMGlobalGetValueType(old_fn);
662
+ // dbg!(&old_fn_ty);
663
+ // let old_param_count = llvm::LLVMCountParamTypes(old_fn_ty);
664
+ // dbg!(&old_param_count);
665
+
666
+ // // Get the old parameter types
667
+ // let mut old_param_types = Vec::with_capacity(old_param_count as usize);
668
+ // llvm::LLVMGetParamTypes(old_fn_ty, old_param_types.as_mut_ptr());
669
+ // old_param_types.set_len(old_param_count as usize);
670
+
671
+ // // Create the new parameter list, with ptr as the first argument
672
+ // let ptr_ty = llvm::LLVMPointerTypeInContext(llcx, 0);
673
+ // let mut new_param_types = Vec::with_capacity(old_param_count as usize + 1);
674
+ // new_param_types.push(ptr_ty);
675
+ // for old_param in old_param_types {
676
+ // new_param_types.push(old_param);
677
+ // }
678
+ // dbg!(&new_param_types);
679
+
680
+ // // Create the new function type
681
+ // let ret_ty = llvm::LLVMGetReturnType(old_fn_ty);
682
+ // let new_fn_ty = llvm::LLVMFunctionType(ret_ty, new_param_types.as_mut_ptr(), new_param_types.len() as u32, 0);
683
+ // dbg!(&new_fn_ty);
684
+
685
+ // // Create the new function
686
+ // let old_fn_name = String::from_utf8(llvm::get_value_name(old_fn)).unwrap();
687
+ // //let old_fn_name = std::ffi::CStr::from_ptr(llvm::LLVMGetValueName2(old_fn)).to_str().unwrap();
688
+ // let new_fn_name = format!("{}_with_dyn_ptr", old_fn_name);
689
+ // let new_fn_cstr = CString::new(new_fn_name).unwrap();
690
+ // let new_fn = llvm::LLVMAddFunction(m, new_fn_cstr.as_ptr(), new_fn_ty);
691
+ // dbg!(&new_fn);
692
+ // let a0 = llvm::LLVMGetParam(new_fn, 0);
693
+ // llvm::LLVMSetValueName2(a0, b"dyn_ptr\0".as_ptr().cast(), "dyn_ptr".len());
694
+ // dbg!(&new_fn);
695
+
696
+ // // Move basic blocks
697
+ // let mut bb = llvm::LLVMGetFirstBasicBlock(old_fn);
698
+ // //dbg!(&bb);
699
+ // llvm::LLVMAppendExistingBasicBlock(new_fn, bb);
700
+ // //while !bb.is_null() {
701
+ // // let next = llvm::LLVMGetNextBasicBlock(bb);
702
+ // // llvm::LLVMAppendExistingBasicBlock(new_fn, bb);
703
+ // // bb = next;
704
+ // //}// Shift argument uses: old %0 -> new %1, old %1 -> new %2, ...
705
+ // let old_n = llvm::LLVMCountParams(old_fn);
706
+ // for i in 0..old_n {
707
+ // let old_arg = llvm::LLVMGetParam(old_fn, i);
708
+ // let new_arg = llvm::LLVMGetParam(new_fn, i + 1);
709
+ // llvm::LLVMReplaceAllUsesWith(old_arg, new_arg);
710
+ // }
711
+
712
+ // // Copy linkage and visibility
713
+ // //llvm::LLVMSetLinkage(new_fn, llvm::LLVMGetLinkage(old_fn));
714
+ // //llvm::LLVMSetVisibility(new_fn, llvm::LLVMGetVisibility(old_fn));
715
+
716
+ // // Replace all uses of old_fn with new_fn (RAUW)
717
+ // llvm::LLVMReplaceAllUsesWith(old_fn, new_fn);
718
+
719
+ // // Optionally, remove the old function
720
+ // llvm::LLVMDeleteFunction(old_fn);
721
+ //}
722
+ }
723
+
724
+ let consider_offload = config. offload . contains ( & config:: Offload :: Enable ) ;
725
+ if consider_offload && ( cgcx. target_arch == "amdgpu" || cgcx. target_arch == "nvptx64" ) {
726
+ for num in 0 ..9 {
727
+ let name = format ! ( "kernel_{num}" ) ;
728
+ let c_name = CString :: new ( name) . unwrap ( ) ;
729
+ if let Some ( kernel) =
730
+ unsafe { llvm:: LLVMGetNamedFunction ( module. module_llvm . llmod ( ) , c_name. as_ptr ( ) ) }
731
+ {
732
+ handle_offload ( module. module_llvm . llmod ( ) , module. module_llvm . llcx , kernel) ;
733
+ }
734
+ }
735
+ }
736
+
656
737
let mut llvm_profiler = cgcx
657
738
. prof
658
739
. llvm_recording_enabled ( )
0 commit comments