@@ -653,6 +653,84 @@ pub(crate) unsafe fn llvm_optimize(
653
653
None
654
654
} ;
655
655
656
+ fn handle_offload ( m : & llvm:: Module , llcx : & llvm:: Context , old_fn : & llvm:: Value ) {
657
+ dbg ! ( & old_fn) ;
658
+ unsafe { llvm:: LLVMRustOffloadWrapper ( m, old_fn) } ;
659
+ //unsafe {llvm::LLVMDumpModule(m);}
660
+ //unsafe {
661
+ // // Get the old function type
662
+ // let old_fn_ty = llvm::LLVMGlobalGetValueType(old_fn);
663
+ // dbg!(&old_fn_ty);
664
+ // let old_param_count = llvm::LLVMCountParamTypes(old_fn_ty);
665
+ // dbg!(&old_param_count);
666
+
667
+ // // Get the old parameter types
668
+ // let mut old_param_types = Vec::with_capacity(old_param_count as usize);
669
+ // llvm::LLVMGetParamTypes(old_fn_ty, old_param_types.as_mut_ptr());
670
+ // old_param_types.set_len(old_param_count as usize);
671
+
672
+ // // Create the new parameter list, with ptr as the first argument
673
+ // let ptr_ty = llvm::LLVMPointerTypeInContext(llcx, 0);
674
+ // let mut new_param_types = Vec::with_capacity(old_param_count as usize + 1);
675
+ // new_param_types.push(ptr_ty);
676
+ // for old_param in old_param_types {
677
+ // new_param_types.push(old_param);
678
+ // }
679
+ // dbg!(&new_param_types);
680
+
681
+ // // Create the new function type
682
+ // let ret_ty = llvm::LLVMGetReturnType(old_fn_ty);
683
+ // let new_fn_ty = llvm::LLVMFunctionType(ret_ty, new_param_types.as_mut_ptr(), new_param_types.len() as u32, 0);
684
+ // dbg!(&new_fn_ty);
685
+
686
+ // // Create the new function
687
+ // let old_fn_name = String::from_utf8(llvm::get_value_name(old_fn)).unwrap();
688
+ // //let old_fn_name = std::ffi::CStr::from_ptr(llvm::LLVMGetValueName2(old_fn)).to_str().unwrap();
689
+ // let new_fn_name = format!("{}_with_dyn_ptr", old_fn_name);
690
+ // let new_fn_cstr = CString::new(new_fn_name).unwrap();
691
+ // let new_fn = llvm::LLVMAddFunction(m, new_fn_cstr.as_ptr(), new_fn_ty);
692
+ // dbg!(&new_fn);
693
+ // let a0 = llvm::LLVMGetParam(new_fn, 0);
694
+ // llvm::LLVMSetValueName2(a0, b"dyn_ptr\0".as_ptr().cast(), "dyn_ptr".len());
695
+ // dbg!(&new_fn);
696
+
697
+ // // Move basic blocks
698
+ // let mut bb = llvm::LLVMGetFirstBasicBlock(old_fn);
699
+ // //dbg!(&bb);
700
+ // llvm::LLVMAppendExistingBasicBlock(new_fn, bb);
701
+ // //while !bb.is_null() {
702
+ // // let next = llvm::LLVMGetNextBasicBlock(bb);
703
+ // // llvm::LLVMAppendExistingBasicBlock(new_fn, bb);
704
+ // // bb = next;
705
+ // //}// Shift argument uses: old %0 -> new %1, old %1 -> new %2, ...
706
+ // let old_n = llvm::LLVMCountParams(old_fn);
707
+ // for i in 0..old_n {
708
+ // let old_arg = llvm::LLVMGetParam(old_fn, i);
709
+ // let new_arg = llvm::LLVMGetParam(new_fn, i + 1);
710
+ // llvm::LLVMReplaceAllUsesWith(old_arg, new_arg);
711
+ // }
712
+
713
+ // // Copy linkage and visibility
714
+ // //llvm::LLVMSetLinkage(new_fn, llvm::LLVMGetLinkage(old_fn));
715
+ // //llvm::LLVMSetVisibility(new_fn, llvm::LLVMGetVisibility(old_fn));
716
+
717
+ // // Replace all uses of old_fn with new_fn (RAUW)
718
+ // llvm::LLVMReplaceAllUsesWith(old_fn, new_fn);
719
+
720
+ // // Optionally, remove the old function
721
+ // llvm::LLVMDeleteFunction(old_fn);
722
+ //}
723
+ }
724
+
725
+ for num in 0 ..9 {
726
+ let name = format ! ( "kernel_{num}" ) ;
727
+ let c_name = CString :: new ( name) . unwrap ( ) ;
728
+ if let Some ( kernel) = unsafe { llvm:: LLVMGetNamedFunction ( module. module_llvm . llmod ( ) , c_name. as_ptr ( ) ) } {
729
+ dbg ! ( "found offload kernel asfd" ) ;
730
+ handle_offload ( module. module_llvm . llmod ( ) , module. module_llvm . llcx , kernel) ;
731
+ }
732
+ }
733
+
656
734
let mut llvm_profiler = cgcx
657
735
. prof
658
736
. llvm_recording_enabled ( )
0 commit comments