99#include " misc/Interval.h"
1010#include " IntStream.h"
1111
12- #include " support/StringUtils .h"
12+ #include " support/Utf8 .h"
1313#include " support/CPPUtils.h"
1414
1515#include " ANTLRInputStream.h"
@@ -35,28 +35,37 @@ ANTLRInputStream::ANTLRInputStream(std::istream &stream): ANTLRInputStream() {
3535 load (stream);
3636}
3737
38- void ANTLRInputStream::load (const std::string &input) {
39- load (input.data (), input.size ());
38+ void ANTLRInputStream::load (const std::string &input, bool lenient ) {
39+ load (input.data (), input.size (), lenient );
4040}
4141
42- void ANTLRInputStream::load (const char *data, size_t length) {
42+ void ANTLRInputStream::load (const char *data, size_t length, bool lenient ) {
4343 // Remove the UTF-8 BOM if present.
4444 const char *bom = " \xef\xbb\xbf " ;
45- if (length >= 3 && strncmp (data, bom, 3 ) == 0 )
46- _data = antlrcpp::utf8_to_utf32 (data + 3 , data + length);
47- else
48- _data = antlrcpp::utf8_to_utf32 (data, data + length);
45+ if (length >= 3 && strncmp (data, bom, 3 ) == 0 ) {
46+ data += 3 ;
47+ length -= 3 ;
48+ }
49+ if (lenient) {
50+ _data = Utf8::lenientDecode (std::string_view (data, length));
51+ } else {
52+ auto maybe_utf32 = Utf8::strictDecode (std::string_view (data, length));
53+ if (!maybe_utf32.has_value ()) {
54+ throw IllegalArgumentException (" UTF-8 string contains an illegal byte sequence" );
55+ }
56+ _data = std::move (maybe_utf32).value ();
57+ }
4958 p = 0 ;
5059}
5160
52- void ANTLRInputStream::load (std::istream &stream) {
61+ void ANTLRInputStream::load (std::istream &stream, bool lenient ) {
5362 if (!stream.good () || stream.eof ()) // No fail, bad or EOF.
5463 return ;
5564
5665 _data.clear ();
5766
5867 std::string s ((std::istreambuf_iterator<char >(stream)), std::istreambuf_iterator<char >());
59- load (s.data (), s.length ());
68+ load (s.data (), s.length (), lenient );
6069}
6170
6271void ANTLRInputStream::reset () {
@@ -144,7 +153,11 @@ std::string ANTLRInputStream::getText(const Interval &interval) {
144153 return " " ;
145154 }
146155
147- return antlrcpp::utf32_to_utf8 (_data.substr (start, count));
156+ auto maybe_utf8 = Utf8::strictEncode (std::u32string_view (_data).substr (start, count));
157+ if (!maybe_utf8.has_value ()) {
158+ throw IllegalArgumentException (" Input stream contains invalid Unicode code points" );
159+ }
160+ return std::move (maybe_utf8).value ();
148161}
149162
150163std::string ANTLRInputStream::getSourceName () const {
@@ -155,7 +168,11 @@ std::string ANTLRInputStream::getSourceName() const {
155168}
156169
157170std::string ANTLRInputStream::toString () const {
158- return antlrcpp::utf32_to_utf8 (_data);
171+ auto maybe_utf8 = Utf8::strictEncode (_data);
172+ if (!maybe_utf8.has_value ()) {
173+ throw IllegalArgumentException (" Input stream contains invalid Unicode code points" );
174+ }
175+ return std::move (maybe_utf8).value ();
159176}
160177
161178void ANTLRInputStream::InitializeInstanceFields () {
0 commit comments