Skip to content

Commit cabf6ff

Browse files
am17anpwilkin
authored andcommitted
CUDA + openCL: fix bug in accessing rms_norm->src while doing fusion (ggml-org#16577)
1 parent ac97d9b commit cabf6ff

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2882,7 +2882,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
28822882
}
28832883

28842884
//if rms norm is the B operand, then we don't handle broadcast
2885-
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
2885+
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm)) {
28862886
return false;
28872887
}
28882888

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2686,7 +2686,7 @@ static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
26862686

26872687
// if rms_norm is the B operand, then we don't handle broadcast
26882688
if (rms_norm == mul->src[1] &&
2689-
!ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
2689+
!ggml_are_same_shape(mul->src[0], rms_norm)) {
26902690
return false;
26912691
}
26922692

0 commit comments

Comments
 (0)