Skip to content

Commit a57114b

Browse files
committed
bugfix: cast to alias in generated defaults in pointer
also unconditionally cast when using references since we dont know their source type otherwise if you have a `string` field, and provide a `type Value string`-valued default, you get an error trying to assign an alias to a string
1 parent 0e3af0b commit a57114b

File tree

2 files changed

+91
-36
lines changed

2 files changed

+91
-36
lines changed

examples/defaulter-gen/_output_tests/marker/zz_generated.go

Lines changed: 18 additions & 14 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

examples/defaulter-gen/generators/defaulter.go

Lines changed: 73 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -442,16 +442,16 @@ func newCallTreeForType(existingDefaulters, newDefaulters defaulterFuncMap) *cal
442442
}
443443
}
444444

445-
func resolveTypeAndDepth(t *types.Type) (*types.Type, int) {
445+
func resolveTypeAndDepth(t *types.Type) (*types.Type, []*types.Type) {
446446
var prev *types.Type
447-
depth := 0
447+
var depth []*types.Type
448448
for prev != t {
449449
prev = t
450450
if t.Kind == types.Alias {
451451
t = t.Underlying
452452
} else if t.Kind == types.Pointer {
453453
t = t.Elem
454-
depth += 1
454+
depth = append(depth, t)
455455
}
456456
}
457457
return t, depth
@@ -526,15 +526,16 @@ func populateDefaultValue(node *callNode, t *types.Type, tags string, commentLin
526526
var defaultString string
527527
if len(defaultMap) == 1 {
528528
defaultString = defaultMap[0]
529+
} else if len(defaultMap) > 1 {
530+
klog.Fatalf("Found more than one default tag for %v", t.Kind)
529531
}
530532

531-
t, depth := resolveTypeAndDepth(t)
532-
if depth > 0 && defaultString == "" {
533+
baseT, depth := resolveTypeAndDepth(t)
534+
if len(depth) > 0 && defaultString == "" {
533535
defaultString = getNestedDefault(t)
534536
}
535-
if len(defaultMap) > 1 {
536-
klog.Fatalf("Found more than one default tag for %v", t.Kind)
537-
} else if len(defaultMap) == 0 {
537+
538+
if len(defaultString) == 0 {
538539
return node
539540
}
540541
var symbolReference types.Name
@@ -547,7 +548,7 @@ func populateDefaultValue(node *callNode, t *types.Type, tags string, commentLin
547548
}
548549

549550
omitEmpty := strings.Contains(reflect.StructTag(tags).Get("json"), "omitempty")
550-
if enforced, err := mustEnforceDefault(t, depth, omitEmpty); err != nil {
551+
if enforced, err := mustEnforceDefault(baseT, len(depth), omitEmpty); err != nil {
551552
klog.Fatal(err)
552553
} else if enforced != nil {
553554
if defaultValue != nil {
@@ -568,8 +569,9 @@ func populateDefaultValue(node *callNode, t *types.Type, tags string, commentLin
568569
node.markerOnly = true
569570
}
570571

571-
node.defaultIsPrimitive = t.IsPrimitive()
572-
node.defaultType = t.String()
572+
node.defaultIsPrimitive = baseT.IsPrimitive()
573+
node.defaultType = baseT.String()
574+
node.defaultTopLevelType = t
573575
node.defaultValue.InlineConstant = defaultString
574576
node.defaultValue.SymbolReference = symbolReference
575577
node.defaultDepth = depth
@@ -893,11 +895,15 @@ type callNode struct {
893895
// +default="foo"
894896
// Field *string
895897
// }
896-
defaultDepth int
898+
defaultDepth []*types.Type
897899

898900
// defaultType is the type of the default value.
899901
// Only populated if defaultIsPrimitive is true
900902
defaultType string
903+
904+
// defaultTopLevelType is the final type the value should resolve to
905+
// This is in constrast with default type, which resolves aliases and pointers.
906+
defaultTopLevelType *types.Type
901907
}
902908

903909
type defaultValue struct {
@@ -996,8 +1002,9 @@ func (n *callNode) writeDefaulter(varName string, index string, isVarPointer boo
9961002
"defaultValue": n.defaultValue.Resolved(),
9971003
"varName": varName,
9981004
"index": index,
999-
"varDepth": n.defaultDepth,
1005+
"varDepth": len(n.defaultDepth),
10001006
"varType": n.defaultType,
1007+
"varTopType": n.defaultTopLevelType,
10011008
}
10021009

10031010
variablePlaceholder := ""
@@ -1021,15 +1028,54 @@ func (n *callNode) writeDefaulter(varName string, index string, isVarPointer boo
10211028
if n.defaultIsPrimitive {
10221029
// If the default value is a primitive when the assigned type is a pointer
10231030
// keep using the address-of operator on the primitive value until the types match
1024-
if n.defaultDepth > 0 {
1025-
sw.Do(fmt.Sprintf("if %s == nil {\n", variablePlaceholder), args)
1026-
sw.Do("var ptrVar$.varDepth$ $.varType$ = $.defaultValue$\n", args)
1027-
// We iterate until a depth of 1 instead of 0 because the following line
1028-
// `if $.varName$ == &ptrVar1` accounts for 1 level already
1029-
for i := n.defaultDepth; i > 1; i-- {
1030-
sw.Do("ptrVar$.ptri$ := &ptrVar$.i$\n", generator.Args{"i": fmt.Sprintf("%d", i), "ptri": fmt.Sprintf("%d", (i - 1))})
1031+
if len(n.defaultDepth) > 0 {
1032+
// If the destination is a pointer, the last element in
1033+
// defaultDepth is the element type of the bottommost pointer:
1034+
// the base type of our default value.
1035+
destElemType := n.defaultDepth[len(n.defaultDepth)-1]
1036+
pointerArgs := args.With("baseElemType", destElemType)
1037+
1038+
sw.Do(fmt.Sprintf("if %s == nil {\n", variablePlaceholder), pointerArgs)
1039+
if len(n.defaultValue.InlineConstant) > 0 {
1040+
// If default value is a literal then it can be assigned via var stmt
1041+
sw.Do("var ptrVar$.varDepth$ $.baseElemType|raw$ = $.defaultValue$\n", pointerArgs)
1042+
} else {
1043+
// If default value is not a literal then it may need to be casted
1044+
// to the base type of the destination pointer
1045+
sw.Do("ptrVar$.varDepth$ := $.baseElemType|raw$($.defaultValue$)\n", pointerArgs)
1046+
}
1047+
1048+
for i := len(n.defaultDepth); i >= 1; i-- {
1049+
dest := fmt.Sprintf("ptrVar%d", i-1)
1050+
assignment := ":="
1051+
if i == 1 {
1052+
// Last assignment is into the storage destination
1053+
dest = variablePlaceholder
1054+
assignment = "="
1055+
}
1056+
1057+
sourceType := "*" + destElemType.String()
1058+
if i == len(n.defaultDepth) {
1059+
// Initial value is not a pointer
1060+
sourceType = destElemType.String()
1061+
}
1062+
destElemType = n.defaultDepth[i-1]
1063+
1064+
// Cannot include `dest` into args since its value may be
1065+
// `variablePlaceholder` which is a template, not a value
1066+
elementArgs := pointerArgs.WithArgs(generator.Args{
1067+
"assignment": assignment,
1068+
"source": fmt.Sprintf("ptrVar%d", i),
1069+
"destElemType": destElemType,
1070+
})
1071+
1072+
// Skip cast if type is exact match
1073+
if destElemType.String() == sourceType {
1074+
sw.Do(fmt.Sprintf("%v $.assignment$ &$.source$\n", dest), elementArgs)
1075+
} else {
1076+
sw.Do(fmt.Sprintf("%v $.assignment$ (*$.destElemType|raw$)(&$.source$)\n", dest), elementArgs)
1077+
}
10311078
}
1032-
sw.Do(fmt.Sprintf("%s = &ptrVar1", variablePlaceholder), args)
10331079
} else {
10341080
// For primitive types, nil checks cannot be used and the zero value must be determined
10351081
defaultZero, err := getTypeZeroValue(n.defaultType)
@@ -1039,7 +1085,12 @@ func (n *callNode) writeDefaulter(varName string, index string, isVarPointer boo
10391085
args["defaultZero"] = defaultZero
10401086

10411087
sw.Do(fmt.Sprintf("if %s == $.defaultZero$ {\n", variablePlaceholder), args)
1042-
sw.Do(fmt.Sprintf("%s = $.defaultValue$", variablePlaceholder), args)
1088+
1089+
if len(n.defaultValue.InlineConstant) > 0 {
1090+
sw.Do(fmt.Sprintf("%s = $.defaultValue$", variablePlaceholder), args)
1091+
} else {
1092+
sw.Do(fmt.Sprintf("%s = $.varTopType|raw$($.defaultValue$)", variablePlaceholder), args)
1093+
}
10431094
}
10441095
} else {
10451096
sw.Do(fmt.Sprintf("if %s == nil {\n", variablePlaceholder), args)

0 commit comments

Comments
 (0)