@@ -172,7 +172,10 @@ def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default,
172172 clear_if_default = False ):
173173 if is_packed :
174174 local_DecodeVarint = _DecodeVarint
175- def DecodePackedField (buffer , pos , end , message , field_dict ):
175+ def DecodePackedField (
176+ buffer , pos , end , message , field_dict , current_depth = 0
177+ ):
178+ del current_depth # unused
176179 value = field_dict .get (key )
177180 if value is None :
178181 value = field_dict .setdefault (key , new_default (message ))
@@ -191,7 +194,10 @@ def DecodePackedField(buffer, pos, end, message, field_dict):
191194 elif is_repeated :
192195 tag_bytes = encoder .TagBytes (field_number , wire_type )
193196 tag_len = len (tag_bytes )
194- def DecodeRepeatedField (buffer , pos , end , message , field_dict ):
197+ def DecodeRepeatedField (
198+ buffer , pos , end , message , field_dict , current_depth = 0
199+ ):
200+ del current_depth # unused
195201 value = field_dict .get (key )
196202 if value is None :
197203 value = field_dict .setdefault (key , new_default (message ))
@@ -208,7 +214,8 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
208214 return new_pos
209215 return DecodeRepeatedField
210216 else :
211- def DecodeField (buffer , pos , end , message , field_dict ):
217+ def DecodeField (buffer , pos , end , message , field_dict , current_depth = 0 ):
218+ del current_depth # unused
212219 (new_value , pos ) = decode_value (buffer , pos )
213220 if pos > end :
214221 raise _DecodeError ('Truncated message.' )
@@ -352,7 +359,9 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
352359 enum_type = key .enum_type
353360 if is_packed :
354361 local_DecodeVarint = _DecodeVarint
355- def DecodePackedField (buffer , pos , end , message , field_dict ):
362+ def DecodePackedField (
363+ buffer , pos , end , message , field_dict , current_depth = 0
364+ ):
356365 """Decode serialized packed enum to its value and a new position.
357366
358367 Args:
@@ -365,6 +374,7 @@ def DecodePackedField(buffer, pos, end, message, field_dict):
365374 Returns:
366375 int, new position in serialized data.
367376 """
377+ del current_depth # unused
368378 value = field_dict .get (key )
369379 if value is None :
370380 value = field_dict .setdefault (key , new_default (message ))
@@ -405,7 +415,9 @@ def DecodePackedField(buffer, pos, end, message, field_dict):
405415 elif is_repeated :
406416 tag_bytes = encoder .TagBytes (field_number , wire_format .WIRETYPE_VARINT )
407417 tag_len = len (tag_bytes )
408- def DecodeRepeatedField (buffer , pos , end , message , field_dict ):
418+ def DecodeRepeatedField (
419+ buffer , pos , end , message , field_dict , current_depth = 0
420+ ):
409421 """Decode serialized repeated enum to its value and a new position.
410422
411423 Args:
@@ -418,6 +430,7 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
418430 Returns:
419431 int, new position in serialized data.
420432 """
433+ del current_depth # unused
421434 value = field_dict .get (key )
422435 if value is None :
423436 value = field_dict .setdefault (key , new_default (message ))
@@ -446,7 +459,7 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
446459 return new_pos
447460 return DecodeRepeatedField
448461 else :
449- def DecodeField (buffer , pos , end , message , field_dict ):
462+ def DecodeField (buffer , pos , end , message , field_dict , current_depth = 0 ):
450463 """Decode serialized repeated enum to its value and a new position.
451464
452465 Args:
@@ -459,6 +472,7 @@ def DecodeField(buffer, pos, end, message, field_dict):
459472 Returns:
460473 int, new position in serialized data.
461474 """
475+ del current_depth # unused
462476 value_start_pos = pos
463477 (enum_value , pos ) = _DecodeSignedVarint32 (buffer , pos )
464478 if pos > end :
@@ -540,7 +554,10 @@ def _ConvertToUnicode(memview):
540554 tag_bytes = encoder .TagBytes (field_number ,
541555 wire_format .WIRETYPE_LENGTH_DELIMITED )
542556 tag_len = len (tag_bytes )
543- def DecodeRepeatedField (buffer , pos , end , message , field_dict ):
557+ def DecodeRepeatedField (
558+ buffer , pos , end , message , field_dict , current_depth = 0
559+ ):
560+ del current_depth # unused
544561 value = field_dict .get (key )
545562 if value is None :
546563 value = field_dict .setdefault (key , new_default (message ))
@@ -557,7 +574,8 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
557574 return new_pos
558575 return DecodeRepeatedField
559576 else :
560- def DecodeField (buffer , pos , end , message , field_dict ):
577+ def DecodeField (buffer , pos , end , message , field_dict , current_depth = 0 ):
578+ del current_depth # unused
561579 (size , pos ) = local_DecodeVarint (buffer , pos )
562580 new_pos = pos + size
563581 if new_pos > end :
@@ -581,7 +599,10 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default,
581599 tag_bytes = encoder .TagBytes (field_number ,
582600 wire_format .WIRETYPE_LENGTH_DELIMITED )
583601 tag_len = len (tag_bytes )
584- def DecodeRepeatedField (buffer , pos , end , message , field_dict ):
602+ def DecodeRepeatedField (
603+ buffer , pos , end , message , field_dict , current_depth = 0
604+ ):
605+ del current_depth # unused
585606 value = field_dict .get (key )
586607 if value is None :
587608 value = field_dict .setdefault (key , new_default (message ))
@@ -598,7 +619,8 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
598619 return new_pos
599620 return DecodeRepeatedField
600621 else :
601- def DecodeField (buffer , pos , end , message , field_dict ):
622+ def DecodeField (buffer , pos , end , message , field_dict , current_depth = 0 ):
623+ del current_depth # unused
602624 (size , pos ) = local_DecodeVarint (buffer , pos )
603625 new_pos = pos + size
604626 if new_pos > end :
@@ -623,7 +645,9 @@ def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
623645 tag_bytes = encoder .TagBytes (field_number ,
624646 wire_format .WIRETYPE_START_GROUP )
625647 tag_len = len (tag_bytes )
626- def DecodeRepeatedField (buffer , pos , end , message , field_dict ):
648+ def DecodeRepeatedField (
649+ buffer , pos , end , message , field_dict , current_depth = 0
650+ ):
627651 value = field_dict .get (key )
628652 if value is None :
629653 value = field_dict .setdefault (key , new_default (message ))
@@ -632,7 +656,13 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
632656 if value is None :
633657 value = field_dict .setdefault (key , new_default (message ))
634658 # Read sub-message.
635- pos = value .add ()._InternalParse (buffer , pos , end )
659+ current_depth += 1
660+ if current_depth > _recursion_limit :
661+ raise _DecodeError (
662+ 'Error parsing message: too many levels of nesting.'
663+ )
664+ pos = value .add ()._InternalParse (buffer , pos , end , current_depth )
665+ current_depth -= 1
636666 # Read end tag.
637667 new_pos = pos + end_tag_len
638668 if buffer [pos :new_pos ] != end_tag_bytes or new_pos > end :
@@ -644,12 +674,16 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
644674 return new_pos
645675 return DecodeRepeatedField
646676 else :
647- def DecodeField (buffer , pos , end , message , field_dict ):
677+ def DecodeField (buffer , pos , end , message , field_dict , current_depth = 0 ):
648678 value = field_dict .get (key )
649679 if value is None :
650680 value = field_dict .setdefault (key , new_default (message ))
651681 # Read sub-message.
652- pos = value ._InternalParse (buffer , pos , end )
682+ current_depth += 1
683+ if current_depth > _recursion_limit :
684+ raise _DecodeError ('Error parsing message: too many levels of nesting.' )
685+ pos = value ._InternalParse (buffer , pos , end , current_depth )
686+ current_depth -= 1
653687 # Read end tag.
654688 new_pos = pos + end_tag_len
655689 if buffer [pos :new_pos ] != end_tag_bytes or new_pos > end :
@@ -668,7 +702,9 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
668702 tag_bytes = encoder .TagBytes (field_number ,
669703 wire_format .WIRETYPE_LENGTH_DELIMITED )
670704 tag_len = len (tag_bytes )
671- def DecodeRepeatedField (buffer , pos , end , message , field_dict ):
705+ def DecodeRepeatedField (
706+ buffer , pos , end , message , field_dict , current_depth = 0
707+ ):
672708 value = field_dict .get (key )
673709 if value is None :
674710 value = field_dict .setdefault (key , new_default (message ))
@@ -679,18 +715,27 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
679715 if new_pos > end :
680716 raise _DecodeError ('Truncated message.' )
681717 # Read sub-message.
682- if value .add ()._InternalParse (buffer , pos , new_pos ) != new_pos :
718+ current_depth += 1
719+ if current_depth > _recursion_limit :
720+ raise _DecodeError (
721+ 'Error parsing message: too many levels of nesting.'
722+ )
723+ if (
724+ value .add ()._InternalParse (buffer , pos , new_pos , current_depth )
725+ != new_pos
726+ ):
683727 # The only reason _InternalParse would return early is if it
684728 # encountered an end-group tag.
685729 raise _DecodeError ('Unexpected end-group tag.' )
686730 # Predict that the next tag is another copy of the same repeated field.
731+ current_depth -= 1
687732 pos = new_pos + tag_len
688733 if buffer [new_pos :pos ] != tag_bytes or new_pos == end :
689734 # Prediction failed. Return.
690735 return new_pos
691736 return DecodeRepeatedField
692737 else :
693- def DecodeField (buffer , pos , end , message , field_dict ):
738+ def DecodeField (buffer , pos , end , message , field_dict , current_depth = 0 ):
694739 value = field_dict .get (key )
695740 if value is None :
696741 value = field_dict .setdefault (key , new_default (message ))
@@ -699,11 +744,14 @@ def DecodeField(buffer, pos, end, message, field_dict):
699744 new_pos = pos + size
700745 if new_pos > end :
701746 raise _DecodeError ('Truncated message.' )
702- # Read sub-message.
703- if value ._InternalParse (buffer , pos , new_pos ) != new_pos :
747+ current_depth += 1
748+ if current_depth > _recursion_limit :
749+ raise _DecodeError ('Error parsing message: too many levels of nesting.' )
750+ if value ._InternalParse (buffer , pos , new_pos , current_depth ) != new_pos :
704751 # The only reason _InternalParse would return early is if it encountered
705752 # an end-group tag.
706753 raise _DecodeError ('Unexpected end-group tag.' )
754+ current_depth -= 1
707755 return new_pos
708756 return DecodeField
709757
@@ -859,7 +907,8 @@ def MapDecoder(field_descriptor, new_default, is_message_map):
859907 # Can't read _concrete_class yet; might not be initialized.
860908 message_type = field_descriptor .message_type
861909
862- def DecodeMap (buffer , pos , end , message , field_dict ):
910+ def DecodeMap (buffer , pos , end , message , field_dict , current_depth = 0 ):
911+ del current_depth # unused
863912 submsg = message_type ._concrete_class ()
864913 value = field_dict .get (key )
865914 if value is None :
@@ -941,8 +990,16 @@ def _SkipGroup(buffer, pos, end):
941990 return pos
942991 pos = new_pos
943992
993+ DEFAULT_RECURSION_LIMIT = 100
994+ _recursion_limit = DEFAULT_RECURSION_LIMIT
995+
996+
997+ def SetRecursionLimit (new_limit ):
998+ global _recursion_limit
999+ _recursion_limit = new_limit
1000+
9441001
945- def _DecodeUnknownFieldSet (buffer , pos , end_pos = None ):
1002+ def _DecodeUnknownFieldSet (buffer , pos , end_pos = None , current_depth = 0 ):
9461003 """Decode UnknownFieldSet. Returns the UnknownFieldSet and new position."""
9471004
9481005 unknown_field_set = containers .UnknownFieldSet ()
@@ -952,14 +1009,14 @@ def _DecodeUnknownFieldSet(buffer, pos, end_pos=None):
9521009 field_number , wire_type = wire_format .UnpackTag (tag )
9531010 if wire_type == wire_format .WIRETYPE_END_GROUP :
9541011 break
955- (data , pos ) = _DecodeUnknownField (buffer , pos , wire_type )
1012+ (data , pos ) = _DecodeUnknownField (buffer , pos , wire_type , current_depth )
9561013 # pylint: disable=protected-access
9571014 unknown_field_set ._add (field_number , wire_type , data )
9581015
9591016 return (unknown_field_set , pos )
9601017
9611018
962- def _DecodeUnknownField (buffer , pos , wire_type ):
1019+ def _DecodeUnknownField (buffer , pos , wire_type , current_depth = 0 ):
9631020 """Decode a unknown field. Returns the UnknownField and new position."""
9641021
9651022 if wire_type == wire_format .WIRETYPE_VARINT :
@@ -973,7 +1030,12 @@ def _DecodeUnknownField(buffer, pos, wire_type):
9731030 data = buffer [pos :pos + size ].tobytes ()
9741031 pos += size
9751032 elif wire_type == wire_format .WIRETYPE_START_GROUP :
976- (data , pos ) = _DecodeUnknownFieldSet (buffer , pos )
1033+ print ("MMP " + str (current_depth ))
1034+ current_depth += 1
1035+ if current_depth >= _recursion_limit :
1036+ raise _DecodeError ('Error parsing message: too many levels of nesting.' )
1037+ (data , pos ) = _DecodeUnknownFieldSet (buffer , pos , None , current_depth )
1038+ current_depth -= 1
9771039 elif wire_type == wire_format .WIRETYPE_END_GROUP :
9781040 return (0 , - 1 )
9791041 else :
0 commit comments