Skip to content

Remove tf_serving config, move tf_signature_key #471

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/deployments/apis.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Serve models at scale.
model: <string> # path to an exported model (e.g. s3://my-bucket/exported_model)
model_format: <string> # model format, must be "tensorflow" or "onnx" (default: "onnx" if model path ends with .onnx, "tensorflow" if model path ends with .zip or is a directory)
request_handler: <string> # path to the request handler implementation file, relative to the cortex root
tf_signature_key: <string> # name of the signature def to use for prediction (required if your model has more than one signature def)
tracker:
key: <string> # json key to track if the response payload is a dictionary
model_type: <string> # model type, must be "classification" or "regression"
Expand Down
39 changes: 15 additions & 24 deletions pkg/operator/api/userconfig/apis.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,20 @@ type APIs []*API

type API struct {
ResourceFields
Model string `json:"model" yaml:"model"`
ModelFormat ModelFormat `json:"model_format" yaml:"model_format"`
Tracker *Tracker `json:"tracker" yaml:"tracker"`
RequestHandler *string `json:"request_handler" yaml:"request_handler"`
TFServing *TFServingOptions `json:"tf_serving" yaml:"tf_serving"`
Compute *APICompute `json:"compute" yaml:"compute"`
Tags Tags `json:"tags" yaml:"tags"`
Model string `json:"model" yaml:"model"`
ModelFormat ModelFormat `json:"model_format" yaml:"model_format"`
Tracker *Tracker `json:"tracker" yaml:"tracker"`
RequestHandler *string `json:"request_handler" yaml:"request_handler"`
TFSignatureKey *string `json:"tf_signature_key" yaml:"tf_signature_key"`
Compute *APICompute `json:"compute" yaml:"compute"`
Tags Tags `json:"tags" yaml:"tags"`
}

type Tracker struct {
Key *string `json:"key" yaml:"key"`
ModelType ModelType `json:"model_type" yaml:"model_type"`
}

type TFServingOptions struct {
SignatureKey string `json:"signature_key" yaml:"signature_key"`
}

var apiValidation = &cr.StructValidation{
StructFieldValidations: []*cr.StructFieldValidation{
{
Expand Down Expand Up @@ -107,17 +103,9 @@ var apiValidation = &cr.StructValidation{
},
},
{
StructField: "TFServing",
StructValidation: &cr.StructValidation{
DefaultNil: true,
StructFieldValidations: []*cr.StructFieldValidation{
{
StructField: "SignatureKey",
StringValidation: &cr.StringValidation{
Required: true,
},
},
},
StructField: "TFSignatureKey",
StringPtrValidation: &cr.StringPtrValidation{
Required: false,
},
},
apiComputeFieldValidation,
Expand Down Expand Up @@ -198,6 +186,9 @@ func (api *API) UserConfigStr() string {
if api.RequestHandler != nil {
sb.WriteString(fmt.Sprintf("%s: %s\n", RequestHandlerKey, *api.RequestHandler))
}
if api.TFSignatureKey != nil {
sb.WriteString(fmt.Sprintf("%s: %s\n", TFSignatureKeyKey, *api.TFSignatureKey))
}
if api.Compute != nil {
sb.WriteString(fmt.Sprintf("%s:\n", ComputeKey))
sb.WriteString(s.Indent(api.Compute.UserConfigStr(), " "))
Expand Down Expand Up @@ -276,8 +267,8 @@ func (api *API) Validate(projectFileMap map[string][]byte) error {
}
}

if api.ModelFormat != TensorFlowModelFormat && api.TFServing != nil {
return errors.Wrap(ErrorTFServingOptionsForTFOnly(api.ModelFormat), Identify(api))
if api.ModelFormat != TensorFlowModelFormat && api.TFSignatureKey != nil {
return errors.Wrap(ErrorIncompatibleWithModelFormat(TFSignatureKeyKey, api.ModelFormat), Identify(api))
}

if api.RequestHandler != nil {
Expand Down
1 change: 1 addition & 0 deletions pkg/operator/api/userconfig/config_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ const (
ModelKey = "model"
ModelFormatKey = "model_format"
RequestHandlerKey = "request_handler"
TFSignatureKeyKey = "tf_signature_key"

// compute
ComputeKey = "compute"
Expand Down
10 changes: 5 additions & 5 deletions pkg/operator/api/userconfig/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ const (
ErrExternalNotFound
ErrONNXDoesntSupportZip
ErrInvalidTensorflowDir
ErrTFServingOptionsForTFOnly
ErrIncompatibleWithModelFormat
)

var errorKinds = []string{
Expand Down Expand Up @@ -76,7 +76,7 @@ var errorKinds = []string{
"err_tf_serving_options_for_tf_only",
}

var _ = [1]int{}[int(ErrTFServingOptionsForTFOnly)-(len(errorKinds)-1)] // Ensure list length matches
var _ = [1]int{}[int(ErrIncompatibleWithModelFormat)-(len(errorKinds)-1)] // Ensure list length matches

func (t ErrorKind) String() string {
return errorKinds[t]
Expand Down Expand Up @@ -312,9 +312,9 @@ func ErrorInvalidTensorflowDir(path string) error {
}
}

func ErrorTFServingOptionsForTFOnly(format ModelFormat) error {
func ErrorIncompatibleWithModelFormat(configKey string, format ModelFormat) error {
return Error{
Kind: ErrTFServingOptionsForTFOnly,
message: fmt.Sprintf("TensorFlow serving options were provided but the model format is %s", format.String()),
Kind: ErrIncompatibleWithModelFormat,
message: fmt.Sprintf("\"%s\" was specified, but is not supported by the %s model format", configKey, format.String()),
}
}
2 changes: 1 addition & 1 deletion pkg/operator/context/apis.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func getAPIs(config *userconfig.Config, deploymentVersion string, projectID stri
buf.WriteString(apiConfig.ModelFormat.String())
buf.WriteString(deploymentVersion)
buf.WriteString(strings.TrimSuffix(apiConfig.Model, "/"))
buf.WriteString(s.Obj(apiConfig.TFServing))
buf.WriteString(s.Obj(apiConfig.TFSignatureKey))

if apiConfig.RequestHandler != nil {
buf.WriteString(projectID)
Expand Down
10 changes: 3 additions & 7 deletions pkg/workloads/cortex/tf_api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,15 +406,11 @@ def start(args):

time.sleep(5)

signature_key = None
if api.get("tf_serving") is not None and api["tf_serving"].get("signature_key") is not None:
signature_key = api["tf_serving"]["signature_key"]

key, parsed_signature = extract_signature(
local_cache["model_metadata"]["signatureDef"], signature_key
signature_key, parsed_signature = extract_signature(
local_cache["model_metadata"]["signatureDef"], api["tf_signature_key"]
)

local_cache["signature_key"] = key
local_cache["signature_key"] = signature_key
local_cache["parsed_signature"] = parsed_signature
logger.info("model_signature: {}".format(local_cache["parsed_signature"]))
serve(app, listen="*:{}".format(args.port))
Expand Down