Skip to content

Commit a8fb3b7

Browse files
authored
Merge pull request #330 from JohannesGaessler/cuda-fa-mma-5
Cuda fa mma 5
2 parents 5bbc736 + 60958f6 commit a8fb3b7

31 files changed

+2126
-1002
lines changed

.github/workflows/build.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,7 +1154,6 @@ jobs:
11541154
uses: hendrikmuhs/[email protected]
11551155
with:
11561156
key: ${{ github.job }}
1157-
variant: sccache
11581157
evict-old-files: 1d
11591158

11601159
- name: Build
@@ -1189,7 +1188,6 @@ jobs:
11891188
uses: hendrikmuhs/[email protected]
11901189
with:
11911190
key: windows-latest-cmake-hip-release
1192-
variant: sccache
11931191
evict-old-files: 1d
11941192

11951193
- name: Install

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,7 @@ ifdef GGML_RPC
596596
OBJ_GGML_EXT += ggml/src/ggml-rpc.o
597597
endif # GGML_RPC
598598

599-
OBJ_CUDA_TMPL = $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-wmma*.cu))
599+
OBJ_CUDA_TMPL = $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-mma*.cu))
600600
OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/mmq*.cu))
601601

602602
ifdef GGML_CUDA_FA_ALL_QUANTS

common/minja.hpp

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -693,7 +693,7 @@ enum SpaceHandling { Keep, Strip, StripSpaces, StripNewline };
693693

694694
class TemplateToken {
695695
public:
696-
enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter };
696+
enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter, Break, Continue };
697697

698698
static std::string typeToString(Type t) {
699699
switch (t) {
@@ -714,6 +714,8 @@ class TemplateToken {
714714
case Type::EndFilter: return "endfilter";
715715
case Type::Generation: return "generation";
716716
case Type::EndGeneration: return "endgeneration";
717+
case Type::Break: return "break";
718+
case Type::Continue: return "continue";
717719
}
718720
return "Unknown";
719721
}
@@ -815,6 +817,22 @@ struct CommentTemplateToken : public TemplateToken {
815817
CommentTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Comment, location, pre, post), text(t) {}
816818
};
817819

820+
enum class LoopControlType { Break, Continue };
821+
822+
class LoopControlException : public std::runtime_error {
823+
public:
824+
LoopControlType control_type;
825+
LoopControlException(const std::string & message, LoopControlType control_type) : std::runtime_error(message), control_type(control_type) {}
826+
LoopControlException(LoopControlType control_type)
827+
: std::runtime_error((std::ostringstream() << (control_type == LoopControlType::Continue ? "continue" : "break") << " outside of a loop").str()),
828+
control_type(control_type) {}
829+
};
830+
831+
struct LoopControlTemplateToken : public TemplateToken {
832+
LoopControlType control_type;
833+
LoopControlTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, LoopControlType control_type) : TemplateToken(Type::Break, location, pre, post), control_type(control_type) {}
834+
};
835+
818836
class TemplateNode {
819837
Location location_;
820838
protected:
@@ -825,6 +843,12 @@ class TemplateNode {
825843
void render(std::ostringstream & out, const std::shared_ptr<Context> & context) const {
826844
try {
827845
do_render(out, context);
846+
} catch (const LoopControlException & e) {
847+
// TODO: make stack creation lazy. Only needed if it was thrown outside of a loop.
848+
std::ostringstream err;
849+
err << e.what();
850+
if (location_.source) err << error_location_suffix(*location_.source, location_.pos);
851+
throw LoopControlException(err.str(), e.control_type);
828852
} catch (const std::exception & e) {
829853
std::ostringstream err;
830854
err << e.what();
@@ -897,6 +921,15 @@ class IfNode : public TemplateNode {
897921
}
898922
};
899923

924+
class LoopControlNode : public TemplateNode {
925+
LoopControlType control_type_;
926+
public:
927+
LoopControlNode(const Location & location, LoopControlType control_type) : TemplateNode(location), control_type_(control_type) {}
928+
void do_render(std::ostringstream &, const std::shared_ptr<Context> &) const override {
929+
throw LoopControlException(control_type_);
930+
}
931+
};
932+
900933
class ForNode : public TemplateNode {
901934
std::vector<std::string> var_names;
902935
std::shared_ptr<Expression> iterable;
@@ -961,7 +994,12 @@ class ForNode : public TemplateNode {
961994
loop.set("last", i == (n - 1));
962995
loop.set("previtem", i > 0 ? filtered_items.at(i - 1) : Value());
963996
loop.set("nextitem", i < n - 1 ? filtered_items.at(i + 1) : Value());
964-
body->render(out, loop_context);
997+
try {
998+
body->render(out, loop_context);
999+
} catch (const LoopControlException & e) {
1000+
if (e.control_type == LoopControlType::Break) break;
1001+
if (e.control_type == LoopControlType::Continue) continue;
1002+
}
9651003
}
9661004
}
9671005
};
@@ -2159,7 +2197,7 @@ class Parser {
21592197
static std::regex comment_tok(R"(\{#([-~]?)(.*?)([-~]?)#\})");
21602198
static std::regex expr_open_regex(R"(\{\{([-~])?)");
21612199
static std::regex block_open_regex(R"(^\{%([-~])?[\s\n\r]*)");
2162-
static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter)\b)");
2200+
static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)");
21632201
static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)");
21642202
static std::regex expr_close_regex(R"([\s\n\r]*([-~])?\}\})");
21652203
static std::regex block_close_regex(R"([\s\n\r]*([-~])?%\})");
@@ -2291,6 +2329,9 @@ class Parser {
22912329
} else if (keyword == "endfilter") {
22922330
auto post_space = parseBlockClose();
22932331
tokens.push_back(std::make_unique<EndFilterTemplateToken>(location, pre_space, post_space));
2332+
} else if (keyword == "break" || keyword == "continue") {
2333+
auto post_space = parseBlockClose();
2334+
tokens.push_back(std::make_unique<LoopControlTemplateToken>(location, pre_space, post_space, keyword == "break" ? LoopControlType::Break : LoopControlType::Continue));
22942335
} else {
22952336
throw std::runtime_error("Unexpected block: " + keyword);
22962337
}
@@ -2414,6 +2455,8 @@ class Parser {
24142455
children.emplace_back(std::make_shared<FilterNode>(token->location, std::move(filter_token->filter), std::move(body)));
24152456
} else if (dynamic_cast<CommentTemplateToken*>(token.get())) {
24162457
// Ignore comments
2458+
} else if (auto ctrl_token = dynamic_cast<LoopControlTemplateToken*>(token.get())) {
2459+
children.emplace_back(std::make_shared<LoopControlNode>(token->location, ctrl_token->control_type));
24172460
} else if (dynamic_cast<EndForTemplateToken*>(token.get())
24182461
|| dynamic_cast<EndSetTemplateToken*>(token.get())
24192462
|| dynamic_cast<EndMacroTemplateToken*>(token.get())

examples/run/run.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,13 @@ static int printe(const char * fmt, ...) {
6565
return ret;
6666
}
6767

68+
static std::string strftime_fmt(const char * fmt, const std::tm & tm) {
69+
std::ostringstream oss;
70+
oss << std::put_time(&tm, fmt);
71+
72+
return oss.str();
73+
}
74+
6875
class Opt {
6976
public:
7077
int init(int argc, const char ** argv) {
@@ -698,6 +705,39 @@ class LlamaData {
698705
return download(url, bn, true);
699706
}
700707

708+
int s3_dl(const std::string & model, const std::string & bn) {
709+
const size_t slash_pos = model.find('/');
710+
if (slash_pos == std::string::npos) {
711+
return 1;
712+
}
713+
714+
const std::string bucket = model.substr(0, slash_pos);
715+
const std::string key = model.substr(slash_pos + 1);
716+
const char * access_key = std::getenv("AWS_ACCESS_KEY_ID");
717+
const char * secret_key = std::getenv("AWS_SECRET_ACCESS_KEY");
718+
if (!access_key || !secret_key) {
719+
printe("AWS credentials not found in environment\n");
720+
return 1;
721+
}
722+
723+
// Generate AWS Signature Version 4 headers
724+
// (Implementation requires HMAC-SHA256 and date handling)
725+
// Get current timestamp
726+
const time_t now = time(nullptr);
727+
const tm tm = *gmtime(&now);
728+
const std::string date = strftime_fmt("%Y%m%d", tm);
729+
const std::string datetime = strftime_fmt("%Y%m%dT%H%M%SZ", tm);
730+
const std::vector<std::string> headers = {
731+
"Authorization: AWS4-HMAC-SHA256 Credential=" + std::string(access_key) + "/" + date +
732+
"/us-east-1/s3/aws4_request",
733+
"x-amz-content-sha256: UNSIGNED-PAYLOAD", "x-amz-date: " + datetime
734+
};
735+
736+
const std::string url = "https://" + bucket + ".s3.amazonaws.com/" + key;
737+
738+
return download(url, bn, true, headers);
739+
}
740+
701741
std::string basename(const std::string & path) {
702742
const size_t pos = path.find_last_of("/\\");
703743
if (pos == std::string::npos) {
@@ -738,6 +778,9 @@ class LlamaData {
738778
rm_until_substring(model_, "github:");
739779
rm_until_substring(model_, "://");
740780
ret = github_dl(model_, bn);
781+
} else if (string_starts_with(model_, "s3://")) {
782+
rm_until_substring(model_, "://");
783+
ret = s3_dl(model_, bn);
741784
} else { // ollama:// or nothing
742785
rm_until_substring(model_, "ollama.com/library/");
743786
rm_until_substring(model_, "://");

ggml/include/ggml.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1775,7 +1775,7 @@ extern "C" {
17751775
struct ggml_tensor * a,
17761776
int k);
17771777

1778-
#define GGML_KQ_MASK_PAD 32
1778+
#define GGML_KQ_MASK_PAD 64
17791779

17801780
// q: [n_embd, n_batch, n_head, 1]
17811781
// k: [n_embd, n_kv, n_head_kv, 1]

ggml/src/ggml-cuda/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ if (CUDAToolkit_FOUND)
2828
list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
2929

3030
file(GLOB GGML_SOURCES_CUDA "*.cu")
31-
file(GLOB SRCS "template-instances/fattn-wmma*.cu")
31+
file(GLOB SRCS "template-instances/fattn-mma*.cu")
3232
list(APPEND GGML_SOURCES_CUDA ${SRCS})
3333
file(GLOB SRCS "template-instances/mmq*.cu")
3434
list(APPEND GGML_SOURCES_CUDA ${SRCS})

ggml/src/ggml-cuda/common.cuh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ typedef float2 dfloat2;
148148
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
149149

150150
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
151-
#define INT8_MMA_AVAILABLE
151+
#define NEW_MMA_AVAILABLE
152152
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
153153

154154
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
@@ -159,11 +159,13 @@ static constexpr bool fast_fp16_available(const int cc) {
159159
return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
160160
}
161161

162+
// Any FP16 tensor cores are available.
162163
static constexpr bool fp16_mma_available(const int cc) {
163164
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
164165
}
165166

166-
static constexpr bool int8_mma_available(const int cc) {
167+
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
168+
static constexpr bool new_mma_available(const int cc) {
167169
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING;
168170
}
169171

0 commit comments

Comments
 (0)