@@ -243,7 +243,7 @@ func (it *nodeIterator) seek(prefix []byte) error {
243243 key = key [:len (key )- 1 ]
244244 // Move forward until we're just before the closest match to key.
245245 for {
246- state , parentIndex , path , err := it .peek ( bytes . HasPrefix ( key , it . path ) )
246+ state , parentIndex , path , err := it .peekSeek ( key )
247247 if err == errIteratorEnd {
248248 return errIteratorEnd
249249 } else if err != nil {
@@ -255,16 +255,21 @@ func (it *nodeIterator) seek(prefix []byte) error {
255255 }
256256}
257257
258+ // init initializes the the iterator.
259+ func (it * nodeIterator ) init () (* nodeIteratorState , error ) {
260+ root := it .trie .Hash ()
261+ state := & nodeIteratorState {node : it .trie .root , index : - 1 }
262+ if root != emptyRoot {
263+ state .hash = root
264+ }
265+ return state , state .resolve (it .trie , nil )
266+ }
267+
258268// peek creates the next state of the iterator.
259269func (it * nodeIterator ) peek (descend bool ) (* nodeIteratorState , * int , []byte , error ) {
270+ // Initialize the iterator if we've just started.
260271 if len (it .stack ) == 0 {
261- // Initialize the iterator if we've just started.
262- root := it .trie .Hash ()
263- state := & nodeIteratorState {node : it .trie .root , index : - 1 }
264- if root != emptyRoot {
265- state .hash = root
266- }
267- err := state .resolve (it .trie , nil )
272+ state , err := it .init ()
268273 return state , nil , nil , err
269274 }
270275 if ! descend {
@@ -292,6 +297,39 @@ func (it *nodeIterator) peek(descend bool) (*nodeIteratorState, *int, []byte, er
292297 return nil , nil , nil , errIteratorEnd
293298}
294299
300+ // peekSeek is like peek, but it also tries to skip resolving hashes by skipping
301+ // over the siblings that do not lead towards the desired seek position.
302+ func (it * nodeIterator ) peekSeek (seekKey []byte ) (* nodeIteratorState , * int , []byte , error ) {
303+ // Initialize the iterator if we've just started.
304+ if len (it .stack ) == 0 {
305+ state , err := it .init ()
306+ return state , nil , nil , err
307+ }
308+ if ! bytes .HasPrefix (seekKey , it .path ) {
309+ // If we're skipping children, pop the current node first
310+ it .pop ()
311+ }
312+
313+ // Continue iteration to the next child
314+ for len (it .stack ) > 0 {
315+ parent := it .stack [len (it .stack )- 1 ]
316+ ancestor := parent .hash
317+ if (ancestor == common.Hash {}) {
318+ ancestor = parent .parent
319+ }
320+ state , path , ok := it .nextChildAt (parent , ancestor , seekKey )
321+ if ok {
322+ if err := state .resolve (it .trie , path ); err != nil {
323+ return parent , & parent .index , path , err
324+ }
325+ return state , & parent .index , path , nil
326+ }
327+ // No more child nodes, move back up.
328+ it .pop ()
329+ }
330+ return nil , nil , nil , errIteratorEnd
331+ }
332+
295333func (st * nodeIteratorState ) resolve (tr * Trie , path []byte ) error {
296334 if hash , ok := st .node .(hashNode ); ok {
297335 resolved , err := tr .resolveHash (hash , path )
@@ -304,25 +342,38 @@ func (st *nodeIteratorState) resolve(tr *Trie, path []byte) error {
304342 return nil
305343}
306344
345+ func findChild (n * fullNode , index int , path []byte , ancestor common.Hash ) (node , * nodeIteratorState , []byte , int ) {
346+ var (
347+ child node
348+ state * nodeIteratorState
349+ childPath []byte
350+ )
351+ for ; index < len (n .Children ); index ++ {
352+ if n .Children [index ] != nil {
353+ child = n .Children [index ]
354+ hash , _ := child .cache ()
355+ state = & nodeIteratorState {
356+ hash : common .BytesToHash (hash ),
357+ node : child ,
358+ parent : ancestor ,
359+ index : - 1 ,
360+ pathlen : len (path ),
361+ }
362+ childPath = append (childPath , path ... )
363+ childPath = append (childPath , byte (index ))
364+ return child , state , childPath , index
365+ }
366+ }
367+ return nil , nil , nil , 0
368+ }
369+
307370func (it * nodeIterator ) nextChild (parent * nodeIteratorState , ancestor common.Hash ) (* nodeIteratorState , []byte , bool ) {
308371 switch node := parent .node .(type ) {
309372 case * fullNode :
310- // Full node, move to the first non-nil child.
311- for i := parent .index + 1 ; i < len (node .Children ); i ++ {
312- child := node .Children [i ]
313- if child != nil {
314- hash , _ := child .cache ()
315- state := & nodeIteratorState {
316- hash : common .BytesToHash (hash ),
317- node : child ,
318- parent : ancestor ,
319- index : - 1 ,
320- pathlen : len (it .path ),
321- }
322- path := append (it .path , byte (i ))
323- parent .index = i - 1
324- return state , path , true
325- }
373+ //Full node, move to the first non-nil child.
374+ if child , state , path , index := findChild (node , parent .index + 1 , it .path , ancestor ); child != nil {
375+ parent .index = index - 1
376+ return state , path , true
326377 }
327378 case * shortNode :
328379 // Short node, return the pointer singleton child
@@ -342,6 +393,52 @@ func (it *nodeIterator) nextChild(parent *nodeIteratorState, ancestor common.Has
342393 return parent , it .path , false
343394}
344395
396+ // nextChildAt is similar to nextChild, except that it targets a child as close to the
397+ // target key as possible, thus skipping siblings.
398+ func (it * nodeIterator ) nextChildAt (parent * nodeIteratorState , ancestor common.Hash , key []byte ) (* nodeIteratorState , []byte , bool ) {
399+ switch n := parent .node .(type ) {
400+ case * fullNode :
401+ // Full node, move to the first non-nil child before the desired key position
402+ child , state , path , index := findChild (n , parent .index + 1 , it .path , ancestor )
403+ if child == nil {
404+ // No more children in this fullnode
405+ return parent , it .path , false
406+ }
407+ // If the child we found is already past the seek position, just return it.
408+ if bytes .Compare (path , key ) >= 0 {
409+ parent .index = index - 1
410+ return state , path , true
411+ }
412+ // The child is before the seek position. Try advancing
413+ for {
414+ nextChild , nextState , nextPath , nextIndex := findChild (n , index + 1 , it .path , ancestor )
415+ // If we run out of children, or skipped past the target, return the
416+ // previous one
417+ if nextChild == nil || bytes .Compare (nextPath , key ) >= 0 {
418+ parent .index = index - 1
419+ return state , path , true
420+ }
421+ // We found a better child closer to the target
422+ state , path , index = nextState , nextPath , nextIndex
423+ }
424+ case * shortNode :
425+ // Short node, return the pointer singleton child
426+ if parent .index < 0 {
427+ hash , _ := n .Val .cache ()
428+ state := & nodeIteratorState {
429+ hash : common .BytesToHash (hash ),
430+ node : n .Val ,
431+ parent : ancestor ,
432+ index : - 1 ,
433+ pathlen : len (it .path ),
434+ }
435+ path := append (it .path , n .Key ... )
436+ return state , path , true
437+ }
438+ }
439+ return parent , it .path , false
440+ }
441+
345442func (it * nodeIterator ) push (state * nodeIteratorState , parentIndex * int , path []byte ) {
346443 it .path = path
347444 it .stack = append (it .stack , state )
0 commit comments