Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 202 additions & 0 deletions credentials/alts/internal/conn/record_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,11 @@ import (
"net"
"reflect"
"strings"
"syscall"
"testing"
"time"

"golang.org/x/sys/unix"
core "google.golang.org/grpc/credentials/alts/internal"
"google.golang.org/grpc/internal/grpctest"
)
Expand Down Expand Up @@ -105,6 +108,94 @@ func newConnPair(rp string, clientProtected []byte, serverProtected []byte) (cli
return clientConn, serverConn
}

// newTCPConnPair returns a pair of conns backed by TCP over loopback.
func newTCPConnPair(rp string, clientProtected []byte, serverProtected []byte) (*conn, *conn, error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that this function is intended to only be used from the newly added benchmark, you could consider making it a helper function by passing testing.B as the first parameter and calling b.Helper().

Also, please consider replacing calls to panic with calls to b.Fatal or b.Fatalf.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Please consider replacing rp with recordProtocol as really short variable names make more sense where the scope of the variable is much smaller. This seems to be a reasonable large function, and rp is used way down below.

Also, clientProtected []byte, serverProtected []byte could be shortened as clientProtected, serverProtected []byte

const address = "localhost:50935"

// Start the server.
serverChan := make(chan net.Conn)
listenChan := make(chan struct{})
go func() {
listener, err := net.Listen("tcp4", address)
if err != nil {
panic(fmt.Sprintf("failed to listen: %v", err))
}
defer listener.Close()
listenChan <- struct{}{}
conn, err := listener.Accept()
if err != nil {
panic(fmt.Sprintf("failed to aceept: %v", err))
}
serverChan <- conn
}()

// Ensure the server is listening before trying to connect.
<-listenChan
clientTCP, err := net.DialTimeout("tcp4", address, 5*time.Second)
if err != nil {
return nil, nil, fmt.Errorf("failed to Dial: %w", err)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These calls to fmt.Errorf could also be replaced with calls to b.Fatalf and remove the error return value from this function. So, the caller of this function will not have to handle the error and call b.Fatal.

}

// Get the server-side connection returned by Accept().
var serverTCP net.Conn
select {
case serverTCP = <-serverChan:
case <-time.After(5 * time.Second):
return nil, nil, fmt.Errorf("timed out waiting for server conn")
}

// Make the connection behave a little bit like a real one by imposing
// an MTU.
clientTCP = &mtuConn{clientTCP, 1500}

// 16 arbitrary bytes.
key := []byte{
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88,
0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49,
}

client, err := NewConn(clientTCP, core.ClientSide, rp, key, clientProtected)
if err != nil {
panic(fmt.Sprintf("Unexpected error creating test ALTS record connection: %v", err))
}
server, err := NewConn(serverTCP, core.ServerSide, rp, key, serverProtected)
if err != nil {
panic(fmt.Sprintf("Unexpected error creating test ALTS record connection: %v", err))
}

return client.(*conn), server.(*conn), nil
}

// mtuConn imposes an MTU on writes. It simulates an important quality of real
// network traffic that is lost when using loopback devices. On loopback, even
// large messages (e.g. 512 KiB) when written often arrive at the receiver
// instantaneously as a single payload. By explicitly splitting such writes into
// smaller, MTU-sized paylaods we give the receiver a chance to respond to
// smaller message sizes.
type mtuConn struct {
net.Conn
mtu int
}

// Write implements net.Conn.
func (rc *mtuConn) Write(buf []byte) (int, error) {
var written int
for len(buf) > 0 {
n, err := rc.Conn.Write(buf[:min(rc.mtu, len(buf))])
written += n
if err != nil {
return written, err
}
buf = buf[n:]
}
return written, nil
}

// SyscallConn implements syscall.Conn.
func (rc *mtuConn) SycallConn() (syscall.RawConn, error) {
return rc.Conn.(syscall.Conn).SyscallConn()
}

func testPingPong(t *testing.T, rp string) {
clientConn, serverConn := newConnPair(rp, nil, nil)
clientMsg := []byte("Client Message")
Expand Down Expand Up @@ -231,6 +322,117 @@ func BenchmarkLargeMessage(b *testing.B) {
}
}

// BenchmarkTCP is a simple throughput test that sends payloads over a local TCP
// connection. Run via:
//
// go test -run="^$" -bench="BenchmarkTCP" ./credentials/alts/internal/conn
func BenchmarkTCP(b *testing.B) {
tcs := []struct {
name string
size int
}{
{"1 KiB", 1024},
{"4 KiB", 4 * 1024},
{"64 KiB", 64 * 1024},
{"512 KiB", 512 * 1024},
{"1 MiB", 1024 * 1024},
{"4 MiB", 4 * 1024 * 1024},
}
for _, tc := range tcs {
b.Run("size="+tc.name, func(b *testing.B) {
benchmarkTCP(b, tc.size)
})
}
}

// sum makes unwanted compiler optimizations in benchmarkTCP's loop less likely.
var sum int

func benchmarkTCP(b *testing.B, size int) {
// Initialize the connection.
client, server, err := newTCPConnPair(rekeyRecordProtocol, nil, nil)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the last two arguments are always expected to be nil, we could remove those parameters and pass nil to NewConn from newTCPConnPair.

if err != nil {
b.Fatalf("failed to create TCP conn pair: %v", err)
}
defer client.Close()
defer server.Close()

rcvBuf := make([]byte, size)
sndBuf := make([]byte, size)
done := make(chan struct{})
errChan := make(chan error)

// Launch a writer goroutine.
go func() {
for {
select {
case <-done:
return
default:
}
n, err := client.Write(sndBuf)
if n != size || err != nil {
errChan <- fmt.Errorf("Write() = %v, %v; want %v, <nil>", n, err, size)
return
}
// Act a bit like a real workload that can't just fill
// every buffer immediately.
time.Sleep(10 * time.Millisecond)
}
}()

// Get the initial rusage so we can measure CPU time.
var startUsage unix.Rusage
if err := unix.Getrusage(unix.RUSAGE_SELF, &startUsage); err != nil {
b.Fatalf("failed to get initial rusage: %v", err)
}

// Read as much as possible.
var rcvd uint64
for b.Loop() {
n, err := io.ReadFull(server, rcvBuf)
rcvd += uint64(n)
if n != size || err != nil {
b.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, size)
}
// Act a bit like a real workload and utilize received bytes.
for _, b := range rcvBuf[:n] {
sum += int(b)
}
}

// Turn off the writer.
done <- struct{}{}

// Get the ending rusage.
var endUsage unix.Rusage
if err := unix.Getrusage(unix.RUSAGE_SELF, &endUsage); err != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to use more portable alternatives. The sys/unix package will probably only build on unix and unix-like systems.

Options that I see are:

b.Fatalf("failed to get final rusage: %v", err)
}

// Error check the writer goroutine.
select {
case err := <-errChan:
b.Fatal(err)
default:
}

// Emit extra metrics.
utime := timevalDiffUsec(&startUsage.Utime, &endUsage.Utime)
stime := timevalDiffUsec(&startUsage.Stime, &endUsage.Stime)
b.ReportMetric(float64(utime)/float64(b.N), "usr-usec/op")
b.ReportMetric(float64(stime)/float64(b.N), "sys-usec/op")
b.ReportMetric(float64(stime+utime)/float64(b.N), "cpu-usec/op")
b.ReportMetric(float64(rcvd*8/(1024*1024))/float64(b.Elapsed().Seconds()), "Mbps")
}

// timevalDiffUsec returns the difference in microseconds between start and end.
func timevalDiffUsec(start, end *unix.Timeval) int64 {
// Note: the int64 type conversion is needed because unix.Timeval uses
// 32 bit values on some architectures.
return int64(1_000_000*(end.Sec-start.Sec) + end.Usec - start.Usec)
}

func testIncorrectMsgType(t *testing.T, rp string) {
// framedMsg is an empty ciphertext with correct framing but wrong
// message type.
Expand Down
Loading