- 
                Notifications
    You must be signed in to change notification settings 
- Fork 44
[WIP] faster sum #356
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[WIP] faster sum #356
Conversation
| Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/fast_sum.jl b/fast_sum.jl
index e7cba52..fec6952 100644
--- a/fast_sum.jl
+++ b/fast_sum.jl
@@ -12,7 +12,7 @@ function sum_columns_subgroup(X, result, M, N)
     end
 
     partial = 0.0f0
-    for row = row_thread:row_stride:M
+    for row in row_thread:row_stride:M
         idx = (col - 1) * M + row  # column-major layout
         partial += X[idx]
     end
@@ -32,9 +32,9 @@ function sum_columns_subgroup(X, result, M, N)
 
     # Only one thread writes result
     if lane == 1
-       Atomix.@atomic result[col] += partial
+        Atomix.@atomic result[col] += partial
     end
-    nothing
+    return nothing
 end
 
 
diff --git a/lib/intrinsics/src/atomic.jl b/lib/intrinsics/src/atomic.jl
index 08e71e8..6d46c2e 100644
--- a/lib/intrinsics/src/atomic.jl
+++ b/lib/intrinsics/src/atomic.jl
@@ -58,12 +58,14 @@ end
 end
 
 for gentype in [Float32, Float64], as in atomic_memory_types
-@eval begin
+    @eval begin
 
-@device_function atomic_add!(p::LLVMPtr{$gentype,$as}, val::$gentype) =
-    @builtin_ccall("atomic_add", $gentype,
-                   (LLVMPtr{$gentype,$as}, $gentype), p, val)
-end
+        @device_function atomic_add!(p::LLVMPtr{$gentype, $as}, val::$gentype) =
+            @builtin_ccall(
+            "atomic_add", $gentype,
+            (LLVMPtr{$gentype, $as}, $gentype), p, val
+        )
+    end
 end
 
 
diff --git a/lib/intrinsics/src/work_item.jl b/lib/intrinsics/src/work_item.jl
index d9919f2..3f6446b 100644
--- a/lib/intrinsics/src/work_item.jl
+++ b/lib/intrinsics/src/work_item.jl
@@ -39,41 +39,47 @@ end
 export sub_group_shuffle, sub_group_shuffle_xor
 
 for (jltype, llvmtype, julia_type_str) in [
-        (Int8,    "i8",    :Int8),
-        (UInt8,   "i8",    :UInt8),
-        (Int16,   "i16",   :Int16),
-        (UInt16,  "i16",   :UInt16),
-        (Int32,   "i32",   :Int32),
-        (UInt32,  "i32",   :UInt32),
-        (Int64,   "i64",   :Int64),
-        (UInt64,  "i64",   :UInt64),
-        (Float16, "half",  :Float16),
+        (Int8, "i8", :Int8),
+        (UInt8, "i8", :UInt8),
+        (Int16, "i16", :Int16),
+        (UInt16, "i16", :UInt16),
+        (Int32, "i32", :Int32),
+        (UInt32, "i32", :UInt32),
+        (Int64, "i64", :Int64),
+        (UInt64, "i64", :UInt64),
+        (Float16, "half", :Float16),
         (Float32, "float", :Float32),
-        (Float64, "double",:Float64)
+        (Float64, "double", :Float64),
     ]
     @eval begin
         export sub_group_shuffle, sub_group_shuffle_xor
         function sub_group_shuffle(x::$jltype, idx::Integer)
-            Base.llvmcall(
-                $("""
-                declare $llvmtype @__spirv_GroupNonUniformShuffle(i32, $llvmtype, i32)
-                define $llvmtype @entry($llvmtype %val, i32 %idx) #0 {
-                    %res = call $llvmtype @__spirv_GroupNonUniformShuffle(i32 3, $llvmtype %val, i32 %idx)
-                    ret $llvmtype %res
-                }
-                attributes #0 = { alwaysinline }
-                """, "entry"), $julia_type_str, Tuple{$julia_type_str, Int32}, x, Int32(idx))
+            return Base.llvmcall(
+                $(
+                    """
+                    declare $llvmtype @__spirv_GroupNonUniformShuffle(i32, $llvmtype, i32)
+                    define $llvmtype @entry($llvmtype %val, i32 %idx) #0 {
+                        %res = call $llvmtype @__spirv_GroupNonUniformShuffle(i32 3, $llvmtype %val, i32 %idx)
+                        ret $llvmtype %res
+                    }
+                    attributes #0 = { alwaysinline }
+                    """, "entry",
+                ), $julia_type_str, Tuple{$julia_type_str, Int32}, x, Int32(idx)
+            )
         end
         function sub_group_shuffle_xor(x::$jltype, mask::Integer)
-            Base.llvmcall(
-                $("""
-                declare $llvmtype @__spirv_GroupNonUniformShuffleXor(i32, $llvmtype, i32)
-                define $llvmtype @entry($llvmtype %val, i32 %mask) #0 {
-                    %res = call $llvmtype @__spirv_GroupNonUniformShuffleXor(i32 3, $llvmtype %val, i32 %mask)
-                    ret $llvmtype %res
-                }
-                attributes #0 = { alwaysinline }
-                """, "entry"), $julia_type_str, Tuple{$julia_type_str, Int32}, x, Int32(mask))
+            return Base.llvmcall(
+                $(
+                    """
+                    declare $llvmtype @__spirv_GroupNonUniformShuffleXor(i32, $llvmtype, i32)
+                    define $llvmtype @entry($llvmtype %val, i32 %mask) #0 {
+                        %res = call $llvmtype @__spirv_GroupNonUniformShuffleXor(i32 3, $llvmtype %val, i32 %mask)
+                        ret $llvmtype %res
+                    }
+                    attributes #0 = { alwaysinline }
+                    """, "entry",
+                ), $julia_type_str, Tuple{$julia_type_str, Int32}, x, Int32(mask)
+            )
         end
     end
 end | 
| Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@           Coverage Diff           @@
##           master     #356   +/-   ##
=======================================
  Coverage   78.86%   78.86%           
=======================================
  Files          12       12           
  Lines         672      672           
=======================================
  Hits          530      530           
  Misses        142      142           ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
 | 
| 
 That's great to hear, at least. I presume that the cartesian indexing introduced by the more complicated  | 
ref #352
This matches the speed of the C implementation, so there seems to be no inherent overhead compared to OpenCL C:
cc @maleadt