@@ -2,9 +2,11 @@ package source
22
33import (
44 "database/sql"
5+ "database/sql/driver"
56 "fmt"
67 "log"
78 "regexp"
9+ "strconv"
810 "strings"
911 "time"
1012
@@ -40,25 +42,25 @@ func NewMysqlSource(cfg *config.Config) (*MysqlSource, error) {
4042
4143// AdjustBatchSizeAccordingToSourceDbTable has a concept called s, s = (maxKey - minKey) / sourceTableRowCount
4244// if s == 1 it means the data is uniform in the table, if s is much bigger than 1, it means the data is not uniform in the table
43- func (s * MysqlSource ) AdjustBatchSizeAccordingToSourceDbTable () int64 {
45+ func (s * MysqlSource ) AdjustBatchSizeAccordingToSourceDbTable () uint64 {
4446 minSplitKey , maxSplitKey , err := s .GetMinMaxSplitKey ()
4547 if err != nil {
46- return s .cfg .BatchSize
48+ return uint64 ( s .cfg .BatchSize )
4749 }
4850 sourceTableRowCount , err := s .GetSourceReadRowsCount ()
4951 if err != nil {
50- return s .cfg .BatchSize
52+ return uint64 ( s .cfg .BatchSize )
5153 }
5254 rangeSize := maxSplitKey - minSplitKey + 1
5355 switch {
5456 case int64 (sourceTableRowCount ) <= s .cfg .BatchSize :
5557 return rangeSize
56- case rangeSize / int64 (sourceTableRowCount ) >= 10 :
57- return s .cfg .BatchSize * 5
58- case rangeSize / int64 (sourceTableRowCount ) >= 100 :
59- return s .cfg .BatchSize * 20
58+ case rangeSize / uint64 (sourceTableRowCount ) >= 10 :
59+ return uint64 ( s .cfg .BatchSize * 5 )
60+ case rangeSize / uint64 (sourceTableRowCount ) >= 100 :
61+ return uint64 ( s .cfg .BatchSize * 20 )
6062 default :
61- return s .cfg .BatchSize
63+ return uint64 ( s .cfg .BatchSize )
6264 }
6365}
6466
@@ -74,28 +76,41 @@ func (s *MysqlSource) GetSourceReadRowsCount() (int, error) {
7476 return rowCount , nil
7577}
7678
77- func (s * MysqlSource ) GetMinMaxSplitKey () (int64 , int64 , error ) {
78- rows , err := s .db .Query (fmt .Sprintf ("select min(%s), max(%s) from %s.%s WHERE %s" , s .cfg .SourceSplitKey ,
79- s .cfg .SourceSplitKey , s .cfg .SourceDB , s .cfg .SourceTable , s .cfg .SourceWhereCondition ))
79+ func (s * MysqlSource ) GetMinMaxSplitKey () (uint64 , uint64 , error ) {
80+ query := fmt .Sprintf ("SELECT MIN(%s), MAX(%s) FROM %s.%s WHERE %s" ,
81+ s .cfg .SourceSplitKey , s .cfg .SourceSplitKey ,
82+ s .cfg .SourceDB , s .cfg .SourceTable , s .cfg .SourceWhereCondition )
83+
84+ rows , err := s .db .Query (query )
8085 if err != nil {
8186 return 0 , 0 , err
8287 }
8388 defer rows .Close ()
8489
85- var minSplitKey , maxSplitKey sql. NullInt64
90+ var minSplitKey , maxSplitKey interface {}
8691 for rows .Next () {
8792 err = rows .Scan (& minSplitKey , & maxSplitKey )
8893 if err != nil {
8994 return 0 , 0 , err
9095 }
9196 }
9297
93- // Check if minSplitKey and maxSplitKey are valid (not NULL)
94- if ! minSplitKey . Valid || ! maxSplitKey . Valid {
98+ // 处理 NULL 值
99+ if minSplitKey == nil || maxSplitKey == nil {
95100 return 0 , 0 , nil
96101 }
97102
98- return minSplitKey .Int64 , maxSplitKey .Int64 , nil
103+ min64 , err := toUint64 (minSplitKey )
104+ if err != nil {
105+ return 0 , 0 , fmt .Errorf ("failed to convert min value: %w" , err )
106+ }
107+
108+ max64 , err := toUint64 (maxSplitKey )
109+ if err != nil {
110+ return 0 , 0 , fmt .Errorf ("failed to convert max value: %w" , err )
111+ }
112+
113+ return min64 , max64 , nil
99114}
100115
101116func (s * MysqlSource ) GetMinMaxTimeSplitKey () (string , string , error ) {
@@ -117,6 +132,7 @@ func (s *MysqlSource) GetMinMaxTimeSplitKey() (string, string, error) {
117132}
118133
119134func (s * MysqlSource ) DeleteAfterSync () error {
135+ logrus .Infof ("DeleteAfterSync: %v" , s .cfg .DeleteAfterSync )
120136 if ! s .cfg .DeleteAfterSync {
121137 return nil
122138 }
@@ -126,6 +142,8 @@ func (s *MysqlSource) DeleteAfterSync() error {
126142 return err
127143 }
128144
145+ logrus .Infof ("dbTables: %v" , dbTables )
146+
129147 for db , tables := range dbTables {
130148 for _ , table := range tables {
131149 count , err := s .GetSourceReadRowsCount ()
@@ -188,7 +206,9 @@ func (s *MysqlSource) QueryTableData(threadNum int, conditionSql string) ([][]in
188206 switch columnType .DatabaseTypeName () {
189207 case "INT" , "SMALLINT" , "TINYINT" , "MEDIUMINT" , "BIGINT" :
190208 scanArgs [i ] = new (sql.NullInt64 )
191- case "UNSIGNED INT" , "UNSIGNED TINYINT" , "UNSIGNED MEDIUMINT" , "UNSIGNED BIGINT" :
209+ case "UNSIGNED BIGINT" :
210+ scanArgs [i ] = new (NullUint64 )
211+ case "UNSIGNED INT" , "UNSIGNED TINYINT" , "UNSIGNED MEDIUMINT" :
192212 scanArgs [i ] = new (sql.NullInt64 )
193213 case "FLOAT" , "DOUBLE" :
194214 scanArgs [i ] = new (sql.NullFloat64 )
@@ -244,6 +264,12 @@ func (s *MysqlSource) QueryTableData(threadNum int, conditionSql string) ([][]in
244264 } else {
245265 row [i ] = nil
246266 }
267+ case * NullUint64 :
268+ if v .Valid {
269+ row [i ] = v .Uint64
270+ } else {
271+ row [i ] = nil
272+ }
247273 case * sql.NullBool :
248274 if v .Valid {
249275 row [i ] = v .Bool
@@ -375,5 +401,58 @@ func (s *MysqlSource) GetDbTablesAccordingToSourceDbTables() (map[string][]strin
375401 allDbTables [db ] = append (allDbTables [db ], tables ... )
376402 }
377403 }
404+ if s .cfg .SourceDB != "" && s .cfg .SourceTable != "" {
405+ allDbTables [s .cfg .SourceDB ] = append (allDbTables [s .cfg .SourceDB ], s .cfg .SourceTable )
406+ }
378407 return allDbTables , nil
379408}
409+
410+ // NullUint64 represents a uint64 that may be null.
411+ type NullUint64 struct {
412+ Uint64 uint64
413+ Valid bool // Valid is true if Uint64 is not NULL
414+ }
415+
416+ // Scan implements the Scanner interface.
417+ func (n * NullUint64 ) Scan (value interface {}) error {
418+ if value == nil {
419+ n .Uint64 , n .Valid = 0 , false
420+ return nil
421+ }
422+
423+ n .Valid = true
424+ switch v := value .(type ) {
425+ case uint64 :
426+ n .Uint64 = v
427+ case int64 :
428+ if v < 0 {
429+ // 处理溢出的情况
430+ n .Uint64 = uint64 (v )
431+ } else {
432+ n .Uint64 = uint64 (v )
433+ }
434+ case []byte :
435+ var err error
436+ n .Uint64 , err = strconv .ParseUint (string (v ), 10 , 64 )
437+ if err != nil {
438+ return err
439+ }
440+ case string :
441+ var err error
442+ n .Uint64 , err = strconv .ParseUint (v , 10 , 64 )
443+ if err != nil {
444+ return err
445+ }
446+ default :
447+ return fmt .Errorf ("cannot scan type %T into NullUint64" , value )
448+ }
449+ return nil
450+ }
451+
452+ // Value implements the driver Valuer interface.
453+ func (n NullUint64 ) Value () (driver.Value , error ) {
454+ if ! n .Valid {
455+ return nil , nil
456+ }
457+ return n .Uint64 , nil
458+ }
0 commit comments