diff --git a/eth/api_trace.go b/eth/api_trace.go new file mode 100644 index 000000000000..2f27b9166400 --- /dev/null +++ b/eth/api_trace.go @@ -0,0 +1,46 @@ +package eth + +import ( + "context" + "encoding/json" + + "github.com/ethereum/go-ethereum/eth/tracers" + "github.com/ethereum/go-ethereum/internal/ethapi" + "github.com/ethereum/go-ethereum/rpc" +) + +type TraceAPI struct { + backend *EthAPIBackend + tracerAPI *tracers.API +} + +func NewTraceAPI(b *EthAPIBackend) *TraceAPI { + return &TraceAPI{ + backend: b, + tracerAPI: tracers.NewAPI(b), + } +} + +// CallMany simulate a series of transactions in latest block +func (api *TraceAPI) CallMany(ctx context.Context, txs []ethapi.TransactionArgs) (map[string]interface{}, error) { + // get latest block number + latestBlockNumOrHash := rpc.BlockNumberOrHashWithNumber(rpc.LatestBlockNumber) + // prepare stateDiff tracer + tracerName := "stateDiffTracer" + config := tracers.TraceCallConfig{ + TraceConfig: tracers.TraceConfig{ + Tracer: &tracerName, + TracerConfig: json.RawMessage("{\"onlyTopCall\": false, \"withLog\": false}"), + }, + } + // trace + traceResult, err := api.tracerAPI.TraceCallMany(ctx, txs, latestBlockNumOrHash, &config) + if err != nil { + return nil, err + } + result := map[string]interface{}{ + "blockNumber": latestBlockNumOrHash.BlockNumber.String(), + "traceResult": traceResult, + } + return result, nil +} diff --git a/eth/backend.go b/eth/backend.go index cc555f0c23fd..1d27dcc8f771 100644 --- a/eth/backend.go +++ b/eth/backend.go @@ -316,6 +316,9 @@ func (s *Ethereum) APIs() []rpc.API { }, { Namespace: "net", Service: s.netRPCService, + }, { + Namespace: "trace", + Service: NewTraceAPI(s.APIBackend), }, }...) } diff --git a/eth/tracers/api.go b/eth/tracers/api.go index 740a38ab9fbf..805cb4b88885 100644 --- a/eth/tracers/api.go +++ b/eth/tracers/api.go @@ -861,6 +861,65 @@ func (api *API) TraceTransaction(ctx context.Context, hash common.Hash, config * return api.traceTx(ctx, msg, txctx, vmctx, statedb, config) } +func (api *API) TraceCallMany(ctx context.Context, args []ethapi.TransactionArgs, blockNrOrHash rpc.BlockNumberOrHash, config *TraceCallConfig) ([]interface{}, error) { + // Try to retrieve the specified block + var ( + err error + block *types.Block + ) + if hash, ok := blockNrOrHash.Hash(); ok { + block, err = api.blockByHash(ctx, hash) + } else if number, ok := blockNrOrHash.Number(); ok { + if number == rpc.PendingBlockNumber { + return nil, errors.New("tracing on top of pending is not supported") + } + block, err = api.blockByNumber(ctx, number) + } else { + return nil, errors.New("invalid arguments; neither block nor hash specified") + } + if err != nil { + return nil, err + } + // try to recompute the state + reexec := defaultTraceReexec + if config != nil && config.Reexec != nil { + reexec = *config.Reexec + } + statedb, release, err := api.backend.StateAtBlock(ctx, block, reexec, nil, true, false) + if err != nil { + return nil, err + } + defer release() + vmctx := core.NewEVMBlockContext(block.Header(), api.chainContext(ctx), nil) + // Apply the customization rules if required. + if config != nil { + if err := config.StateOverrides.Apply(statedb); err != nil { + return nil, err + } + config.BlockOverrides.Apply(&vmctx) + } + var traceConfig *TraceConfig + if config != nil { + traceConfig = &config.TraceConfig + } + + // loop over all the transactions and trace internal calls + result := []interface{}{} + for _, arg := range args { + msg, err := arg.ToMessage(api.backend.RPCGasCap(), block.BaseFee()) + if err != nil { + return nil, err + } + res, err := api.traceTx(ctx, msg, new(Context), vmctx, statedb, traceConfig) + if err != nil { + return nil, err + } + // append all results + result = append(result, res) + } + return result, nil +} + // TraceCall lets you trace a given eth_call. It collects the structured logs // created during the execution of EVM if the given transaction was added on // top of the provided block and returns them as a JSON object. @@ -912,7 +971,6 @@ func (api *API) TraceCall(ctx context.Context, args ethapi.TransactionArgs, bloc if err != nil { return nil, err } - var traceConfig *TraceConfig if config != nil { traceConfig = &config.TraceConfig diff --git a/eth/tracers/native/statediff.go b/eth/tracers/native/statediff.go new file mode 100644 index 000000000000..19137997ecf5 --- /dev/null +++ b/eth/tracers/native/statediff.go @@ -0,0 +1,290 @@ +package native + +import ( + "encoding/hex" + "encoding/json" + "fmt" + "math/big" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/vm" + "github.com/ethereum/go-ethereum/eth/tracers" +) + +func init() { + tracers.DefaultDirectory.Register("stateDiffTracer", newStateTracer, false) +} + +type diff[T any] struct { + before T + after T +} + +type accountDiff struct { + balanceDelta *big.Int + nonceDelta int + storage map[common.Hash]diff[common.Hash] + code diff[[]byte] +} + +// StateDiffLogger implements Tracer interface +type StateDiffTracer struct { + accounts map[common.Address]accountDiff + env *vm.EVM + tracer *callTracer +} + +func newStateTracer(ctx *tracers.Context, cfg json.RawMessage) (tracers.Tracer, error) { + t, err := newCallTracer(ctx, cfg) + if err != nil { + return nil, err + } + return &StateDiffTracer{ + tracer: t.(*callTracer), + accounts: make(map[common.Address]accountDiff), + }, nil +} + +func (l *StateDiffTracer) CaptureTxStart(gasLimit uint64) { + l.tracer.CaptureTxStart(gasLimit) +} + +func (l *StateDiffTracer) CaptureTxEnd(restGas uint64) { + l.tracer.CaptureTxEnd(restGas) + callFrame := l.tracer.callstack[0] + caller := callFrame.From + used := callFrame.GasUsed + // record gas used here instead of capture whenever gas is used, because need to consider intrinsic gas + l.recordBalanceChange(caller, big.NewInt(-int64(used))) + // additional nonce increment when first call is not CREATE + if callFrame.Type != vm.CREATE { + l.recordNonceIncrese(caller) + } +} + +func (l *StateDiffTracer) CaptureStart(env *vm.EVM, from common.Address, to common.Address, create bool, input []byte, gas uint64, value *big.Int) { + l.env = env + l.tracer.CaptureStart(env, from, to, create, input, gas, value) + if create { + // record noce increment + l.recordNonceIncrese(from) + } +} + +func (l *StateDiffTracer) CaptureEnd(output []byte, gasUsed uint64, err error) { + l.tracer.CaptureEnd(output, gasUsed, err) + callframe := l.tracer.callstack[0] + // Note: do not record gasUsed here. All gas used value is recorded in TxEnd + + opType := callframe.Type + switch opType { + case vm.CREATE, vm.CREATE2, vm.CALL: + if opType == vm.CREATE || opType == vm.CREATE2 { + // record the code + contract := *callframe.To + l.recordCode(contract, l.env.StateDB.GetCode(contract)) + } + // ether transfer + value := callframe.Value + if value != nil { + from := callframe.From + to := *callframe.To + l.recordBalanceChange(from, big.NewInt(0).Neg(value)) + l.recordBalanceChange(to, value) + } + } +} + +func (l *StateDiffTracer) CaptureEnter(typ vm.OpCode, from common.Address, to common.Address, input []byte, gas uint64, value *big.Int) { + l.tracer.CaptureEnter(typ, from, to, input, gas, value) + if typ == vm.CREATE || typ == vm.CREATE2 { + // record noce increment + l.recordNonceIncrese(from) + } +} + +func (l *StateDiffTracer) CaptureExit(output []byte, gasUsed uint64, err error) { + l.tracer.CaptureExit(output, gasUsed, err) + // retrieve the last callframe in last callstack + lastCallStack := l.tracer.callstack[len(l.tracer.callstack)-1].Calls + callframe := lastCallStack[len(lastCallStack)-1] + // Note: do not record gasUsed here. All gas used value is recorded in TxEnd + + opType := callframe.Type + switch opType { + case vm.CREATE, vm.CREATE2, vm.CALL: + if opType == vm.CREATE || opType == vm.CREATE2 { + // record the code + contract := *callframe.To + l.recordCode(contract, callframe.Input) + } + // ether transfer + value := callframe.Value + if value != nil { + from := callframe.From + to := *callframe.To + l.recordBalanceChange(from, big.NewInt(0).Neg(value)) + l.recordBalanceChange(to, value) + } + case vm.SELFDESTRUCT: + // destruct this contract. code is empty and balance is zero + contract := *callframe.To + l.recordCode(contract, []byte{}) + l.recordBalanceChange(contract, big.NewInt(0).Neg(l.env.StateDB.GetBalance(contract))) + } +} + +func (l *StateDiffTracer) CaptureFault(pc uint64, op vm.OpCode, gas, cost uint64, scope *vm.ScopeContext, depth int, err error) { + l.tracer.CaptureFault(pc, op, gas, cost, scope, depth, err) +} + +func (l *StateDiffTracer) CaptureState(pc uint64, op vm.OpCode, gas, cost uint64, scope *vm.ScopeContext, rData []byte, depth int, err error) { + if op == vm.SSTORE { + contract := scope.Contract + stack := scope.Stack + stackLen := len(stack.Data()) + if stackLen >= 2 { + value := common.Hash(stack.Data()[stackLen-2].Bytes32()) + address := common.Hash(stack.Data()[stackLen-1].Bytes32()) + // record storage change + l.recordStorage(contract.Address(), address, value) + } + } +} + +func (l *StateDiffTracer) GetResult() (json.RawMessage, error) { + stateDiffResult := map[string]accountReport{} + for addr, diff := range l.accounts { + stateDiffResult[addr.Hex()] = l.report(addr, diff) + } + result := map[string]interface{}{ + // only stateDiff result is supported now + "stateDiff": stateDiffResult, + } + return json.Marshal(result) +} + +func (l *StateDiffTracer) Stop(err error) { + l.tracer.Stop(err) +} + +func (l *StateDiffTracer) tryInitAccDiff(addr common.Address) bool { + if _, ok := l.accounts[addr]; !ok { + l.accounts[addr] = accountDiff{ + balanceDelta: big.NewInt(0), + storage: make(map[common.Hash]diff[common.Hash]), + code: diff[[]byte]{nil, nil}, + } + return true + } + return false +} + +func (l *StateDiffTracer) recordNonceIncrese(addr common.Address) { + l.tryInitAccDiff(addr) + diff := l.accounts[addr] + diff.nonceDelta++ + l.accounts[addr] = diff +} + +func (l *StateDiffTracer) recordCode(addr common.Address, code []byte) { + isInit := l.tryInitAccDiff(addr) + diff := l.accounts[addr] + if isInit { + // init non-nil code before change + beforeCode := l.env.StateDB.GetCode(addr) + if beforeCode == nil { + beforeCode = []byte{} + } + diff.code.before = beforeCode + } + + diff.code.after = code + l.accounts[addr] = diff +} + +func (l *StateDiffTracer) recordStorage(addr common.Address, key, after common.Hash) { + isInit := l.tryInitAccDiff(addr) + value := l.accounts[addr].storage[key] + value.after = after + if isInit { + // take only the initial value + value.before = l.env.StateDB.GetState(addr, key) + } + l.accounts[addr].storage[key] = value +} + +// update balance +func (l *StateDiffTracer) recordBalanceChange(addr common.Address, delta *big.Int) { + l.tryInitAccDiff(addr) + diff := l.accounts[addr] + diff.balanceDelta.Add(diff.balanceDelta, delta) + l.accounts[addr] = diff +} + +type accountReport struct { + Balance any `json:"balance"` + Nonce any `json:"nonce"` + Code any `json:"code"` + Storage map[string]fromTo `json:"storage"` +} +type fromTo struct { + From string `json:"from"` + To string `json:"to"` +} + +func (l *StateDiffTracer) report(addr common.Address, a accountDiff) accountReport { + result := accountReport{ + Balance: "=", + Nonce: "=", + Code: "=", + Storage: make(map[string]fromTo), + } + // balance + if a.balanceDelta != nil && a.balanceDelta.Sign() != 0 { + delta := a.balanceDelta + current := l.env.StateDB.GetBalance(addr) + result.Balance = fromTo{ + // from = current - delta. transform to hex + From: fmt.Sprintf("0x%x", big.NewInt(0).Sub(current, delta).Text(16)), + To: fmt.Sprintf("0x%x", current.Text(16)), + } + } + // nonce + if a.nonceDelta != 0 { + current := l.env.StateDB.GetNonce(addr) + result.Nonce = fromTo{ + // in hex + From: fmt.Sprintf("0x%x", current-uint64(a.nonceDelta)), + To: fmt.Sprintf("0x%x", current), + } + } + // code + if a.code.before != nil || a.code.after != nil { + before, after := "", "" + if a.code.before != nil { + before = hex.EncodeToString(a.code.before) + } + if a.code.after != nil { + after = hex.EncodeToString(a.code.after) + } + if before != after { + result.Code = fromTo{ + From: before, + To: after, + } + } + } + // storage + for k, v := range a.storage { + before := v.before.Hex() + after := v.after.Hex() + if before != after { + result.Storage[k.Hex()] = fromTo{ + From: before, + To: after, + } + } + } + return result +}