Skip to content

Commit a00c2cf

Browse files
Merge branch 'main' of https://github.com/oracle-samples/gorm-oracle into fixing-bulk-merge
2 parents 52c7e00 + e195ff2 commit a00c2cf

17 files changed

+595
-179
lines changed

oracle/clause_builder.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ func ReturningClauseBuilder(c clause.Clause, builder clause.Builder) {
165165
var dest interface{}
166166
if stmt.Schema != nil {
167167
if field := findFieldByDBName(stmt.Schema, column.Name); field != nil {
168-
dest = createTypedDestination(field.FieldType)
168+
dest = createTypedDestination(field)
169169
} else {
170170
dest = new(string) // Default to string for unknown fields
171171
}

oracle/common.go

Lines changed: 92 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -100,18 +100,30 @@ func findFieldByDBName(schema *schema.Schema, dbName string) *schema.Field {
100100
}
101101

102102
// Create typed destination for OUT parameters
103-
func createTypedDestination(fieldType reflect.Type) interface{} {
104-
// Handle pointer types
105-
if fieldType.Kind() == reflect.Ptr {
106-
fieldType = fieldType.Elem()
103+
func createTypedDestination(f *schema.Field) interface{} {
104+
if f == nil {
105+
var s string
106+
return &s
107107
}
108108

109-
// Type-safe handling for known GORM types and SQL null types
110-
switch fieldType {
111-
case reflect.TypeOf(gorm.DeletedAt{}):
109+
ft := f.FieldType
110+
for ft.Kind() == reflect.Ptr {
111+
ft = ft.Elem()
112+
}
113+
114+
if ft == reflect.TypeOf(gorm.DeletedAt{}) {
112115
return new(sql.NullTime)
113-
case reflect.TypeOf(time.Time{}):
116+
}
117+
if ft == reflect.TypeOf(time.Time{}) {
118+
if !f.NotNull { // nullable column => keep NULLs
119+
return new(sql.NullTime)
120+
}
114121
return new(time.Time)
122+
}
123+
124+
switch ft {
125+
case reflect.TypeOf(sql.NullTime{}):
126+
return new(sql.NullTime)
115127
case reflect.TypeOf(sql.NullInt64{}):
116128
return new(sql.NullInt64)
117129
case reflect.TypeOf(sql.NullInt32{}):
@@ -120,33 +132,28 @@ func createTypedDestination(fieldType reflect.Type) interface{} {
120132
return new(sql.NullFloat64)
121133
case reflect.TypeOf(sql.NullBool{}):
122134
return new(sql.NullBool)
123-
case reflect.TypeOf(sql.NullTime{}):
124-
return new(sql.NullTime)
125135
}
126136

127-
// Handle primitive types by Kind
128-
switch fieldType.Kind() {
137+
switch ft.Kind() {
138+
case reflect.String:
139+
return new(string)
140+
141+
case reflect.Bool:
142+
return new(int64)
143+
129144
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
130-
return new(int64) // Oracle returns NUMBER as int64
131-
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
145+
return new(int64)
146+
147+
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
132148
return new(uint64)
149+
133150
case reflect.Float32, reflect.Float64:
134-
return new(float64) // Oracle returns FLOAT as float64
135-
case reflect.Bool:
136-
return new(int64) // Oracle NUMBER(1) for boolean
137-
case reflect.String:
138-
return new(string)
139-
case reflect.Struct:
140-
// For time.Time specifically
141-
if fieldType == reflect.TypeOf(time.Time{}) {
142-
return new(time.Time)
143-
}
144-
// For other structs, use string as safe fallback
145-
return new(string)
146-
default:
147-
// For unknown types, use string as safe fallback
148-
return new(string)
151+
return new(float64)
149152
}
153+
154+
// Fallback
155+
var s string
156+
return &s
150157
}
151158

152159
// Convert values for Oracle-specific types
@@ -182,7 +189,7 @@ func convertValue(val interface{}) interface{} {
182189

183190
// Convert Oracle values back to Go types
184191
func convertFromOracleToField(value interface{}, field *schema.Field) interface{} {
185-
if value == nil {
192+
if value == nil || field == nil {
186193
return nil
187194
}
188195

@@ -194,7 +201,6 @@ func convertFromOracleToField(value interface{}, field *schema.Field) interface{
194201

195202
var converted interface{}
196203

197-
// Handle special types first using type-safe comparisons
198204
switch targetType {
199205
case reflect.TypeOf(gorm.DeletedAt{}):
200206
if nullTime, ok := value.(sql.NullTime); ok {
@@ -203,7 +209,31 @@ func convertFromOracleToField(value interface{}, field *schema.Field) interface{
203209
converted = gorm.DeletedAt{}
204210
}
205211
case reflect.TypeOf(time.Time{}):
206-
converted = value
212+
switch vv := value.(type) {
213+
case time.Time:
214+
converted = vv
215+
case sql.NullTime:
216+
if vv.Valid {
217+
converted = vv.Time
218+
} else {
219+
// DB returned NULL
220+
if isPtr {
221+
return nil // -> *time.Time(nil)
222+
}
223+
// non-pointer time.Time: represent NULL as zero time
224+
return time.Time{}
225+
}
226+
default:
227+
converted = value
228+
}
229+
230+
case reflect.TypeOf(sql.NullTime{}):
231+
if nullTime, ok := value.(sql.NullTime); ok {
232+
converted = nullTime
233+
} else {
234+
converted = sql.NullTime{}
235+
}
236+
207237
case reflect.TypeOf(sql.NullInt64{}):
208238
if nullInt, ok := value.(sql.NullInt64); ok {
209239
converted = nullInt
@@ -228,48 +258,24 @@ func convertFromOracleToField(value interface{}, field *schema.Field) interface{
228258
} else {
229259
converted = sql.NullBool{}
230260
}
231-
case reflect.TypeOf(sql.NullTime{}):
232-
if nullTime, ok := value.(sql.NullTime); ok {
233-
converted = nullTime
234-
} else {
235-
converted = sql.NullTime{}
236-
}
237261
default:
238-
// Handle primitive types
262+
// primitives and everything else
239263
converted = convertPrimitiveType(value, targetType)
240264
}
241265

242-
// Handle pointer types
243-
if isPtr && converted != nil {
244-
if isZeroValueForPointer(converted, targetType) {
266+
// Pointer targets: nil for "zero-ish", else allocate and set.
267+
if isPtr {
268+
if isZeroFor(targetType, converted) {
245269
return nil
246270
}
247271
ptr := reflect.New(targetType)
248272
ptr.Elem().Set(reflect.ValueOf(converted))
249-
converted = ptr.Interface()
273+
return ptr.Interface()
250274
}
251275

252276
return converted
253277
}
254278

255-
// Helper function to check if a value should be treated as nil for pointer fields
256-
func isZeroValueForPointer(value interface{}, targetType reflect.Type) bool {
257-
v := reflect.ValueOf(value)
258-
if !v.IsValid() || v.Kind() != targetType.Kind() {
259-
return false
260-
}
261-
262-
switch targetType.Kind() {
263-
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
264-
return v.Int() == 0
265-
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
266-
return v.Uint() == 0
267-
case reflect.Float32, reflect.Float64:
268-
return v.Float() == 0.0
269-
}
270-
return false
271-
}
272-
273279
// Helper function to handle primitive type conversions
274280
func convertPrimitiveType(value interface{}, targetType reflect.Type) interface{} {
275281
switch targetType.Kind() {
@@ -442,3 +448,28 @@ func isNullValue(value interface{}) bool {
442448
return false
443449
}
444450
}
451+
452+
func isZeroFor(t reflect.Type, v interface{}) bool {
453+
if v == nil {
454+
return true
455+
}
456+
rv := reflect.ValueOf(v)
457+
if !rv.IsValid() {
458+
return true
459+
}
460+
// exact type match?
461+
if rv.Type() == t {
462+
// special-case time.Time
463+
if t == reflect.TypeOf(time.Time{}) {
464+
return rv.Interface().(time.Time).IsZero()
465+
}
466+
// generic zero check
467+
z := reflect.Zero(t)
468+
return reflect.DeepEqual(rv.Interface(), z.Interface())
469+
}
470+
// If types differ (e.g., sql.NullTime), treat invalid as zero
471+
if nt, ok := v.(sql.NullTime); ok {
472+
return !nt.Valid
473+
}
474+
return false
475+
}

oracle/create.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
513513
for rowIdx := 0; rowIdx < len(createValues.Values); rowIdx++ {
514514
for _, column := range allColumns {
515515
if field := findFieldByDBName(schema, column); field != nil {
516-
stmt.Vars = append(stmt.Vars, sql.Out{Dest: createTypedDestination(field.FieldType)})
516+
stmt.Vars = append(stmt.Vars, sql.Out{Dest: createTypedDestination(field)})
517517
plsqlBuilder.WriteString(fmt.Sprintf(" IF l_affected_records.COUNT > %d THEN :%d := l_affected_records(%d).", rowIdx, outParamIndex+1, rowIdx+1))
518518
writeQuotedIdentifier(&plsqlBuilder, column)
519519
plsqlBuilder.WriteString("; END IF;\n")
@@ -625,7 +625,7 @@ func buildBulkInsertOnlyPLSQL(db *gorm.DB, createValues clause.Values) {
625625
quotedColumn := columnBuilder.String()
626626

627627
if field := findFieldByDBName(schema, column); field != nil {
628-
stmt.Vars = append(stmt.Vars, sql.Out{Dest: createTypedDestination(field.FieldType)})
628+
stmt.Vars = append(stmt.Vars, sql.Out{Dest: createTypedDestination(field)})
629629
plsqlBuilder.WriteString(fmt.Sprintf(" IF l_inserted_records.COUNT > %d THEN :%d := l_inserted_records(%d).%s; END IF;\n",
630630
rowIdx, outParamIndex+1, rowIdx+1, quotedColumn))
631631
outParamIndex++

oracle/delete.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ func buildBulkDeletePLSQL(db *gorm.DB) {
278278
for _, column := range allColumns {
279279
field := findFieldByDBName(schema, column)
280280
if field != nil {
281-
dest := createTypedDestination(field.FieldType)
281+
dest := createTypedDestination(field)
282282
stmt.Vars = append(stmt.Vars, sql.Out{Dest: dest})
283283

284284
plsqlBuilder.WriteString(fmt.Sprintf(" IF l_deleted_records.COUNT > %d THEN\n", rowIdx))

oracle/update.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ func buildUpdatePLSQL(db *gorm.DB) {
544544
for _, column := range allColumns {
545545
field := findFieldByDBName(schema, column)
546546
if field != nil {
547-
dest := createTypedDestination(field.FieldType)
547+
dest := createTypedDestination(field)
548548
stmt.Vars = append(stmt.Vars, sql.Out{Dest: dest})
549549
}
550550
}

tests/associations_many2many_test.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,6 @@ func TestSingleTableMany2ManyAssociationForSlice(t *testing.T) {
370370
}
371371

372372
func TestDuplicateMany2ManyAssociation(t *testing.T) {
373-
t.Skip()
374373
user1 := User{Name: "TestDuplicateMany2ManyAssociation-1", Languages: []Language{
375374
{Code: "TestDuplicateMany2ManyAssociation-language-1"},
376375
{Code: "TestDuplicateMany2ManyAssociation-language-2"},
@@ -434,7 +433,6 @@ func TestConcurrentMany2ManyAssociation(t *testing.T) {
434433
}
435434

436435
func TestMany2ManyDuplicateBelongsToAssociation(t *testing.T) {
437-
t.Skip()
438436
user1 := User{Name: "TestMany2ManyDuplicateBelongsToAssociation-1", Friends: []*User{
439437
{Name: "TestMany2ManyDuplicateBelongsToAssociation-friend-1", Company: Company{
440438
ID: 1,

tests/generics_test.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,6 @@ func TestGenericsJoins(t *testing.T) {
422422
}
423423

424424
func TestGenericsNestedJoins(t *testing.T) {
425-
t.Skip()
426425
users := []User{
427426
{
428427
Name: "generics-nested-joins-1",

tests/joins_test.go

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ func TestJoinWithSoftDeleted(t *testing.T) {
259259

260260
var user1 User
261261
DB.Model(&User{}).Joins("NamedPet").Joins("Account").First(&user1, user.ID)
262-
if user1.NamedPet == nil || user1.Account.ID == 0 {
262+
if user1.NamedPet.ID == 0 || user1.Account.ID == 0 {
263263
t.Fatalf("joins NamedPet and Account should not empty:%v", user1)
264264
}
265265

@@ -268,17 +268,17 @@ func TestJoinWithSoftDeleted(t *testing.T) {
268268

269269
var user2 User
270270
DB.Model(&User{}).Joins("NamedPet").Joins("Account").First(&user2, user.ID)
271-
if user2.NamedPet == nil || user2.Account.ID != 0 {
272-
t.Fatalf("joins Account should not empty:%v", user2)
271+
if user2.NamedPet.ID == 0 || user2.Account.ID != 0 {
272+
t.Fatalf("joins Account should be empty:%v", user2)
273273
}
274274

275275
// NamedPet should empty
276276
DB.Delete(&user1.NamedPet)
277277

278278
var user3 User
279279
DB.Model(&User{}).Joins("NamedPet").Joins("Account").First(&user3, user.ID)
280-
if user3.NamedPet != nil || user2.Account.ID != 0 {
281-
t.Fatalf("joins NamedPet and Account should not empty:%v", user2)
280+
if user3.NamedPet.ID != 0 || user3.Account.ID != 0 {
281+
t.Fatalf("joins NamedPet and Account should be empty:%v", user3)
282282
}
283283
}
284284

@@ -383,8 +383,6 @@ func TestJoinArgsWithDB(t *testing.T) {
383383
}
384384

385385
func TestNestedJoins(t *testing.T) {
386-
t.Skip()
387-
388386
users := []User{
389387
{
390388
Name: "nested-joins-1",
@@ -424,7 +422,7 @@ func TestNestedJoins(t *testing.T) {
424422
Joins("Manager.NamedPet.Toy").
425423
Joins("NamedPet").
426424
Joins("NamedPet.Toy").
427-
Find(&users2, "users.id IN ?", userIDs).Error; err != nil {
425+
Find(&users2, "\"users\".\"id\" IN ?", userIDs).Error; err != nil {
428426
t.Fatalf("Failed to load with joins, got error: %v", err)
429427
} else if len(users2) != len(users) {
430428
t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users))

0 commit comments

Comments
 (0)