diff --git a/docs/applications/advanced/external-models.md b/docs/applications/advanced/external-models.md index ea767666d0..b68a781b49 100644 --- a/docs/applications/advanced/external-models.md +++ b/docs/applications/advanced/external-models.md @@ -23,7 +23,7 @@ $ aws s3 cp model.zip s3://your-bucket/model.zip - kind: api name: my-api external_model: - path: s3://your-bucket/model.zip + path: s3://my-bucket/model.zip region: us-west-2 compute: replicas: 5 diff --git a/docs/applications/resources/apis.md b/docs/applications/resources/apis.md index 5c657600f1..febdd85241 100644 --- a/docs/applications/resources/apis.md +++ b/docs/applications/resources/apis.md @@ -9,8 +9,8 @@ Serve models at scale and use them to build smarter applications. name: # API name (required) model: # reference to a Cortex model (this or external_model must be specified) external_model: # (this or model must be specified) - path: # path to a zipped model dir - region: # S3 region (default: us-west-2) + path: # path to a zipped model dir (e.g. s3://my-bucket/model.zip) + region: # S3 region (default: us-west-2) compute: replicas: # number of replicas to launch (default: 1) cpu: # CPU request (default: Null) diff --git a/docs/applications/resources/constants.md b/docs/applications/resources/constants.md index c4a06d788c..d93923ba11 100644 --- a/docs/applications/resources/constants.md +++ b/docs/applications/resources/constants.md @@ -8,7 +8,11 @@ Constants represent literal values which can be used in other Cortex resources. - kind: constant name: # constant name (required) type: # the type of the constant (optional, will be inferred from value if not specified) - value: # a literal value (required) + value: # a literal value (this or external_model must be specified) + external_model: # (this or value must be specified) + path: # path to a JSON object (e.g. s3://my-bucket/constant.json) + region: # S3 region (default: us-west-2) + ``` See [Data Types](data-types.md) for details about output types and values. diff --git a/docs/applications/resources/environments.md b/docs/applications/resources/environments.md index 4bdcd392b1..39b20f5512 100644 --- a/docs/applications/resources/environments.md +++ b/docs/applications/resources/environments.md @@ -26,7 +26,7 @@ Transfer data at scale from data warehouses like S3 into the Cortex environment. data: type: csv # file type (required) path: s3a:/// # S3 is currently supported (required) - region: us-west-2 # S3 region (default: us-west-2) + region: us-west-2 # S3 region (default: us-west-2) drop_null: # drop any rows that contain at least 1 null value (default: false) csv_config: # optional configuration that can be provided schema: @@ -65,7 +65,7 @@ csv_config: data: type: parquet # file type (required) path: s3a:/// # S3 is currently supported (required) - region: us-west-2 # S3 region (default: us-west-2) + region: us-west-2 # S3 region (default: us-west-2) drop_null: # drop any rows that contain at least 1 null value (default: false) schema: - parquet_column_name: # name of the column in the parquet file (required) diff --git a/pkg/operator/api/userconfig/apis.go b/pkg/operator/api/userconfig/apis.go index 87048dd85c..152e2ab837 100644 --- a/pkg/operator/api/userconfig/apis.go +++ b/pkg/operator/api/userconfig/apis.go @@ -119,7 +119,7 @@ func (api *API) Validate() error { } if ok, err := aws.IsS3FileExternal(bucket, key, api.ExternalModel.Region); err != nil || !ok { - return errors.Wrap(ErrorExternalModelNotFound(api.ExternalModel.Path), Identify(api), ExternalModelKey, PathKey) + return errors.Wrap(ErrorExternalNotFound(api.ExternalModel.Path), Identify(api), ExternalModelKey, PathKey) } } diff --git a/pkg/operator/api/userconfig/config_key.go b/pkg/operator/api/userconfig/config_key.go index f4ea733b48..df744f2634 100644 --- a/pkg/operator/api/userconfig/config_key.go +++ b/pkg/operator/api/userconfig/config_key.go @@ -27,6 +27,7 @@ const ( PathKey = "path" OutputTypeKey = "output_type" TagsKey = "tags" + ExternalKey = "external" // input schema options OptionalOptKey = "_optional" diff --git a/pkg/operator/api/userconfig/constants.go b/pkg/operator/api/userconfig/constants.go index 19ae4138ac..6d51146f43 100644 --- a/pkg/operator/api/userconfig/constants.go +++ b/pkg/operator/api/userconfig/constants.go @@ -17,6 +17,7 @@ limitations under the License. package userconfig import ( + "github.com/cortexlabs/cortex/pkg/lib/aws" cr "github.com/cortexlabs/cortex/pkg/lib/configreader" "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/operator/api/resource" @@ -26,9 +27,10 @@ type Constants []*Constant type Constant struct { ResourceFields - Type OutputSchema `json:"type" yaml:"type"` - Value interface{} `json:"value" yaml:"value"` - Tags Tags `json:"tags" yaml:"tags"` + Type OutputSchema `json:"type" yaml:"type"` + Value interface{} `json:"value" yaml:"value"` + Tags Tags `json:"tags" yaml:"tags"` + External *ExternalConstant `json:"external" yaml:"external"` } var constantValidation = &cr.StructValidation{ @@ -50,14 +52,43 @@ var constantValidation = &cr.StructValidation{ { StructField: "Value", InterfaceValidation: &cr.InterfaceValidation{ - Required: true, + Required: false, }, }, + { + StructField: "External", + StructValidation: externalModelFieldValidation, + }, tagsFieldValidation, typeFieldValidation, }, } +type ExternalConstant struct { + Path string `json:"path" yaml:"path"` + Region string `json:"region" yaml:"region"` +} + +var externalConstantFieldValidation = &cr.StructValidation{ + DefaultNil: true, + StructFieldValidations: []*cr.StructFieldValidation{ + { + StructField: "Path", + StringValidation: &cr.StringValidation{ + Validator: cr.GetS3PathValidator(), + Required: true, + }, + }, + { + StructField: "Region", + StringValidation: &cr.StringValidation{ + Default: aws.DefaultS3Region, + AllowedValues: aws.S3Regions.Slice(), + }, + }, + }, +} + func (constants Constants) Validate() error { for _, constant := range constants { if err := constant.Validate(); err != nil { @@ -79,6 +110,25 @@ func (constants Constants) Validate() error { } func (constant *Constant) Validate() error { + if constant.External == nil && constant.Value == nil { + return errors.Wrap(ErrorSpecifyOnlyOneMissing(ValueKey, ExternalKey), Identify(constant)) + } + + if constant.External != nil && constant.Value != nil { + return errors.Wrap(ErrorSpecifyOnlyOne(ValueKey, ExternalKey), Identify(constant)) + } + + if constant.External != nil { + bucket, key, err := aws.SplitS3Path(constant.External.Path) + if err != nil { + return errors.Wrap(err, Identify(constant), ExternalKey, PathKey) + } + + if ok, err := aws.IsS3FileExternal(bucket, key, constant.External.Region); err != nil || !ok { + return errors.Wrap(ErrorExternalNotFound(constant.External.Path), Identify(constant), ExternalKey, PathKey) + } + } + if constant.Type != nil { castedValue, err := CastOutputValue(constant.Value, constant.Type) if err != nil { diff --git a/pkg/operator/api/userconfig/errors.go b/pkg/operator/api/userconfig/errors.go index 87334d0fa0..3c7151e637 100644 --- a/pkg/operator/api/userconfig/errors.go +++ b/pkg/operator/api/userconfig/errors.go @@ -75,7 +75,7 @@ const ( ErrEnvSchemaMismatch ErrExtraResourcesWithExternalAPIs ErrImplDoesNotExist - ErrExternalModelNotFound + ErrExternalNotFound ) var errorKinds = []string{ @@ -125,10 +125,10 @@ var errorKinds = []string{ "err_env_schema_mismatch", "err_extra_resources_with_external_a_p_is", "err_impl_does_not_exist", - "err_external_model_not_found", + "err_external_not_found", } -var _ = [1]int{}[int(ErrExternalModelNotFound)-(len(errorKinds)-1)] // Ensure list length matches +var _ = [1]int{}[int(ErrExternalNotFound)-(len(errorKinds)-1)] // Ensure list length matches func (t ErrorKind) String() string { return errorKinds[t] @@ -578,9 +578,9 @@ func ErrorImplDoesNotExist(path string) error { } } -func ErrorExternalModelNotFound(path string) error { +func ErrorExternalNotFound(path string) error { return Error{ - Kind: ErrExternalModelNotFound, + Kind: ErrExternalNotFound, message: fmt.Sprintf("%s: file not found or inaccessible", path), } } diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index 50f9e94d86..7e8bad3ece 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -441,7 +441,12 @@ def populate_values(self, input, input_schema, preserve_column_refs): if util.is_resource_ref(input): res_name = util.get_resource_ref(input) if res_name in self.constants: - const_val = self.constants[res_name]["value"] + if self.constants[res_name]["value"]: + const_val = self.constants[res_name]["value"] + elif self.constants[res_name]["external"]: + const_val = self.storage.get_json_external( + self.constants[res_name]["external"]["path"] + ) try: return self.populate_values(const_val, input_schema, preserve_column_refs) except CortexException as e: diff --git a/pkg/workloads/lib/storage/s3.py b/pkg/workloads/lib/storage/s3.py index cd8dea6d58..a438984e2f 100644 --- a/pkg/workloads/lib/storage/s3.py +++ b/pkg/workloads/lib/storage/s3.py @@ -114,9 +114,12 @@ def _get_matching_s3_keys_generator(self, prefix="", suffix=""): def _upload_string_to_s3(self, string, key): self.s3.put_object(Bucket=self.bucket, Key=key, Body=string) - def _read_bytes_from_s3(self, key, allow_missing=False): + def _read_bytes_from_s3(self, key, allow_missing=False, ext_bucket=None): try: - byte_array = self.s3.get_object(Bucket=self.bucket, Key=key)["Body"].read() + bucket = self.bucket + if ext_bucket is not None: + bucket = ext_bucket + byte_array = self.s3.get_object(Bucket=bucket, Key=key)["Body"].read() except self.s3.exceptions.NoSuchKey as e: if allow_missing: return None @@ -190,3 +193,13 @@ def download_file_external(self, s3_path, local_path): return local_path except Exception as e: raise CortexException("bucket " + bucket, "key " + key) from e + + def get_json_external(self, s3_path): + try: + bucket, key = self.deconstruct_s3_path(s3_path) + obj = self._read_bytes_from_s3(key, ext_bucket=bucket) + if obj is None: + return None + return json.loads(obj.decode("utf-8")) + except Exception as e: + raise CortexException("bucket " + bucket, "key " + key) from e diff --git a/pkg/workloads/spark_job/test/unit/spark_util_test.py b/pkg/workloads/spark_job/test/unit/spark_util_test.py index 9a8bf11233..38622a0eb5 100644 --- a/pkg/workloads/spark_job/test/unit/spark_util_test.py +++ b/pkg/workloads/spark_job/test/unit/spark_util_test.py @@ -669,7 +669,7 @@ def test_read_parquet_infer_invalid(spark, write_parquet_file, ctx_obj, get_cont }, }, { - "data": [("1", 0.1, "yolo"), ("1", 1.0, "yolo"), ("1", 1.1, "yolo")], + "data": [("1", 0.1, "a"), ("1", 1.0, "a"), ("1", 1.1, "a")], "schema": StructType( [ StructField("a_str", StringType()),