diff --git a/bindings/go/main.go b/bindings/go/main.go index bdd5385b5..bf344a1d5 100644 --- a/bindings/go/main.go +++ b/bindings/go/main.go @@ -6,10 +6,10 @@ package ckzg4844 import "C" import ( + "bytes" "encoding/hex" "errors" "fmt" - "strings" "unsafe" // So its functions are available during compilation. @@ -38,12 +38,6 @@ var ( ErrBadArgs = errors.New("bad arguments") ErrError = errors.New("unexpected error") ErrMalloc = errors.New("malloc failed") - errorMap = map[C.C_KZG_RET]error{ - C.C_KZG_OK: nil, - C.C_KZG_BADARGS: ErrBadArgs, - C.C_KZG_ERROR: ErrError, - C.C_KZG_MALLOC: ErrMalloc, - } ) /////////////////////////////////////////////////////////////////////////////// @@ -51,14 +45,18 @@ var ( /////////////////////////////////////////////////////////////////////////////// // makeErrorFromRet translates an (integral) return value, as reported -// by the C library, into a proper Go error. If there is no error, this -// will return nil. +// by the C library, into a proper Go error. This function should only be +// called when there is an error, not with C_KZG_OK. func makeErrorFromRet(ret C.C_KZG_RET) error { - err, ok := errorMap[ret] - if !ok { - panic(fmt.Sprintf("unexpected return value: %v", ret)) + switch ret { + case C.C_KZG_BADARGS: + return ErrBadArgs + case C.C_KZG_ERROR: + return ErrError + case C.C_KZG_MALLOC: + return ErrMalloc } - return err + return fmt.Errorf("unexpected error from c-library: %v", ret) } /////////////////////////////////////////////////////////////////////////////// @@ -66,50 +64,53 @@ func makeErrorFromRet(ret C.C_KZG_RET) error { /////////////////////////////////////////////////////////////////////////////// func (b *Bytes32) UnmarshalText(input []byte) error { - inputStr := string(input) - if strings.HasPrefix(inputStr, "0x") { - inputStr = strings.TrimPrefix(inputStr, "0x") + if bytes.HasPrefix(input, []byte("0x")) { + input = input[2:] + } + if len(input) != 2*len(b) { + return ErrBadArgs } - bytes, err := hex.DecodeString(inputStr) + l, err := hex.Decode(b[:], input) if err != nil { return err } - if len(bytes) != len(b) { + if l != len(b) { return ErrBadArgs } - copy(b[:], bytes) return nil } func (b *Bytes48) UnmarshalText(input []byte) error { - inputStr := string(input) - if strings.HasPrefix(inputStr, "0x") { - inputStr = strings.TrimPrefix(inputStr, "0x") + if bytes.HasPrefix(input, []byte("0x")) { + input = input[2:] } - bytes, err := hex.DecodeString(inputStr) + if len(input) != 2*len(b) { + return ErrBadArgs + } + l, err := hex.Decode(b[:], input) if err != nil { return err } - if len(bytes) != len(b) { + if l != len(b) { return ErrBadArgs } - copy(b[:], bytes) return nil } func (b *Blob) UnmarshalText(input []byte) error { - inputStr := string(input) - if strings.HasPrefix(inputStr, "0x") { - inputStr = strings.TrimPrefix(inputStr, "0x") + if bytes.HasPrefix(input, []byte("0x")) { + input = input[2:] + } + if len(input) != 2*len(b) { + return ErrBadArgs } - bytes, err := hex.DecodeString(inputStr) + l, err := hex.Decode(b[:], input) if err != nil { return err } - if len(bytes) != len(b) { + if l != len(b) { return ErrBadArgs } - copy(b[:], bytes) return nil } @@ -147,6 +148,7 @@ func LoadTrustedSetup(g1Bytes, g2Bytes []byte) error { (C.size_t)(numG2Elements)) if ret == C.C_KZG_OK { loaded = true + return nil } return makeErrorFromRet(ret) } @@ -174,6 +176,7 @@ func LoadTrustedSetupFile(trustedSetupFile string) error { C.fclose(fp) if ret == C.C_KZG_OK { loaded = true + return nil } return makeErrorFromRet(ret) } @@ -200,16 +203,24 @@ BlobToKZGCommitment is the binding for: const Blob *blob, const KZGSettings *s); */ -func BlobToKZGCommitment(blob Blob) (KZGCommitment, error) { +func BlobToKZGCommitment(blob *Blob) (KZGCommitment, error) { if !loaded { panic("trusted setup isn't loaded") } - commitment := KZGCommitment{} + if blob == nil { + return KZGCommitment{}, ErrBadArgs + } + + var commitment KZGCommitment ret := C.blob_to_kzg_commitment( (*C.KZGCommitment)(unsafe.Pointer(&commitment)), - (*C.Blob)(unsafe.Pointer(&blob)), + (*C.Blob)(unsafe.Pointer(blob)), &settings) - return commitment, makeErrorFromRet(ret) + + if ret != C.C_KZG_OK { + return KZGCommitment{}, makeErrorFromRet(ret) + } + return commitment, nil } /* @@ -222,19 +233,28 @@ ComputeKZGProof is the binding for: const Bytes32 *z_bytes, const KZGSettings *s); */ -func ComputeKZGProof(blob Blob, zBytes Bytes32) (KZGProof, Bytes32, error) { +func ComputeKZGProof(blob *Blob, zBytes Bytes32) (KZGProof, Bytes32, error) { if !loaded { panic("trusted setup isn't loaded") } - proof := KZGProof{} - y := Bytes32{} + if blob == nil { + return KZGProof{}, Bytes32{}, ErrBadArgs + } + var ( + proof KZGProof + y Bytes32 + ) ret := C.compute_kzg_proof( (*C.KZGProof)(unsafe.Pointer(&proof)), (*C.Bytes32)(unsafe.Pointer(&y)), - (*C.Blob)(unsafe.Pointer(&blob)), + (*C.Blob)(unsafe.Pointer(blob)), (*C.Bytes32)(unsafe.Pointer(&zBytes)), &settings) - return proof, y, makeErrorFromRet(ret) + + if ret != C.C_KZG_OK { + return KZGProof{}, Bytes32{}, makeErrorFromRet(ret) + } + return proof, y, nil } /* @@ -246,17 +266,24 @@ ComputeBlobKZGProof is the binding for: const Bytes48 *commitment_bytes, const KZGSettings *s); */ -func ComputeBlobKZGProof(blob Blob, commitmentBytes Bytes48) (KZGProof, error) { +func ComputeBlobKZGProof(blob *Blob, commitmentBytes Bytes48) (KZGProof, error) { if !loaded { panic("trusted setup isn't loaded") } - proof := KZGProof{} + if blob == nil { + return KZGProof{}, ErrBadArgs + } + var proof KZGProof ret := C.compute_blob_kzg_proof( (*C.KZGProof)(unsafe.Pointer(&proof)), - (*C.Blob)(unsafe.Pointer(&blob)), + (*C.Blob)(unsafe.Pointer(blob)), (*C.Bytes48)(unsafe.Pointer(&commitmentBytes)), &settings) - return proof, makeErrorFromRet(ret) + + if ret != C.C_KZG_OK { + return KZGProof{}, makeErrorFromRet(ret) + } + return proof, nil } /* @@ -282,7 +309,11 @@ func VerifyKZGProof(commitmentBytes Bytes48, zBytes, yBytes Bytes32, proofBytes (*C.Bytes32)(unsafe.Pointer(&yBytes)), (*C.Bytes48)(unsafe.Pointer(&proofBytes)), &settings) - return bool(result), makeErrorFromRet(ret) + + if ret != C.C_KZG_OK { + return false, makeErrorFromRet(ret) + } + return bool(result), nil } /* @@ -295,18 +326,26 @@ VerifyBlobKZGProof is the binding for: const Bytes48 *proof_bytes, const KZGSettings *s); */ -func VerifyBlobKZGProof(blob Blob, commitmentBytes, proofBytes Bytes48) (bool, error) { +func VerifyBlobKZGProof(blob *Blob, commitmentBytes, proofBytes Bytes48) (bool, error) { if !loaded { panic("trusted setup isn't loaded") } + if blob == nil { + return false, ErrBadArgs + } + var result C.bool ret := C.verify_blob_kzg_proof( &result, - (*C.Blob)(unsafe.Pointer(&blob)), + (*C.Blob)(unsafe.Pointer(blob)), (*C.Bytes48)(unsafe.Pointer(&commitmentBytes)), (*C.Bytes48)(unsafe.Pointer(&proofBytes)), &settings) - return bool(result), makeErrorFromRet(ret) + + if ret != C.C_KZG_OK { + return false, makeErrorFromRet(ret) + } + return bool(result), nil } /* @@ -335,5 +374,9 @@ func VerifyBlobKZGProofBatch(blobs []Blob, commitmentsBytes, proofsBytes []Bytes *(**C.Bytes48)(unsafe.Pointer(&proofsBytes)), (C.size_t)(len(blobs)), &settings) - return bool(result), makeErrorFromRet(ret) + + if ret != C.C_KZG_OK { + return false, makeErrorFromRet(ret) + } + return bool(result), nil } diff --git a/bindings/go/main_test.go b/bindings/go/main_test.go index 258883de4..99d1f3bd2 100644 --- a/bindings/go/main_test.go +++ b/bindings/go/main_test.go @@ -12,9 +12,9 @@ import ( ) func TestMain(m *testing.M) { - err := LoadTrustedSetupFile("../../src/trusted_setup.txt") - if err != nil { - panic("failed to load trusted setup") + + if err := LoadTrustedSetupFile("../../src/trusted_setup.txt"); err != nil { + panic(fmt.Sprintf("failed to load trusted setup: %v", err)) } defer FreeTrustedSetup() code := m.Run() @@ -40,13 +40,11 @@ func getRandFieldElement(seed int64) Bytes32 { return fieldElementBytes } -func getRandBlob(seed int64) Blob { - var blob Blob +func fillBlobRandom(blob *Blob, seed int64) { for i := 0; i < BytesPerBlob; i += BytesPerFieldElement { fieldElementBytes := getRandFieldElement(seed + int64(i)) copy(blob[i:i+BytesPerFieldElement], fieldElementBytes[:]) } - return blob } /////////////////////////////////////////////////////////////////////////////// @@ -84,7 +82,7 @@ func TestBlobToKZGCommitment(t *testing.T) { require.NoError(t, testFile.Close()) require.NoError(t, err) - var blob Blob + blob := new(Blob) err = blob.UnmarshalText([]byte(test.Input.Blob)) if err != nil { require.Nil(t, test.Output) @@ -124,7 +122,7 @@ func TestComputeKZGProof(t *testing.T) { require.NoError(t, testFile.Close()) require.NoError(t, err) - var blob Blob + blob := new(Blob) err = blob.UnmarshalText([]byte(test.Input.Blob)) if err != nil { require.Nil(t, test.Output) @@ -178,7 +176,7 @@ func TestComputeBlobKZGProof(t *testing.T) { require.NoError(t, testFile.Close()) require.NoError(t, err) - var blob Blob + blob := new(Blob) err = blob.UnmarshalText([]byte(test.Input.Blob)) if err != nil { require.Nil(t, test.Output) @@ -289,7 +287,7 @@ func TestVerifyBlobKZGProof(t *testing.T) { require.NoError(t, testFile.Close()) require.NoError(t, err) - var blob Blob + var blob = new(Blob) err = blob.UnmarshalText([]byte(test.Input.Blob)) if err != nil { require.Nil(t, test.Output) @@ -399,10 +397,11 @@ func Benchmark(b *testing.B) { proofs := [length]Bytes48{} fields := [length]Bytes32{} for i := 0; i < length; i++ { - blob := getRandBlob(int64(i)) - commitment, err := BlobToKZGCommitment(blob) + var blob Blob + fillBlobRandom(&blob, int64(i)) + commitment, err := BlobToKZGCommitment(&blob) require.NoError(b, err) - proof, err := ComputeBlobKZGProof(blob, Bytes48(commitment)) + proof, err := ComputeBlobKZGProof(&blob, Bytes48(commitment)) require.NoError(b, err) blobs[i] = blob @@ -413,19 +412,19 @@ func Benchmark(b *testing.B) { b.Run("BlobToKZGCommitment", func(b *testing.B) { for n := 0; n < b.N; n++ { - BlobToKZGCommitment(blobs[0]) + BlobToKZGCommitment(&blobs[0]) } }) b.Run("ComputeKZGProof", func(b *testing.B) { for n := 0; n < b.N; n++ { - ComputeKZGProof(blobs[0], fields[0]) + ComputeKZGProof(&blobs[0], fields[0]) } }) b.Run("ComputeBlobKZGProof", func(b *testing.B) { for n := 0; n < b.N; n++ { - ComputeBlobKZGProof(blobs[0], commitments[0]) + ComputeBlobKZGProof(&blobs[0], commitments[0]) } }) @@ -437,7 +436,7 @@ func Benchmark(b *testing.B) { b.Run("VerifyBlobKZGProof", func(b *testing.B) { for n := 0; n < b.N; n++ { - VerifyBlobKZGProof(blobs[0], commitments[0], proofs[0]) + VerifyBlobKZGProof(&blobs[0], commitments[0], proofs[0]) } })