@@ -21,9 +21,11 @@ import (
2121 "encoding/binary"
2222 "errors"
2323 "fmt"
24+ "maps"
2425 "math"
2526 "math/rand"
2627 "reflect"
28+ "slices"
2729 "strings"
2830 "sync"
2931 "testing"
@@ -557,10 +559,14 @@ func forEachStorage(s *StateDB, addr common.Address, cb func(key, value common.H
557559 if err != nil {
558560 return err
559561 }
560- it := trie .NewIterator (trieIt )
562+ var (
563+ it = trie .NewIterator (trieIt )
564+ visited = make (map [common.Hash ]bool )
565+ )
561566
562567 for it .Next () {
563568 key := common .BytesToHash (s .trie .GetKey (it .Key ))
569+ visited [key ] = true
564570 if value , dirty := so .dirtyStorage [key ]; dirty {
565571 if ! cb (key , value ) {
566572 return nil
@@ -600,6 +606,10 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
600606 checkeq ("GetCode" , state .GetCode (addr ), checkstate .GetCode (addr ))
601607 checkeq ("GetCodeHash" , state .GetCodeHash (addr ), checkstate .GetCodeHash (addr ))
602608 checkeq ("GetCodeSize" , state .GetCodeSize (addr ), checkstate .GetCodeSize (addr ))
609+ // Check newContract-flag
610+ if obj := state .getStateObject (addr ); obj != nil {
611+ checkeq ("IsNewContract" , obj .newContract , checkstate .getStateObject (addr ).newContract )
612+ }
603613 // Check storage.
604614 if obj := state .getStateObject (addr ); obj != nil {
605615 forEachStorage (state , addr , func (key , value common.Hash ) bool {
@@ -608,12 +618,49 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
608618 forEachStorage (checkstate , addr , func (key , value common.Hash ) bool {
609619 return checkeq ("GetState(" + key .Hex ()+ ")" , checkstate .GetState (addr , key ), value )
610620 })
621+ other := checkstate .getStateObject (addr )
622+ // Check dirty storage which is not in trie
623+ if ! maps .Equal (obj .dirtyStorage , other .dirtyStorage ) {
624+ print := func (dirty map [common.Hash ]common.Hash ) string {
625+ var keys []common.Hash
626+ out := new (strings.Builder )
627+ for key := range dirty {
628+ keys = append (keys , key )
629+ }
630+ slices .SortFunc (keys , common .Hash .Cmp )
631+ for i , key := range keys {
632+ fmt .Fprintf (out , " %d. %v %v\n " , i , key , dirty [key ])
633+ }
634+ return out .String ()
635+ }
636+ return fmt .Errorf ("dirty storage err, have\n %v\n want\n %v" ,
637+ print (obj .dirtyStorage ),
638+ print (other .dirtyStorage ))
639+ }
640+ }
641+ // Check transient storage.
642+ {
643+ have := state .transientStorage
644+ want := checkstate .transientStorage
645+ eq := maps .EqualFunc (have , want ,
646+ func (a Storage , b Storage ) bool {
647+ return maps .Equal (a , b )
648+ })
649+ if ! eq {
650+ return fmt .Errorf ("transient storage differs ,have\n %v\n want\n %v" ,
651+ have .PrettyPrint (),
652+ want .PrettyPrint ())
653+ }
611654 }
612655 if err != nil {
613656 return err
614657 }
615658 }
616-
659+ if ! checkstate .accessList .Equal (state .accessList ) { // Check access lists
660+ return fmt .Errorf ("AccessLists are wrong, have \n %v\n want\n %v" ,
661+ checkstate .accessList .PrettyPrint (),
662+ state .accessList .PrettyPrint ())
663+ }
617664 if state .GetRefund () != checkstate .GetRefund () {
618665 return fmt .Errorf ("got GetRefund() == %d, want GetRefund() == %d" ,
619666 state .GetRefund (), checkstate .GetRefund ())
@@ -622,6 +669,23 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
622669 return fmt .Errorf ("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v" ,
623670 state .GetLogs (common.Hash {}, 0 , common.Hash {}), checkstate .GetLogs (common.Hash {}, 0 , common.Hash {}))
624671 }
672+ if ! maps .Equal (state .journal .dirties , checkstate .journal .dirties ) {
673+ getKeys := func (dirty map [common.Address ]int ) string {
674+ var keys []common.Address
675+ out := new (strings.Builder )
676+ for key := range dirty {
677+ keys = append (keys , key )
678+ }
679+ slices .SortFunc (keys , common .Address .Cmp )
680+ for i , key := range keys {
681+ fmt .Fprintf (out , " %d. %v\n " , i , key )
682+ }
683+ return out .String ()
684+ }
685+ have := getKeys (state .journal .dirties )
686+ want := getKeys (checkstate .journal .dirties )
687+ return fmt .Errorf ("dirty-journal set mismatch.\n have:\n %v\n want:\n %v\n " , have , want )
688+ }
625689 return nil
626690}
627691
0 commit comments