Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
sudo: false
language: go
go:
- 1.7.x
- 1.8.x
- 1.9.x
- 1.10.x
Expand Down
93 changes: 0 additions & 93 deletions benchmark_go18_test.go

This file was deleted.

77 changes: 77 additions & 0 deletions benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@ package mysql

import (
"bytes"
"context"
"database/sql"
"database/sql/driver"
"fmt"
"math"
"runtime"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -240,3 +243,77 @@ func BenchmarkInterpolation(b *testing.B) {
}
}
}

func benchmarkQueryContext(b *testing.B, db *sql.DB, p int) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db.SetMaxIdleConns(p * runtime.GOMAXPROCS(0))

tb := (*TB)(b)
stmt := tb.checkStmt(db.PrepareContext(ctx, "SELECT val FROM foo WHERE id=?"))
defer stmt.Close()

b.SetParallelism(p)
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
var got string
for pb.Next() {
tb.check(stmt.QueryRow(1).Scan(&got))
if got != "one" {
b.Fatalf("query = %q; want one", got)
}
}
})
}

func BenchmarkQueryContext(b *testing.B) {
db := initDB(b,
"DROP TABLE IF EXISTS foo",
"CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))",
`INSERT INTO foo VALUES (1, "one")`,
`INSERT INTO foo VALUES (2, "two")`,
)
defer db.Close()
for _, p := range []int{1, 2, 3, 4} {
b.Run(fmt.Sprintf("%d", p), func(b *testing.B) {
benchmarkQueryContext(b, db, p)
})
}
}

func benchmarkExecContext(b *testing.B, db *sql.DB, p int) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db.SetMaxIdleConns(p * runtime.GOMAXPROCS(0))

tb := (*TB)(b)
stmt := tb.checkStmt(db.PrepareContext(ctx, "DO 1"))
defer stmt.Close()

b.SetParallelism(p)
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
if _, err := stmt.ExecContext(ctx); err != nil {
b.Fatal(err)
}
}
})
}

func BenchmarkExecContext(b *testing.B) {
db := initDB(b,
"DROP TABLE IF EXISTS foo",
"CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))",
`INSERT INTO foo VALUES (1, "one")`,
`INSERT INTO foo VALUES (2, "two")`,
)
defer db.Close()
for _, p := range []int{1, 2, 3, 4} {
b.Run(fmt.Sprintf("%d", p), func(b *testing.B) {
benchmarkQueryContext(b, db, p)
})
}
}
193 changes: 193 additions & 0 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
package mysql

import (
"context"
"database/sql"
"database/sql/driver"
"io"
"net"
Expand Down Expand Up @@ -459,3 +461,194 @@ func (mc *mysqlConn) finish() {
case <-mc.closech:
}
}

// Ping implements driver.Pinger interface
func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
if mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
return driver.ErrBadConn
}

if err = mc.watchCancel(ctx); err != nil {
return
}
defer mc.finish()

if err = mc.writeCommandPacket(comPing); err != nil {
return
}

return mc.readResultOK()
}

// BeginTx implements driver.ConnBeginTx interface
func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
if err := mc.watchCancel(ctx); err != nil {
return nil, err
}
defer mc.finish()

if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault {
level, err := mapIsolationLevel(opts.Isolation)
if err != nil {
return nil, err
}
err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level)
if err != nil {
return nil, err
}
}

return mc.begin(opts.ReadOnly)
}

func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}

if err := mc.watchCancel(ctx); err != nil {
return nil, err
}

rows, err := mc.query(query, dargs)
if err != nil {
mc.finish()
return nil, err
}
rows.finish = mc.finish
return rows, err
}

func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}

if err := mc.watchCancel(ctx); err != nil {
return nil, err
}
defer mc.finish()

return mc.Exec(query, dargs)
}

func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
if err := mc.watchCancel(ctx); err != nil {
return nil, err
}

stmt, err := mc.Prepare(query)
mc.finish()
if err != nil {
return nil, err
}

select {
default:
case <-ctx.Done():
stmt.Close()
return nil, ctx.Err()
}
return stmt, nil
}

func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}

if err := stmt.mc.watchCancel(ctx); err != nil {
return nil, err
}

rows, err := stmt.query(dargs)
if err != nil {
stmt.mc.finish()
return nil, err
}
rows.finish = stmt.mc.finish
return rows, err
}

func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}

if err := stmt.mc.watchCancel(ctx); err != nil {
return nil, err
}
defer stmt.mc.finish()

return stmt.Exec(dargs)
}

func (mc *mysqlConn) watchCancel(ctx context.Context) error {
if mc.watching {
// Reach here if canceled,
// so the connection is already invalid
mc.cleanup()
return nil
}
if ctx.Done() == nil {
return nil
}

mc.watching = true
select {
default:
case <-ctx.Done():
return ctx.Err()
}
if mc.watcher == nil {
return nil
}

mc.watcher <- ctx

return nil
}

func (mc *mysqlConn) startWatcher() {
watcher := make(chan mysqlContext, 1)
mc.watcher = watcher
finished := make(chan struct{})
mc.finished = finished
go func() {
for {
var ctx mysqlContext
select {
case ctx = <-watcher:
case <-mc.closech:
return
}

select {
case <-ctx.Done():
mc.cancel(ctx.Err())
case <-finished:
case <-mc.closech:
return
}
}
}()
}

func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
nv.Value, err = converter{}.ConvertValue(nv.Value)
return
}

// ResetSession implements driver.SessionResetter.
// (From Go 1.10)
func (mc *mysqlConn) ResetSession(ctx context.Context) error {
if mc.closed.IsSet() {
return driver.ErrBadConn
}
return nil
}
Loading