Skip to content
6 changes: 6 additions & 0 deletions pkg/generate/ack/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ package ack

import (
"path/filepath"
"sort"
"strings"
ttpl "text/template"

ackgenconfig "github.com/aws-controllers-k8s/code-generator/pkg/config"
"github.com/aws-controllers-k8s/code-generator/pkg/fieldpath"
"github.com/aws-controllers-k8s/code-generator/pkg/generate/code"
"github.com/aws-controllers-k8s/code-generator/pkg/generate/templateset"
"github.com/aws-controllers-k8s/code-generator/pkg/model"
Expand Down Expand Up @@ -60,6 +62,9 @@ var (
"ResourceExceptionCode": func(r *ackmodel.CRD, httpStatusCode int) string {
return r.ExceptionCode(httpStatusCode)
},
"ConstructFieldPath": func(targetFieldPath string) *fieldpath.Path {
return fieldpath.FromString(targetFieldPath)
},
"GoCodeSetExceptionMessageCheck": func(r *ackmodel.CRD, httpStatusCode int) string {
return code.CheckExceptionMessage(r.Config(), r, httpStatusCode)
},
Expand Down Expand Up @@ -259,6 +264,7 @@ func Controller(
for serviceName := range referencedServiceNamesMap {
referencedServiceNames = append(referencedServiceNames, serviceName)
}
sort.Strings(referencedServiceNames)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I presume this is to ensure determinism?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I was seeing the imports and schemas jumping around in my git diff

cmdVars := &templateCmdVars{
metaVars,
snakeCasedCRDNames,
Expand Down
115 changes: 90 additions & 25 deletions pkg/generate/code/resource_reference.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,26 @@ import (
)

// ReferenceFieldsValidation produces the go code to validate reference field and
// corresponding identifier field.
// corresponding identifier field. Iterates through all references within
// slices, if necessary.
// for _, iter0 := range ko.Spec.Routes {
// if iter0.GatewayRef != nil && iter0.GatewayID != nil {
// return ackerr.ResourceReferenceAndIDNotSupportedFor("Routes.GatewayID", "Routes.GatewayRef")
// }
// }
// Sample code:
// if ko.Spec.APIRef != nil && ko.Spec.APIID != nil {
// return ackerr.ResourceReferenceAndIDNotSupportedFor("APIID", "APIRef")
// }
// if ko.Spec.APIRef == nil && ko.Spec.APIID == nil {
// return ackerr.ResourceReferenceOrIDRequiredFor("APIID", "APIRef")
// }
// return ackerr.ResourceReferenceAndIDNotSupportedFor("APIID", "APIRef")
// }
// if ko.Spec.APIRef == nil && ko.Spec.APIID == nil {
// return ackerr.ResourceReferenceOrIDRequiredFor("APIID", "APIRef")
// }
func ReferenceFieldsValidation(
crd *model.CRD,
sourceVarName string,
indentLevel int,
) string {
out := ""
fieldAccessPrefix := fmt.Sprintf("%s%s", sourceVarName,
crd.Config().PrefixConfig.SpecField)
// Sorted fieldnames are used for consistent code-generation
for _, fieldName := range crd.SortedFieldNames() {
field := crd.Fields[fieldName]
Expand All @@ -47,20 +51,48 @@ func ReferenceFieldsValidation(
fp := fieldpath.FromString(field.Path)
// remove fieldName from fieldPath before adding nil checks
fp.Pop()
fieldNamePrefix := crd.Config().PrefixConfig.SpecField

// prefix of the field path for referencing in the model
fieldNamePrefix := ""
// prefix of the field path for the generated code
pathVarPrefix := fmt.Sprintf("%s%s", sourceVarName, crd.Config().PrefixConfig.SpecField)

// this loop outputs a nil-guard for each level of nested field path
// or an iterator for any level that is a slice
fieldDepth := 0
for fp.Size() > 0 {
fIndent = strings.Repeat("\t", fIndentLevel)
fieldNamePrefix = fmt.Sprintf("%s.%s", fieldNamePrefix, fp.PopFront())
out += fmt.Sprintf("%sif %s%s != nil {\n", fIndent, sourceVarName, fieldNamePrefix)
currentField := fp.PopFront()

if fieldNamePrefix == "" {
fieldNamePrefix = currentField
} else {
fieldNamePrefix = fmt.Sprintf("%s.%s", fieldNamePrefix, currentField)
}
pathVarPrefix = fmt.Sprintf("%s.%s", pathVarPrefix, currentField)

fieldConfig, ok := crd.Fields[fieldNamePrefix]
if !ok {
panic(fmt.Sprintf("CRD %s has no Field with path %s", crd.Kind, fieldNamePrefix))
}

if fieldConfig.ShapeRef.Shape.Type == "list" {
out += fmt.Sprintf("%sfor _, iter%d := range %s {\n", fIndent, fieldDepth, pathVarPrefix)
// reset the path variable name
pathVarPrefix = fmt.Sprintf("iter%d", fieldDepth)
} else {
out += fmt.Sprintf("%sif %s != nil {\n", fIndent, pathVarPrefix)
}

fIndentLevel++
fieldDepth++
}

fIndent = strings.Repeat("\t", fIndentLevel)
// Validation to make sure both target field and reference are
// not present at the same time in desired resource
out += fmt.Sprintf("%sif %s.%s != nil"+
" && %s.%s != nil {\n", fIndent, fieldAccessPrefix,
field.ReferenceFieldPath(), fieldAccessPrefix, field.Path)
" && %s.%s != nil {\n", fIndent, pathVarPrefix, field.GetReferenceFieldName().Camel, pathVarPrefix, field.Names.Camel)
out += fmt.Sprintf("%s\treturn "+
"ackerr.ResourceReferenceAndIDNotSupportedFor(\"%s\", \"%s\")\n",
fIndent, field.Path, field.ReferenceFieldPath())
Expand All @@ -78,8 +110,8 @@ func ReferenceFieldsValidation(
// field is present in the resource
if field.IsRequired() {
out += fmt.Sprintf("%sif %s.%s == nil &&"+
" %s.%s == nil {\n", fIndent, fieldAccessPrefix,
field.ReferenceFieldPath(), fieldAccessPrefix, field.Path)
" %s.%s == nil {\n", fIndent, pathVarPrefix,
field.ReferenceFieldPath(), pathVarPrefix, field.Path)
out += fmt.Sprintf("%s\treturn "+
"ackerr.ResourceReferenceOrIDRequiredFor(\"%s\", \"%s\")\n",
fIndent, field.Path, field.ReferenceFieldPath())
Expand All @@ -94,32 +126,65 @@ func ReferenceFieldsValidation(
// a non-nil reference field is present in a resource. This checks helps in deciding
// whether ACK.ReferencesResolved condition should be added to resource status
// Sample Code:
// if ko.Spec.Routes != nil {
// for _, iter35 := range ko.Spec.Routes {
// if iter35.GatewayRef != nil {
// return true
// }
// }
// }
// return false || (ko.Spec.APIRef != nil)
func ReferenceFieldsPresent(
crd *model.CRD,
sourceVarName string,
) string {
out := "false"
iteratorsOut := ""
returnOut := "return false"
fieldAccessPrefix := fmt.Sprintf("%s%s", sourceVarName,
crd.Config().PrefixConfig.SpecField)
// Sorted fieldnames are used for consistent code-generation
for _, fieldName := range crd.SortedFieldNames() {
for fieldIndex, fieldName := range crd.SortedFieldNames() {
field := crd.Fields[fieldName]
if field.HasReference() {
out += " || ("
fp := fieldpath.FromString(field.Path)
// remove fieldName from fieldPath before adding nil checks
// for nested fieldPath
fp.Pop()
fieldNamePrefix := ""
for fp.Size() > 0 {
fieldNamePrefix = fmt.Sprintf("%s.%s", fieldNamePrefix, fp.PopFront())
out += fmt.Sprintf("%s%s != nil && ", fieldAccessPrefix, fieldNamePrefix)

// Determine whether the field is nested
if fp.Size() > 0 {
// Determine whether the field is inside a slice
parentField, ok := crd.Fields[fp.String()]
if !ok {
panic(fmt.Sprintf("CRD %s has no Field with path %s", crd.Kind, fp.String()))
}

if parentField.ShapeRef.Shape.Type == "list" {
iteratorsOut += fmt.Sprintf("if %s {\n", nestedStructNilCheck(*fp.Copy(), fieldAccessPrefix))
iteratorsOut += fmt.Sprintf("\tfor _, iter%d := range %s.%s {\n", fieldIndex, fieldAccessPrefix, parentField.Path)
iteratorsOut += fmt.Sprintf("\t\tif iter%d.%s != nil {\n", fieldIndex, field.GetReferenceFieldName().Camel)
iteratorsOut += fmt.Sprintf("\t\t\treturn true\n")
iteratorsOut += fmt.Sprintf("\t\t}\n")
iteratorsOut += fmt.Sprintf("\t}\n")
iteratorsOut += fmt.Sprintf("}\n")
continue
}
}
out += fmt.Sprintf("%s.%s != nil", fieldAccessPrefix,

nilCheck := nestedStructNilCheck(*fp.Copy(), fieldAccessPrefix) + " && " + fmt.Sprintf("%s.%s != nil", fieldAccessPrefix,
field.ReferenceFieldPath())
out += ")"
returnOut += " || (" + strings.TrimPrefix(nilCheck, " && ") + ")"
}
}
return out
return iteratorsOut + returnOut
}

func nestedStructNilCheck(path fieldpath.Path, fieldAccessPrefix string) string {
out := ""
fieldNamePrefix := ""
for path.Size() > 0 {
fieldNamePrefix = fmt.Sprintf("%s.%s", fieldNamePrefix, path.PopFront())
out += fmt.Sprintf("%s%s != nil && ", fieldAccessPrefix, fieldNamePrefix)
}
return strings.TrimSuffix(out, " && ")
}
38 changes: 34 additions & 4 deletions pkg/generate/code/resource_reference_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func Test_ReferenceFieldsPresent_NoReferenceConfig(t *testing.T) {

crd := testutil.GetCRDByName(t, g, "Api")
require.NotNil(crd)
expected := "false"
expected := "return false"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initially I left out the "return" part to make these conditionals reusable in any "if" block... However now i can see that it's not possible to support all the types with that limitation.

No objections to the change but i thought of commenting the initial intention of not adding the "return" word.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I was wondering why that was left out. Hopefully if anything else needs that logic, it can just use the method directly, itself. Nice forward thinking, though

assert.Equal(expected, code.ReferenceFieldsPresent(crd, "ko"))
}

Expand All @@ -133,7 +133,7 @@ func Test_ReferenceFieldsPresent_SingleReference(t *testing.T) {

crd := testutil.GetCRDByName(t, g, "Integration")
require.NotNil(crd)
expected := "false || (ko.Spec.APIRef != nil)"
expected := "return false || (ko.Spec.APIRef != nil)"
assert.Equal(expected, code.ReferenceFieldsPresent(crd, "ko"))
}

Expand All @@ -150,7 +150,7 @@ func Test_ReferenceFieldsPresent_SliceOfReferences(t *testing.T) {
// just to test code generation for slices of reference
crd := testutil.GetCRDByName(t, g, "VpcLink")
require.NotNil(crd)
expected := "false || (ko.Spec.SecurityGroupRefs != nil) || (ko.Spec.SubnetRefs != nil)"
expected := "return false || (ko.Spec.SecurityGroupRefs != nil) || (ko.Spec.SubnetRefs != nil)"
assert.Equal(expected, code.ReferenceFieldsPresent(crd, "ko"))
}

Expand All @@ -165,6 +165,36 @@ func Test_ReferenceFieldsPresent_NestedReference(t *testing.T) {

crd := testutil.GetCRDByName(t, g, "Authorizer")
require.NotNil(crd)
expected := "false || (ko.Spec.JWTConfiguration != nil && ko.Spec.JWTConfiguration.IssuerRef != nil)"
expected := "return false || (ko.Spec.JWTConfiguration != nil && ko.Spec.JWTConfiguration.IssuerRef != nil)"
assert.Equal(expected, code.ReferenceFieldsPresent(crd, "ko"))
}

func Test_ReferenceFieldsPresent_NestedSliceOfStructsReference(t *testing.T) {
assert := assert.New(t)
require := require.New(t)

g := testutil.NewModelForServiceWithOptions(t, "ec2",
&testutil.TestingModelOptions{
GeneratorConfigFile: "generator-with-nested-reference.yaml",
})

crd := testutil.GetCRDByName(t, g, "RouteTable")
require.NotNil(crd)
expected :=
`if ko.Spec.Routes != nil {
for _, iter32 := range ko.Spec.Routes {
if iter32.GatewayRef != nil {
return true
}
}
}
if ko.Spec.Routes != nil {
for _, iter35 := range ko.Spec.Routes {
if iter35.NATGatewayRef != nil {
return true
}
}
}
return false || (ko.Spec.VPCRef != nil)`
assert.Equal(expected, code.ReferenceFieldsPresent(crd, "ko"))
}
1 change: 1 addition & 0 deletions pkg/model/crd.go
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,7 @@ func (r *CRD) ReferencedServiceNames() (serviceNames []string) {
for serviceName, _ := range serviceNamesMap {
serviceNames = append(serviceNames, serviceName)
}
sort.Strings(serviceNames)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

return serviceNames
}

Expand Down
13 changes: 11 additions & 2 deletions pkg/model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,16 @@ func updateTypeDefAttributeWithReference(fieldPath string, tdefs []*TypeDef, crd
// the beginning of field path and leave rest of nested member names as is.
// Ex: ResourcesVpcConfig.SecurityGroupIDs will become VPCConfigRequest.SecurityGroupIDs
// for Cluster resource in eks-controller.
specFieldShapeName := specField.ShapeRef.ShapeName
specFieldShapeRef := specField.ShapeRef
specFieldShapeName := specFieldShapeRef.ShapeName
switch shapeType := specFieldShapeRef.Shape.Type; shapeType {
case "list":
specFieldShapeName = specField.ShapeRef.Shape.MemberRef.ShapeName
specFieldShapeRef = &specField.ShapeRef.Shape.MemberRef
case "map":
specFieldShapeName = specField.ShapeRef.Shape.ValueRef.ShapeName
specFieldShapeRef = &specField.ShapeRef.Shape.ValueRef
}
fieldShapePath := strings.Replace(fieldPath, specFieldName, specFieldShapeName, 1)
fsp := ackfp.FromString(fieldShapePath)

Expand All @@ -509,7 +518,7 @@ func updateTypeDefAttributeWithReference(fieldPath string, tdefs []*TypeDef, crd
// "fieldName" as attribute. To add a corresponding reference for "fieldName"
// , we will add new attribute in TypeDef for "parentFieldName".
parentFieldName := fsp.Back()
parentFieldShapeRef := fsp.ShapeRef(specField.ShapeRef)
parentFieldShapeRef := fsp.ShapeRef(specFieldShapeRef)
if parentFieldShapeRef == nil {
panic(fmt.Sprintf("Unable to find a shape member with name %s"+
" to add a reference for %s", parentFieldName, fieldPath))
Expand Down
31 changes: 31 additions & 0 deletions pkg/model/model_ec2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
package model_test

import (
"strings"
"testing"

"github.com/aws-controllers-k8s/code-generator/pkg/model"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -152,3 +155,31 @@ func TestEC2_Volume(t *testing.T) {
// field
assert.NotNil(testutil.GetTypeDefByName(t, g, "VolumeAttachment"))
}

func TestEC2_NestedReference(t *testing.T) {
assert := assert.New(t)

g := testutil.NewModelForServiceWithOptions(t, "ec2", &testutil.TestingModelOptions{
GeneratorConfigFile: "generator-with-nested-reference.yaml",
})

tds, err := g.GetTypeDefs()
assert.Nil(err)
assert.NotNil(tds)

var createRouteInputTD *model.TypeDef

for _, td := range tds {
if td != nil && strings.EqualFold(td.Names.Original, "createRouteInput") {
createRouteInputTD = td
break
}
}
assert.NotNil(t, createRouteInputTD)
gatewayIdAttr := createRouteInputTD.GetAttribute("GatewayId")
gatewayRefAttr := createRouteInputTD.GetAttribute("GatewayRef")

assert.Equal("GatewayID", gatewayIdAttr.Names.Camel)
assert.Equal("GatewayRef", gatewayRefAttr.Names.Camel)
assert.Equal("*ackv1alpha1.AWSResourceReferenceWrapper", gatewayRefAttr.GoType)
}
Loading