Skip to content

Commit 7bf35c1

Browse files
authored
Merge pull request #72 from databendcloud/fix/uint64-max
fix: value out of range for type uint64
2 parents ea1c2fe + af3fbb1 commit 7bf35c1

File tree

9 files changed

+222
-82
lines changed

9 files changed

+222
-82
lines changed

cmd/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ func main() {
9494
panic(err)
9595
}
9696
// adjust batch size according to source db table
97-
cfgCopy.BatchSize = src.AdjustBatchSizeAccordingToSourceDbTable()
97+
cfgCopy.BatchSize = int64(src.AdjustBatchSizeAccordingToSourceDbTable())
9898
w := worker.NewWorker(&cfgCopy, fmt.Sprintf("%s.%s", db, table), ig, src)
9999
w.Run(ctx)
100100
}

config/conf_test.json

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,17 @@
44
"sourceUser": "root",
55
"sourcePass": "123456",
66
"sourceDB": "mydb",
7-
"sourceTable": "t1",
8-
"sourceQuery": "select * from mydb.t1",
7+
"sourceTable": "test_table",
8+
"sourceQuery": "select * from mydb.test_table",
99
"sourceWhereCondition": "id > 0",
1010
"sourceSplitKey": "id",
1111
"sourceSplitTimeKey": "",
1212
"timeSplitUnit": "minute",
13-
"databendDSN": "http://databend:databend@localhost:8000",
14-
"databendTable": "testSync.t1",
15-
"batchSize": 2,
13+
"databendDSN": "http://databend:databend@localhost:8009",
14+
"databendTable": "testSync.test_table",
15+
"batchSize": 20000,
1616
"batchMaxInterval": 30,
1717
"userStage": "~",
18-
"deleteAfterSync": false,
18+
"deleteAfterSync": true,
1919
"maxThread": 10
2020
}

source/mysql.go

Lines changed: 95 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@ package source
22

33
import (
44
"database/sql"
5+
"database/sql/driver"
56
"fmt"
67
"log"
78
"regexp"
9+
"strconv"
810
"strings"
911
"time"
1012

@@ -40,25 +42,25 @@ func NewMysqlSource(cfg *config.Config) (*MysqlSource, error) {
4042

4143
// AdjustBatchSizeAccordingToSourceDbTable has a concept called s, s = (maxKey - minKey) / sourceTableRowCount
4244
// if s == 1 it means the data is uniform in the table, if s is much bigger than 1, it means the data is not uniform in the table
43-
func (s *MysqlSource) AdjustBatchSizeAccordingToSourceDbTable() int64 {
45+
func (s *MysqlSource) AdjustBatchSizeAccordingToSourceDbTable() uint64 {
4446
minSplitKey, maxSplitKey, err := s.GetMinMaxSplitKey()
4547
if err != nil {
46-
return s.cfg.BatchSize
48+
return uint64(s.cfg.BatchSize)
4749
}
4850
sourceTableRowCount, err := s.GetSourceReadRowsCount()
4951
if err != nil {
50-
return s.cfg.BatchSize
52+
return uint64(s.cfg.BatchSize)
5153
}
5254
rangeSize := maxSplitKey - minSplitKey + 1
5355
switch {
5456
case int64(sourceTableRowCount) <= s.cfg.BatchSize:
5557
return rangeSize
56-
case rangeSize/int64(sourceTableRowCount) >= 10:
57-
return s.cfg.BatchSize * 5
58-
case rangeSize/int64(sourceTableRowCount) >= 100:
59-
return s.cfg.BatchSize * 20
58+
case rangeSize/uint64(sourceTableRowCount) >= 10:
59+
return uint64(s.cfg.BatchSize * 5)
60+
case rangeSize/uint64(sourceTableRowCount) >= 100:
61+
return uint64(s.cfg.BatchSize * 20)
6062
default:
61-
return s.cfg.BatchSize
63+
return uint64(s.cfg.BatchSize)
6264
}
6365
}
6466

@@ -74,28 +76,41 @@ func (s *MysqlSource) GetSourceReadRowsCount() (int, error) {
7476
return rowCount, nil
7577
}
7678

77-
func (s *MysqlSource) GetMinMaxSplitKey() (int64, int64, error) {
78-
rows, err := s.db.Query(fmt.Sprintf("select min(%s), max(%s) from %s.%s WHERE %s", s.cfg.SourceSplitKey,
79-
s.cfg.SourceSplitKey, s.cfg.SourceDB, s.cfg.SourceTable, s.cfg.SourceWhereCondition))
79+
func (s *MysqlSource) GetMinMaxSplitKey() (uint64, uint64, error) {
80+
query := fmt.Sprintf("SELECT MIN(%s), MAX(%s) FROM %s.%s WHERE %s",
81+
s.cfg.SourceSplitKey, s.cfg.SourceSplitKey,
82+
s.cfg.SourceDB, s.cfg.SourceTable, s.cfg.SourceWhereCondition)
83+
84+
rows, err := s.db.Query(query)
8085
if err != nil {
8186
return 0, 0, err
8287
}
8388
defer rows.Close()
8489

85-
var minSplitKey, maxSplitKey sql.NullInt64
90+
var minSplitKey, maxSplitKey interface{}
8691
for rows.Next() {
8792
err = rows.Scan(&minSplitKey, &maxSplitKey)
8893
if err != nil {
8994
return 0, 0, err
9095
}
9196
}
9297

93-
// Check if minSplitKey and maxSplitKey are valid (not NULL)
94-
if !minSplitKey.Valid || !maxSplitKey.Valid {
98+
// 处理 NULL
99+
if minSplitKey == nil || maxSplitKey == nil {
95100
return 0, 0, nil
96101
}
97102

98-
return minSplitKey.Int64, maxSplitKey.Int64, nil
103+
min64, err := toUint64(minSplitKey)
104+
if err != nil {
105+
return 0, 0, fmt.Errorf("failed to convert min value: %w", err)
106+
}
107+
108+
max64, err := toUint64(maxSplitKey)
109+
if err != nil {
110+
return 0, 0, fmt.Errorf("failed to convert max value: %w", err)
111+
}
112+
113+
return min64, max64, nil
99114
}
100115

101116
func (s *MysqlSource) GetMinMaxTimeSplitKey() (string, string, error) {
@@ -117,6 +132,7 @@ func (s *MysqlSource) GetMinMaxTimeSplitKey() (string, string, error) {
117132
}
118133

119134
func (s *MysqlSource) DeleteAfterSync() error {
135+
logrus.Infof("DeleteAfterSync: %v", s.cfg.DeleteAfterSync)
120136
if !s.cfg.DeleteAfterSync {
121137
return nil
122138
}
@@ -126,6 +142,8 @@ func (s *MysqlSource) DeleteAfterSync() error {
126142
return err
127143
}
128144

145+
logrus.Infof("dbTables: %v", dbTables)
146+
129147
for db, tables := range dbTables {
130148
for _, table := range tables {
131149
count, err := s.GetSourceReadRowsCount()
@@ -188,7 +206,9 @@ func (s *MysqlSource) QueryTableData(threadNum int, conditionSql string) ([][]in
188206
switch columnType.DatabaseTypeName() {
189207
case "INT", "SMALLINT", "TINYINT", "MEDIUMINT", "BIGINT":
190208
scanArgs[i] = new(sql.NullInt64)
191-
case "UNSIGNED INT", "UNSIGNED TINYINT", "UNSIGNED MEDIUMINT", "UNSIGNED BIGINT":
209+
case "UNSIGNED BIGINT":
210+
scanArgs[i] = new(NullUint64)
211+
case "UNSIGNED INT", "UNSIGNED TINYINT", "UNSIGNED MEDIUMINT":
192212
scanArgs[i] = new(sql.NullInt64)
193213
case "FLOAT", "DOUBLE":
194214
scanArgs[i] = new(sql.NullFloat64)
@@ -244,6 +264,12 @@ func (s *MysqlSource) QueryTableData(threadNum int, conditionSql string) ([][]in
244264
} else {
245265
row[i] = nil
246266
}
267+
case *NullUint64:
268+
if v.Valid {
269+
row[i] = v.Uint64
270+
} else {
271+
row[i] = nil
272+
}
247273
case *sql.NullBool:
248274
if v.Valid {
249275
row[i] = v.Bool
@@ -375,5 +401,58 @@ func (s *MysqlSource) GetDbTablesAccordingToSourceDbTables() (map[string][]strin
375401
allDbTables[db] = append(allDbTables[db], tables...)
376402
}
377403
}
404+
if s.cfg.SourceDB != "" && s.cfg.SourceTable != "" {
405+
allDbTables[s.cfg.SourceDB] = append(allDbTables[s.cfg.SourceDB], s.cfg.SourceTable)
406+
}
378407
return allDbTables, nil
379408
}
409+
410+
// NullUint64 represents a uint64 that may be null.
411+
type NullUint64 struct {
412+
Uint64 uint64
413+
Valid bool // Valid is true if Uint64 is not NULL
414+
}
415+
416+
// Scan implements the Scanner interface.
417+
func (n *NullUint64) Scan(value interface{}) error {
418+
if value == nil {
419+
n.Uint64, n.Valid = 0, false
420+
return nil
421+
}
422+
423+
n.Valid = true
424+
switch v := value.(type) {
425+
case uint64:
426+
n.Uint64 = v
427+
case int64:
428+
if v < 0 {
429+
// 处理溢出的情况
430+
n.Uint64 = uint64(v)
431+
} else {
432+
n.Uint64 = uint64(v)
433+
}
434+
case []byte:
435+
var err error
436+
n.Uint64, err = strconv.ParseUint(string(v), 10, 64)
437+
if err != nil {
438+
return err
439+
}
440+
case string:
441+
var err error
442+
n.Uint64, err = strconv.ParseUint(v, 10, 64)
443+
if err != nil {
444+
return err
445+
}
446+
default:
447+
return fmt.Errorf("cannot scan type %T into NullUint64", value)
448+
}
449+
return nil
450+
}
451+
452+
// Value implements the driver Valuer interface.
453+
func (n NullUint64) Value() (driver.Value, error) {
454+
if !n.Valid {
455+
return nil, nil
456+
}
457+
return n.Uint64, nil
458+
}

source/oracle.go

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,25 +23,25 @@ type OracleSource struct {
2323
statsRecorder *DatabendSourceStatsRecorder
2424
}
2525

26-
func (p *OracleSource) AdjustBatchSizeAccordingToSourceDbTable() int64 {
26+
func (p *OracleSource) AdjustBatchSizeAccordingToSourceDbTable() uint64 {
2727
minSplitKey, maxSplitKey, err := p.GetMinMaxSplitKey()
2828
if err != nil {
29-
return p.cfg.BatchSize
29+
return uint64(p.cfg.BatchSize)
3030
}
3131
sourceTableRowCount, err := p.GetSourceReadRowsCount()
3232
if err != nil {
33-
return p.cfg.BatchSize
33+
return uint64(p.cfg.BatchSize)
3434
}
3535
rangeSize := maxSplitKey - minSplitKey + 1
3636
switch {
3737
case int64(sourceTableRowCount) <= p.cfg.BatchSize:
3838
return rangeSize
39-
case rangeSize/int64(sourceTableRowCount) >= 10:
40-
return p.cfg.BatchSize * 5
41-
case rangeSize/int64(sourceTableRowCount) >= 100:
42-
return p.cfg.BatchSize * 20
39+
case rangeSize/uint64(sourceTableRowCount) >= 10:
40+
return uint64(p.cfg.BatchSize * 5)
41+
case rangeSize/uint64(sourceTableRowCount) >= 100:
42+
return uint64(p.cfg.BatchSize * 20)
4343
default:
44-
return p.cfg.BatchSize
44+
return uint64(p.cfg.BatchSize)
4545
}
4646
}
4747

@@ -111,32 +111,45 @@ func (p *OracleSource) GetSourceReadRowsCount() (int, error) {
111111
return rowCount, nil
112112
}
113113

114-
func (p *OracleSource) GetMinMaxSplitKey() (int64, int64, error) {
114+
func (p *OracleSource) GetMinMaxSplitKey() (uint64, uint64, error) {
115115
err := p.SwitchDatabase()
116116
if err != nil {
117117
return 0, 0, err
118118
}
119-
rows, err := p.db.Query(fmt.Sprintf("select COALESCE(min(%s),0), COALESCE(max(%s),0) from %s.%s WHERE %s",
120-
p.cfg.SourceSplitKey, p.cfg.SourceSplitKey, p.cfg.SourceDB, p.cfg.SourceTable, p.cfg.SourceWhereCondition))
119+
120+
query := fmt.Sprintf("SELECT COALESCE(MIN(%s), 0), COALESCE(MAX(%s), 0) FROM %s.%s WHERE %s",
121+
p.cfg.SourceSplitKey, p.cfg.SourceSplitKey,
122+
p.cfg.SourceDB, p.cfg.SourceTable, p.cfg.SourceWhereCondition)
123+
124+
rows, err := p.db.Query(query)
121125
if err != nil {
122126
return 0, 0, err
123127
}
124128
defer rows.Close()
125129

126-
var minSplitKey, maxSplitKey sql.NullInt64
130+
var minSplitKey, maxSplitKey interface{}
127131
for rows.Next() {
128132
err = rows.Scan(&minSplitKey, &maxSplitKey)
129133
if err != nil {
130134
return 0, 0, err
131135
}
132136
}
133137

134-
// Check if minSplitKey and maxSplitKey are valid (not NULL)
135-
if !minSplitKey.Valid || !maxSplitKey.Valid {
138+
if minSplitKey == nil || maxSplitKey == nil {
136139
return 0, 0, nil
137140
}
138141

139-
return minSplitKey.Int64, maxSplitKey.Int64, nil
142+
min64, err := toUint64(minSplitKey)
143+
if err != nil {
144+
return 0, 0, fmt.Errorf("failed to convert min value: %w", err)
145+
}
146+
147+
max64, err := toUint64(maxSplitKey)
148+
if err != nil {
149+
return 0, 0, fmt.Errorf("failed to convert max value: %w", err)
150+
}
151+
152+
return min64, max64, nil
140153
}
141154

142155
func (p *OracleSource) GetMinMaxTimeSplitKey() (string, string, error) {

0 commit comments

Comments
 (0)