diff --git a/cmd/integration/commands/stages.go b/cmd/integration/commands/stages.go index 8027683579a..f774ac0d83e 100644 --- a/cmd/integration/commands/stages.go +++ b/cmd/integration/commands/stages.go @@ -1083,7 +1083,7 @@ func newSync(ctx context.Context, db kv.RwDB, miningConfig *params.MiningConfig) cfg.Miner = *miningConfig } - sync, err := stages2.NewStagedSync(context.Background(), logger, db, p2p.Config{}, cfg, downloadServer, tmpdir, txPool, txPoolP2PServer, nil) + sync, err := stages2.NewStagedSync(context.Background(), logger, db, p2p.Config{}, cfg, chainConfig.TerminalTotalDifficulty, downloadServer, tmpdir, txPool, txPoolP2PServer, nil) if err != nil { panic(err) } diff --git a/core/rawdb/accessors_chain.go b/core/rawdb/accessors_chain.go index 66da8212bf2..6cf8bbacab4 100644 --- a/core/rawdb/accessors_chain.go +++ b/core/rawdb/accessors_chain.go @@ -976,3 +976,34 @@ func WritePendingEpoch(tx kv.RwTx, blockNum uint64, blockHash common.Hash, trans copy(k[8:], blockHash[:]) return tx.Put(kv.PendingEpoch, k, transitionProof) } + +// Transitioned returns true if the block number comes after POS transition +func Transitioned(db kv.Getter, blockNum uint64) (trans bool, err error) { + data, err := db.GetOne(kv.TransitionBlockKey, []byte(kv.TransitionBlockKey)) + if err != nil { + return false, fmt.Errorf("failed ReadTd: %w", err) + } + if len(data) == 0 { + return false, nil + } + return blockNum > binary.BigEndian.Uint64(data), nil +} + +// MarkTreansition sets transition to proof-of-stake from the block number +func MarkTransition(db kv.StatelessRwTx, blockNum uint64) error { + data := make([]byte, 8) + binary.BigEndian.PutUint64(data, blockNum) + // If we already transitioned then we do not update the transition + marked, err := db.Has(kv.TransitionBlockKey, []byte(kv.TransitionBlockKey)) + if err != nil { + return err + } + + if marked { + return nil + } + if err := db.Put(kv.TransitionBlockKey, []byte(kv.TransitionBlockKey), data); err != nil { + return fmt.Errorf("failed to store block total difficulty: %w", err) + } + return nil +} diff --git a/core/rawdb/accessors_chain_test.go b/core/rawdb/accessors_chain_test.go index 1394bfa2967..ea7022b14b0 100644 --- a/core/rawdb/accessors_chain_test.go +++ b/core/rawdb/accessors_chain_test.go @@ -416,6 +416,29 @@ func TestBlockReceiptStorage(t *testing.T) { } } +// Tests that transitions is handled correctly +func TestTransition(t *testing.T) { + _, tx := memdb.NewTestTx(t) + require := require.New(t) + transitionBlock := uint64(1000) + + isTrans, err := Transitioned(tx, 1500) + require.NoError(err) + require.False(isTrans) + isTrans, err = Transitioned(tx, 20) + require.NoError(err) + require.False(isTrans) + + require.NoError(MarkTransition(tx, transitionBlock)) + + isTrans, err = Transitioned(tx, 1500) + require.NoError(err) + require.True(isTrans) + isTrans, err = Transitioned(tx, 20) + require.NoError(err) + require.False(isTrans) +} + func checkReceiptsRLP(have, want types.Receipts) error { if len(have) != len(want) { return fmt.Errorf("receipts sizes mismatch: have %d, want %d", len(have), len(want)) diff --git a/eth/backend.go b/eth/backend.go index 67bb28e2be6..c930cd8bde2 100644 --- a/eth/backend.go +++ b/eth/backend.go @@ -463,7 +463,7 @@ func New(stack *node.Node, config *ethconfig.Config, logger log.Logger) (*Ethere return nil, err } - backend.stagedSync, err = stages2.NewStagedSync(backend.downloadCtx, backend.logger, backend.chainDB, stack.Config().P2P, *config, backend.downloadServer, tmpdir, backend.txPool, backend.txPoolP2PServer, backend.notifications.Accumulator) + backend.stagedSync, err = stages2.NewStagedSync(backend.downloadCtx, backend.logger, backend.chainDB, stack.Config().P2P, *config, chainConfig.TerminalTotalDifficulty, backend.downloadServer, tmpdir, backend.txPool, backend.txPoolP2PServer, backend.notifications.Accumulator) if err != nil { return nil, err } diff --git a/eth/stagedsync/default_stages.go b/eth/stagedsync/default_stages.go index 7f1816c0dbf..435a61eb913 100644 --- a/eth/stagedsync/default_stages.go +++ b/eth/stagedsync/default_stages.go @@ -13,6 +13,7 @@ func DefaultStages(ctx context.Context, headers HeadersCfg, blockHashCfg BlockHashesCfg, bodies BodiesCfg, + difficulty DifficultyCfg, senders SendersCfg, exec ExecuteBlockCfg, trans TranspileCfg, @@ -69,6 +70,19 @@ func DefaultStages(ctx context.Context, return PruneBodiesStage(p, tx, bodies, ctx) }, }, + { + ID: stages.TotalDifficulty, + Description: "Compute total difficulty", + Forward: func(firstCycle bool, badBlockUnwind bool, s *StageState, u Unwinder, tx kv.RwTx) error { + return SpawnDifficultyStage(s, tx, difficulty, ctx) + }, + Unwind: func(firstCycle bool, u *UnwindState, s *StageState, tx kv.RwTx) error { + return UnwindDifficultyStage(u, tx, ctx) + }, + Prune: func(firstCycle bool, p *PruneState, tx kv.RwTx) error { + return PruneDifficultyStage(p, tx, ctx) + }, + }, { ID: stages.Senders, Description: "Recover senders from tx signatures", diff --git a/eth/stagedsync/stage_difficulty.go b/eth/stagedsync/stage_difficulty.go new file mode 100644 index 00000000000..ac466b4b466 --- /dev/null +++ b/eth/stagedsync/stage_difficulty.go @@ -0,0 +1,155 @@ +package stagedsync + +import ( + "bytes" + "context" + "encoding/binary" + "fmt" + "math/big" + + "github.com/ledgerwatch/erigon-lib/etl" + "github.com/ledgerwatch/erigon-lib/kv" + "github.com/ledgerwatch/erigon/core/rawdb" + "github.com/ledgerwatch/erigon/core/types" + "github.com/ledgerwatch/erigon/eth/stagedsync/stages" + "github.com/ledgerwatch/erigon/rlp" +) + +type DifficultyCfg struct { + tmpDir string + terminalTotalDifficulty *big.Int + db kv.RwDB +} + +func StageDifficultyCfg(db kv.RwDB, tmpDir string, terminalTotalDifficulty *big.Int) DifficultyCfg { + return DifficultyCfg{ + db: db, + tmpDir: tmpDir, + terminalTotalDifficulty: terminalTotalDifficulty, + } +} + +func SpawnDifficultyStage(s *StageState, tx kv.RwTx, cfg DifficultyCfg, ctx context.Context) (err error) { + useExternalTx := tx != nil + + if !useExternalTx { + var err error + tx, err = cfg.db.BeginRw(context.Background()) + if err != nil { + return err + } + defer tx.Rollback() + } + + quit := ctx.Done() + headNumber, err := stages.GetStageProgress(tx, stages.Headers) + if err != nil { + return fmt.Errorf("getting headers progress: %w", err) + } + + td := big.NewInt(0) + if s.BlockNumber > 0 { + td, err = rawdb.ReadTd(tx, rawdb.ReadHeaderByNumber(tx, s.BlockNumber).Hash(), s.BlockNumber) + if err != nil { + return err + } + } + // If the chain does not have a proof of stake config or has reached terminalTotalDifficulty then we can skip this stage + if cfg.terminalTotalDifficulty == nil || td.Cmp(cfg.terminalTotalDifficulty) >= 0 { + if err = s.Update(tx, headNumber); err != nil { + return err + } + if !useExternalTx { + if err = tx.Commit(); err != nil { + return err + } + } + return nil + } + + startKey := make([]byte, 8) + binary.BigEndian.PutUint64(startKey, s.BlockNumber) + + header := new(types.Header) + if err := etl.Transform( + s.LogPrefix(), + tx, + kv.Headers, + kv.HeaderTD, + cfg.tmpDir, + func(k []byte, v []byte, next etl.ExtractNextFunc) error { + if len(k) != 40 { + return nil + } + + blockNum := binary.BigEndian.Uint64(k) + canonical, err := rawdb.ReadCanonicalHash(tx, blockNum) + if err != nil { + return err + } + + if !bytes.Equal(k[8:], canonical[:]) { + return nil + } + if err := rlp.Decode(bytes.NewReader(v), header); err != nil { + return err + } + + td.Add(td, header.Difficulty) + + if header.Eip3675 { + return nil + } + + if td.Cmp(cfg.terminalTotalDifficulty) > 0 { + return rawdb.MarkTransition(tx, blockNum) + } + data, err := rlp.EncodeToBytes(td) + if err != nil { + return fmt.Errorf("failed to RLP encode block total difficulty: %w", err) + } + return next(k, k, data) + }, + etl.IdentityLoadFunc, + etl.TransformArgs{ + ExtractStartKey: startKey, + Quit: quit, + }, + ); err != nil { + return err + } + if err = s.Update(tx, headNumber); err != nil { + return err + } + if !useExternalTx { + if err = tx.Commit(); err != nil { + return err + } + } + return nil +} + +func UnwindDifficultyStage(u *UnwindState, tx kv.RwTx, ctx context.Context) (err error) { + useExternalTx := tx != nil + + if err = u.Done(tx); err != nil { + return fmt.Errorf(" reset: %w", err) + } + if !useExternalTx { + if err = tx.Commit(); err != nil { + return fmt.Errorf("failed to write db commit: %w", err) + } + } + return nil +} + +func PruneDifficultyStage(p *PruneState, tx kv.RwTx, ctx context.Context) (err error) { + useExternalTx := tx != nil + + if !useExternalTx { + if err = tx.Commit(); err != nil { + return err + } + } + return nil +} diff --git a/eth/stagedsync/stage_difficulty_test.go b/eth/stagedsync/stage_difficulty_test.go new file mode 100644 index 00000000000..7fa22a1fe53 --- /dev/null +++ b/eth/stagedsync/stage_difficulty_test.go @@ -0,0 +1,245 @@ +package stagedsync + +import ( + "context" + "math/big" + "testing" + + "github.com/ledgerwatch/erigon-lib/kv/memdb" + "github.com/ledgerwatch/erigon/core/rawdb" + "github.com/ledgerwatch/erigon/core/types" + "github.com/ledgerwatch/erigon/eth/stagedsync/stages" + "github.com/stretchr/testify/assert" +) + +func TestDifficultyComputation(t *testing.T) { + // We need a Database with the following requirements: + // 3 Headers + // 3 Canonical Hashes + ctx, assert := context.Background(), assert.New(t) + db := memdb.New() + tx, _ := db.BeginRw(ctx) + // Create the 3 headers, body is irrelevant we just need to have difficulty + var header1, header2, header3 types.Header + // First header + header1.Difficulty = big.NewInt(10) + header1.Number = big.NewInt(1) + // Second Header + header2.Difficulty = big.NewInt(30) + header2.Number = big.NewInt(2) + // Third Header + header3.Difficulty = big.NewInt(314) + header3.Number = big.NewInt(3) + // Insert the headers into the db + rawdb.WriteHeader(tx, &header1) + rawdb.WriteHeader(tx, &header2) + rawdb.WriteHeader(tx, &header3) + // Canonical hashes + rawdb.WriteCanonicalHash(tx, header1.Hash(), header1.Number.Uint64()) + rawdb.WriteCanonicalHash(tx, header2.Hash(), header2.Number.Uint64()) + rawdb.WriteCanonicalHash(tx, header3.Hash(), header3.Number.Uint64()) + // save progress for headers + _ = stages.SaveStageProgress(tx, stages.Headers, 3) + // Code + err := SpawnDifficultyStage(&StageState{BlockNumber: 0, ID: stages.TotalDifficulty}, tx, StageDifficultyCfg(db, "", big.NewInt(1000)), ctx) + assert.NoError(err) + // Asserts + actual_td, err := rawdb.ReadTd(tx, header1.Hash(), header1.Number.Uint64()) + assert.NoError(err) + assert.Equalf(uint64(10), actual_td.Uint64(), "Wrong total difficulty") + + actual_td, err = rawdb.ReadTd(tx, header2.Hash(), header2.Number.Uint64()) + assert.NoError(err) + assert.Equalf(uint64(40), actual_td.Uint64(), "Wrong total difficulty") + + actual_td, err = rawdb.ReadTd(tx, header3.Hash(), header3.Number.Uint64()) + assert.NoError(err) + assert.Equalf(uint64(354), actual_td.Uint64(), "Wrong total difficulty") +} + +func TestDifficultyComputationNonCanonical(t *testing.T) { + // We need a Database with the following requirements: + // 3 Headers + // 3 Canonical Hashes + ctx, assert := context.Background(), assert.New(t) + db := memdb.New() + tx, _ := db.BeginRw(ctx) + + // Create the 3 headers, body is irrelevant we just need to have difficulty + var header1, header2, noncanonicalHeader2, header3 types.Header + // First header + header1.Difficulty = big.NewInt(10) + header1.Number = big.NewInt(1) + // Second Header + header2.Difficulty = big.NewInt(30) + header2.Number = big.NewInt(2) + noncanonicalHeader2.Difficulty = big.NewInt(50) + noncanonicalHeader2.Number = big.NewInt(2) + // Third Header + header3.Difficulty = big.NewInt(314) + header3.Number = big.NewInt(3) + // Insert the headers into the db + rawdb.WriteHeader(tx, &header1) + rawdb.WriteHeader(tx, &header2) + rawdb.WriteHeader(tx, &noncanonicalHeader2) + rawdb.WriteHeader(tx, &header3) + // Canonical hashes + rawdb.WriteCanonicalHash(tx, header1.Hash(), header1.Number.Uint64()) + rawdb.WriteCanonicalHash(tx, header2.Hash(), header2.Number.Uint64()) + rawdb.WriteCanonicalHash(tx, header3.Hash(), header3.Number.Uint64()) + // save progress for headers + _ = stages.SaveStageProgress(tx, stages.Headers, 3) + // Code + err := SpawnDifficultyStage(&StageState{BlockNumber: 0, ID: stages.TotalDifficulty}, tx, StageDifficultyCfg(db, "", big.NewInt(1000)), ctx) + assert.NoError(err) + // Asserts + actual_td, err := rawdb.ReadTd(tx, header1.Hash(), header1.Number.Uint64()) + assert.NoError(err) + assert.Equalf(uint64(10), actual_td.Uint64(), "Wrong total difficulty") + + actual_td, err = rawdb.ReadTd(tx, header2.Hash(), header2.Number.Uint64()) + assert.NoError(err) + assert.Equalf(uint64(40), actual_td.Uint64(), "Wrong total difficulty") + + actual_td, err = rawdb.ReadTd(tx, header3.Hash(), header3.Number.Uint64()) + assert.NoError(err) + assert.Equalf(uint64(354), actual_td.Uint64(), "Wrong total difficulty") +} + +func TestDifficultyProgress(t *testing.T) { + // We need a Database with the following requirements: + // 3 Headers + // 3 Canonical Hashes + ctx, assert := context.Background(), assert.New(t) + db := memdb.New() + tx, _ := db.BeginRw(ctx) + // Create the 3 headers, body is irrelevant we just need to have difficulty + var header1, header2, noncanonicalHeader2, header3 types.Header + // First header + header1.Difficulty = big.NewInt(10) + header1.Number = big.NewInt(1) + // Second Header + header2.Difficulty = big.NewInt(30) + header2.Number = big.NewInt(2) + noncanonicalHeader2.Difficulty = big.NewInt(50) + noncanonicalHeader2.Number = big.NewInt(2) + // Third Header + header3.Difficulty = big.NewInt(314) + header3.Number = big.NewInt(3) + // Insert the headers into the db + rawdb.WriteHeader(tx, &header1) + _ = rawdb.WriteTd(tx, header1.Hash(), header1.Number.Uint64(), big.NewInt(10)) + rawdb.WriteHeader(tx, &header2) + rawdb.WriteHeader(tx, &noncanonicalHeader2) + rawdb.WriteHeader(tx, &header3) + // Canonical hashes + rawdb.WriteCanonicalHash(tx, header1.Hash(), header1.Number.Uint64()) + rawdb.WriteCanonicalHash(tx, header2.Hash(), header2.Number.Uint64()) + rawdb.WriteCanonicalHash(tx, header3.Hash(), header3.Number.Uint64()) + // save progress for headers + _ = stages.SaveStageProgress(tx, stages.Headers, 3) + _ = stages.SaveStageProgress(tx, stages.TotalDifficulty, 1) + // Code + err := SpawnDifficultyStage(&StageState{BlockNumber: 0, ID: stages.TotalDifficulty}, tx, StageDifficultyCfg(db, "", big.NewInt(1000)), ctx) + assert.NoError(err) + // Asserts + actual_td, err := rawdb.ReadTd(tx, header1.Hash(), header1.Number.Uint64()) + assert.NoError(err) + assert.Equalf(uint64(10), actual_td.Uint64(), "Wrong total difficulty") + + actual_td, err = rawdb.ReadTd(tx, header2.Hash(), header2.Number.Uint64()) + assert.NoError(err) + assert.Equalf(uint64(40), actual_td.Uint64(), "Wrong total difficulty") + + actual_td, err = rawdb.ReadTd(tx, header3.Hash(), header3.Number.Uint64()) + assert.NoError(err) + assert.Equalf(uint64(354), actual_td.Uint64(), "Wrong total difficulty") +} + +// If we do not have proof-of-stake config then the stage does nothing at all. +func TestDifficultyNoTerminalDifficulty(t *testing.T) { + // We need a Database with the following requirements: + // 3 Headers + // 3 Canonical Hashes + ctx, assert := context.Background(), assert.New(t) + db := memdb.New() + tx, _ := db.BeginRw(ctx) + // Create the 3 headers, body is irrelevant we just need to have difficulty + var header1, header2, header3 types.Header + // First header + header1.Difficulty = big.NewInt(10) + header1.Number = big.NewInt(1) + // Second Header + header2.Difficulty = big.NewInt(30) + header2.Number = big.NewInt(2) + // Third Header + header3.Difficulty = big.NewInt(314) + header3.Number = big.NewInt(3) + // Insert the headers into the db + rawdb.WriteHeader(tx, &header1) + rawdb.WriteHeader(tx, &header2) + rawdb.WriteHeader(tx, &header3) + // Canonical hashes + rawdb.WriteCanonicalHash(tx, header1.Hash(), header1.Number.Uint64()) + rawdb.WriteCanonicalHash(tx, header2.Hash(), header2.Number.Uint64()) + rawdb.WriteCanonicalHash(tx, header3.Hash(), header3.Number.Uint64()) + // Code + err := SpawnDifficultyStage(&StageState{BlockNumber: 0, ID: stages.TotalDifficulty}, tx, StageDifficultyCfg(db, "", nil), ctx) + assert.NoError(err) + // Asserts + actual_td, err := rawdb.ReadTd(tx, header1.Hash(), header1.Number.Uint64()) + assert.NoError(err) + assert.True(actual_td == nil, "Wrong total difficulty") + + actual_td, err = rawdb.ReadTd(tx, header2.Hash(), header2.Number.Uint64()) + assert.NoError(err) + assert.True(actual_td == nil, "Wrong total difficulty") + + actual_td, err = rawdb.ReadTd(tx, header3.Hash(), header3.Number.Uint64()) + assert.NoError(err) + assert.True(actual_td == nil, "Wrong total difficulty") +} + +// We do not need to compute total difficulty after terminal difficulty. +func TestDifficultyGreaterThanTerminalDifficulty(t *testing.T) { + // We need a Database with the following requirements: + // 3 Headers + // 3 Canonical Hashes + ctx, assert := context.Background(), assert.New(t) + db := memdb.New() + tx, _ := db.BeginRw(ctx) + // Create the 3 headers, body is irrelevant we just need to have difficulty + var header1, header2, header3 types.Header + // First header + header1.Difficulty = big.NewInt(10) + header1.Number = big.NewInt(1) + // Second Header + header2.Difficulty = big.NewInt(990) + header2.Number = big.NewInt(2) + // Third Header + header3.Difficulty = big.NewInt(314) + header3.Number = big.NewInt(3) + // Insert the headers into the db + rawdb.WriteHeader(tx, &header1) + rawdb.WriteHeader(tx, &header2) + rawdb.WriteHeader(tx, &header3) + // Canonical hashes + rawdb.WriteCanonicalHash(tx, header1.Hash(), header1.Number.Uint64()) + rawdb.WriteCanonicalHash(tx, header2.Hash(), header2.Number.Uint64()) + rawdb.WriteCanonicalHash(tx, header3.Hash(), header3.Number.Uint64()) + // Code + err := SpawnDifficultyStage(&StageState{BlockNumber: 0, ID: stages.TotalDifficulty}, tx, StageDifficultyCfg(db, "", big.NewInt(1000)), ctx) + assert.NoError(err) + // Asserts + actual_td, err := rawdb.ReadTd(tx, header1.Hash(), header1.Number.Uint64()) + assert.NoError(err) + assert.Equalf(uint64(10), actual_td.Uint64(), "Wrong total difficulty") + + actual_td, err = rawdb.ReadTd(tx, header2.Hash(), header2.Number.Uint64()) + assert.NoError(err) + assert.Equalf(uint64(1000), actual_td.Uint64(), "Wrong total difficulty") + + actual_td, err = rawdb.ReadTd(tx, header3.Hash(), header3.Number.Uint64()) + assert.NoError(err) + assert.True(actual_td == nil, "Wrong total difficulty") +} diff --git a/eth/stagedsync/stage_headers.go b/eth/stagedsync/stage_headers.go index 35e4cefa6da..4e5b25ed6d6 100644 --- a/eth/stagedsync/stage_headers.go +++ b/eth/stagedsync/stage_headers.go @@ -122,6 +122,15 @@ func HeadersForward( prevProgress := headerProgress Loop: for !stopped { + + isTrans, err := rawdb.Transitioned(tx, headerProgress) + if err != nil { + return err + } + + if isTrans { + break + } currentTime := uint64(time.Now().Unix()) req, penalties := cfg.hd.RequestMoreHeaders(currentTime) if req != nil { @@ -156,9 +165,11 @@ Loop: } // Load headers into the database var inSync bool - if inSync, err = cfg.hd.InsertHeaders(headerInserter.FeedHeaderFunc(tx), logPrefix, logEvery.C); err != nil { + + if inSync, err = cfg.hd.InsertHeaders(headerInserter.FeedHeaderFunc(tx), cfg.chainConfig.TerminalTotalDifficulty, logPrefix, logEvery.C); err != nil { return err } + announces := cfg.hd.GrabAnnounces() if len(announces) > 0 { cfg.announceNewHashes(ctx, announces) @@ -210,6 +221,7 @@ Loop: } // We do not print the following line if the stage was interrupted log.Info(fmt.Sprintf("[%s] Processed", logPrefix), "highest inserted", headerInserter.GetHighest(), "age", common.PrettyAge(time.Unix(int64(headerInserter.GetHighestTimestamp()), 0))) + return nil } @@ -292,6 +304,17 @@ func HeadersUnwind(u *UnwindState, s *StageState, tx kv.RwTx, cfg HeadersCfg, te var maxTd big.Int var maxHash common.Hash var maxNum uint64 = 0 + // unwind the merge + isTrans, err := rawdb.Transitioned(tx, u.UnwindPoint) + if err != nil { + return err + } + + if cfg.chainConfig.TerminalTotalDifficulty != nil && !isTrans { + if err := tx.Delete(kv.TransitionBlockKey, []byte(kv.TransitionBlockKey), nil); err != nil { + return err + } + } if test { // If we are not in the test, we can do searching for the heaviest chain in the next cycle // Find header with biggest TD tdCursor, cErr := tx.Cursor(kv.HeaderTD) @@ -342,6 +365,10 @@ func HeadersUnwind(u *UnwindState, s *StageState, tx kv.RwTx, cfg HeadersCfg, te if err = s.Update(tx, maxNum); err != nil { return err } + // When we forward sync, total difficulty is updated within headers processing + if err = stages.SaveStageProgress(tx, stages.TotalDifficulty, maxNum); err != nil { + return err + } } if !useExternalTx { if err := tx.Commit(); err != nil { diff --git a/eth/stagedsync/stages/stages.go b/eth/stagedsync/stages/stages.go index 0e4616c8841..3f4a91c895a 100644 --- a/eth/stagedsync/stages/stages.go +++ b/eth/stagedsync/stages/stages.go @@ -32,6 +32,7 @@ var ( Headers SyncStage = "Headers" // Headers are downloaded, their Proof-Of-Work validity and chaining is verified BlockHashes SyncStage = "BlockHashes" // Headers Number are written, fills blockHash => number bucket Bodies SyncStage = "Bodies" // Block bodies are downloaded, TxHash and UncleHash are getting verified + TotalDifficulty SyncStage = "TotalDifficulty" // TotalDifficulty for each block is calculated. Senders SyncStage = "Senders" // "From" recovered from signatures, bodies re-written Execution SyncStage = "Execution" // Executing each block w/o buildinf a trie Translation SyncStage = "Translation" // Translation each marked for translation contract (from EVM to TEVM) diff --git a/go.mod b/go.mod index 658056304d5..30b9ca1f00b 100644 --- a/go.mod +++ b/go.mod @@ -64,4 +64,4 @@ require ( gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 gopkg.in/olebedev/go-duktape.v3 v3.0.0-20200619000410-60c24ae608a6 pgregory.net/rapid v0.4.7 -) +) \ No newline at end of file diff --git a/turbo/stages/headerdownload/header_algo_test.go b/turbo/stages/headerdownload/header_algo_test.go index 0f97b0e4d3f..d28448f99ec 100644 --- a/turbo/stages/headerdownload/header_algo_test.go +++ b/turbo/stages/headerdownload/header_algo_test.go @@ -48,10 +48,10 @@ func TestInserter1(t *testing.T) { ParentHash: h1Hash, } h2Hash := h2.Hash() - if err = hi.FeedHeader(tx, &h1, h1Hash, 1); err != nil { + if err = hi.FeedHeader(tx, &h1, h1Hash, 1, nil); err != nil { t.Errorf("feed empty header 1: %v", err) } - if err = hi.FeedHeader(tx, &h2, h2Hash, 2); err != nil { + if err = hi.FeedHeader(tx, &h2, h2Hash, 2, nil); err != nil { t.Errorf("feed empty header 2: %v", err) } } diff --git a/turbo/stages/headerdownload/header_algos.go b/turbo/stages/headerdownload/header_algos.go index f65a2a1f945..3dcbdb8276e 100644 --- a/turbo/stages/headerdownload/header_algos.go +++ b/turbo/stages/headerdownload/header_algos.go @@ -591,7 +591,7 @@ func (hd *HeaderDownload) RequestSkeleton() *HeaderRequest { // InsertHeaders attempts to insert headers into the database, verifying them first // It returns true in the first return value if the system is "in sync" -func (hd *HeaderDownload) InsertHeaders(hf func(header *types.Header, hash common.Hash, blockHeight uint64) error, logPrefix string, logChannel <-chan time.Time) (bool, error) { +func (hd *HeaderDownload) InsertHeaders(hf func(header *types.Header, hash common.Hash, blockHeight uint64, terminalTotalDifficulty *big.Int) error, terminalTotalDifficulty *big.Int, logPrefix string, logChannel <-chan time.Time) (bool, error) { hd.lock.Lock() defer hd.lock.Unlock() var linksInFuture []*Link // Here we accumulate links that fail validation as "in the future" @@ -635,7 +635,9 @@ func (hd *HeaderDownload) InsertHeaders(hf func(header *types.Header, hash commo delete(hd.links, link.hash) continue } - if err := hf(link.header, link.hash, link.blockHeight); err != nil { + + // Check if transition to proof-of-stake happened + if err := hf(link.header, link.hash, link.blockHeight, terminalTotalDifficulty); err != nil { return false, err } if link.blockHeight > hd.highestInDb { @@ -717,18 +719,18 @@ func (hd *HeaderDownload) addHeaderAsLink(header *types.Header, persisted bool) return link } -func (hi *HeaderInserter) FeedHeaderFunc(db kv.StatelessRwTx) func(header *types.Header, hash common.Hash, blockHeight uint64) error { - return func(header *types.Header, hash common.Hash, blockHeight uint64) error { - return hi.FeedHeader(db, header, hash, blockHeight) +func (hi *HeaderInserter) FeedHeaderFunc(db kv.StatelessRwTx) func(header *types.Header, hash common.Hash, blockHeight uint64, terminalTotalDifficulty *big.Int) error { + return func(header *types.Header, hash common.Hash, blockHeight uint64, terminalTotalDifficulty *big.Int) error { + return hi.FeedHeader(db, header, hash, blockHeight, terminalTotalDifficulty) } - } -func (hi *HeaderInserter) FeedHeader(db kv.StatelessRwTx, header *types.Header, hash common.Hash, blockHeight uint64) error { +func (hi *HeaderInserter) FeedHeader(db kv.StatelessRwTx, header *types.Header, hash common.Hash, blockHeight uint64, terminalTotalDifficulty *big.Int) error { if hash == hi.prevHash { // Skip duplicates return nil } + if oldH := rawdb.ReadHeader(db, hash, blockHeight); oldH != nil { // Already inserted, skip return nil @@ -814,9 +816,16 @@ func (hi *HeaderInserter) FeedHeader(db kv.StatelessRwTx, header *types.Header, if err = rawdb.WriteTd(db, hash, blockHeight, td); err != nil { return fmt.Errorf("[%s] failed to WriteTd: %w", hi.logPrefix, err) } + if err = db.Put(kv.Headers, dbutils.HeaderKey(blockHeight, hash), data); err != nil { return fmt.Errorf("[%s] failed to store header: %w", hi.logPrefix, err) } + + if terminalTotalDifficulty != nil && td.Cmp(terminalTotalDifficulty) >= 0 { + if err = rawdb.MarkTransition(db, blockHeight); err != nil { + return err + } + } hi.prevHash = hash return nil } diff --git a/turbo/stages/mock_sentry.go b/turbo/stages/mock_sentry.go index fb6ec0f9125..36b5d6847e9 100644 --- a/turbo/stages/mock_sentry.go +++ b/turbo/stages/mock_sentry.go @@ -286,7 +286,7 @@ func MockWithEverything(t *testing.T, gspec *core.Genesis, key *ecdsa.PrivateKey cfg.BodyDownloadTimeoutSeconds, *mock.ChainConfig, cfg.BatchSize, - ), stagedsync.StageSendersCfg(mock.DB, mock.ChainConfig, mock.tmpdir, prune), stagedsync.StageExecuteBlocksCfg( + ), stagedsync.StageDifficultyCfg(mock.DB, mock.tmpdir, nil), stagedsync.StageSendersCfg(mock.DB, mock.ChainConfig, mock.tmpdir, prune), stagedsync.StageExecuteBlocksCfg( mock.DB, prune, cfg.BatchSize, diff --git a/turbo/stages/stageloop.go b/turbo/stages/stageloop.go index 84a340cdcdf..7448ff24865 100644 --- a/turbo/stages/stageloop.go +++ b/turbo/stages/stageloop.go @@ -220,6 +220,7 @@ func NewStagedSync( db kv.RwDB, p2pCfg p2p.Config, cfg ethconfig.Config, + terminalTotalDifficulty *big.Int, controlServer *download.ControlServerImpl, tmpdir string, txPool *core.TxPool, @@ -245,7 +246,7 @@ func NewStagedSync( cfg.BodyDownloadTimeoutSeconds, *controlServer.ChainConfig, cfg.BatchSize, - ), stagedsync.StageSendersCfg(db, controlServer.ChainConfig, tmpdir, cfg.Prune), stagedsync.StageExecuteBlocksCfg( + ), stagedsync.StageDifficultyCfg(db, tmpdir, terminalTotalDifficulty), stagedsync.StageSendersCfg(db, controlServer.ChainConfig, tmpdir, cfg.Prune), stagedsync.StageExecuteBlocksCfg( db, cfg.Prune, cfg.BatchSize,