diff --git a/Makefile b/Makefile index aae1d9fe94..bb3b68449d 100644 --- a/Makefile +++ b/Makefile @@ -154,6 +154,7 @@ ci-build-images: @./build/build-image.sh images/istio-galley istio-galley @./build/build-image.sh images/istio-pilot istio-pilot @./build/build-image.sh images/istio-proxy istio-proxy + @./build/build-image.sh images/downloader downloader ci-push-images: @./build/push-image.sh manager @@ -172,6 +173,7 @@ ci-push-images: @./build/push-image.sh istio-galley @./build/push-image.sh istio-pilot @./build/push-image.sh istio-proxy + @./build/push-image.sh downloader ci-build-cli: diff --git a/cortex.sh b/cortex.sh index c5cfa439a2..a883e8a270 100755 --- a/cortex.sh +++ b/cortex.sh @@ -161,6 +161,7 @@ export CORTEX_IMAGE_ISTIO_CITADEL="${CORTEX_IMAGE_ISTIO_CITADEL:-cortexlabs/isti export CORTEX_IMAGE_ISTIO_GALLEY="${CORTEX_IMAGE_ISTIO_GALLEY:-cortexlabs/istio-galley:$CORTEX_VERSION_STABLE}" export CORTEX_IMAGE_ISTIO_PILOT="${CORTEX_IMAGE_ISTIO_PILOT:-cortexlabs/istio-pilot:$CORTEX_VERSION_STABLE}" export CORTEX_IMAGE_ISTIO_PROXY="${CORTEX_IMAGE_ISTIO_PROXY:-cortexlabs/istio-proxy:$CORTEX_VERSION_STABLE}" +export CORTEX_IMAGE_DOWNLOADER="${CORTEX_IMAGE_DOWNLOADER:-cortexlabs/downloader:$CORTEX_VERSION_STABLE}" export CORTEX_ENABLE_TELEMETRY="${CORTEX_ENABLE_TELEMETRY:-""}" export CORTEX_TELEMETRY_URL="${CORTEX_TELEMETRY_URL:-"https://telemetry.cortexlabs.dev"}" @@ -220,6 +221,7 @@ function install_cortex() { -e CORTEX_IMAGE_ISTIO_GALLEY=$CORTEX_IMAGE_ISTIO_GALLEY \ -e CORTEX_IMAGE_ISTIO_PILOT=$CORTEX_IMAGE_ISTIO_PILOT \ -e CORTEX_IMAGE_ISTIO_PROXY=$CORTEX_IMAGE_ISTIO_PROXY \ + -e CORTEX_IMAGE_DOWNLOADER=$CORTEX_IMAGE_DOWNLOADER \ -e CORTEX_ENABLE_TELEMETRY=$CORTEX_ENABLE_TELEMETRY \ $CORTEX_IMAGE_MANAGER } diff --git a/dev/registry.sh b/dev/registry.sh index 3d3d85a1c2..c83509a58c 100755 --- a/dev/registry.sh +++ b/dev/registry.sh @@ -51,6 +51,7 @@ function create_registry() { aws ecr create-repository --repository-name=cortexlabs/cluster-autoscaler --region=$REGISTRY_REGION || true aws ecr create-repository --repository-name=cortexlabs/nvidia --region=$REGISTRY_REGION || true aws ecr create-repository --repository-name=cortexlabs/metrics-server --region=$REGISTRY_REGION || true + aws ecr create-repository --repository-name=cortexlabs/downloader --region=$REGISTRY_REGION || true } ### HELPERS ### @@ -139,6 +140,7 @@ elif [ "$cmd" = "update" ]; then fi build_and_push $ROOT/images/tf-api tf-api latest + build_and_push $ROOT/images/downloader downloader latest build_and_push $ROOT/images/onnx-serve onnx-serve latest cleanup diff --git a/docs/cluster/config.md b/docs/cluster/config.md index f1f439b151..7689b28f34 100644 --- a/docs/cluster/config.md +++ b/docs/cluster/config.md @@ -59,6 +59,7 @@ export CORTEX_IMAGE_ISTIO_PROXY="cortexlabs/istio-proxy:master" export CORTEX_IMAGE_ISTIO_PILOT="cortexlabs/istio-pilot:master" export CORTEX_IMAGE_ISTIO_CITADEL="cortexlabs/istio-citadel:master" export CORTEX_IMAGE_ISTIO_GALLEY="cortexlabs/istio-galley:master" +export CORTEX_IMAGE_DOWNLOADER="cortexlabs/downloader:master" # Flag to enable collecting error reports and usage stats. If flag is not set to either "true" or "false", you will be prompted. export CORTEX_ENABLE_TELEMETRY="" diff --git a/docs/cluster/development.md b/docs/cluster/development.md index b697d64e1f..67fc1f0c12 100644 --- a/docs/cluster/development.md +++ b/docs/cluster/development.md @@ -77,6 +77,7 @@ export CORTEX_IMAGE_ISTIO_PROXY="XXXXXXXX.dkr.ecr.us-west-2.amazonaws.com/cortex export CORTEX_IMAGE_ISTIO_PILOT="XXXXXXXX.dkr.ecr.us-west-2.amazonaws.com/cortexlabs/istio-pilot:latest" export CORTEX_IMAGE_ISTIO_CITADEL="XXXXXXXX.dkr.ecr.us-west-2.amazonaws.com/cortexlabs/istio-citadel:latest" export CORTEX_IMAGE_ISTIO_GALLEY="XXXXXXXX.dkr.ecr.us-west-2.amazonaws.com/cortexlabs/istio-galley:latest" +export CORTEX_IMAGE_DOWNLOADER="XXXXXXXX.dkr.ecr.us-west-2.amazonaws.com/cortexlabs/downloader:latest" export CORTEX_ENABLE_TELEMETRY="false" ``` diff --git a/docs/deployments/packaging-models.md b/docs/deployments/packaging-models.md index 61f590231c..dc06c9ff1d 100644 --- a/docs/deployments/packaging-models.md +++ b/docs/deployments/packaging-models.md @@ -2,7 +2,7 @@ ## TensorFlow -Export your trained model and zip the model directory. An example is shown below (here is the [complete example](https://github.com/cortexlabs/cortex/blob/master/examples/iris/models/tensorflow_model.py)): +Export your trained model and upload the export directory, or checkpoint directory containing the export directory, which is usually the case if you used `estimator.train_and_evaluate`. An example is shown below (here is the [complete example](https://github.com/cortexlabs/cortex/blob/master/examples/sentiment)): ```Python import tensorflow as tf @@ -10,22 +10,27 @@ import shutil import os ... - -classifier = tf.estimator.Estimator( - model_fn=my_model, model_dir="iris", params={"hidden_units": [10, 10], "n_classes": 3} -) - -exporter = tf.estimator.FinalExporter("estimator", serving_input_fn, as_text=False) -train_spec = tf.estimator.TrainSpec(train_input_fn, max_steps=1000) -eval_spec = tf.estimator.EvalSpec(eval_input_fn, exporters=[exporter], name="estimator-eval") - -tf.estimator.train_and_evaluate(classifier, train_spec, eval_spec) +OUPUT_DIR="bert" +estimator = tf.estimator.Estimator(model_fn=model_fn...) + +# TF Serving requires a special input_fn used at serving time +def serving_input_fn(): + inputs = tf.placeholder(shape=[128], dtype=tf.int32) + features = { + "input_ids": tf.expand_dims(inputs, 0), + "input_mask": tf.expand_dims(inputs, 0), + "segment_ids": tf.expand_dims(inputs, 0), + "label_ids": tf.placeholder(shape=[0], dtype=tf.int32), + } + return tf.estimator.export.ServingInputReceiver(features=features, receiver_tensors=inputs) + +estimator.export_savedmodel(OUPUT_DIR, serving_input_fn, strip_default_attrs=True) ``` -Upload the exported version directory to Amazon S3 using the AWS web console or CLI: +Upload the checkpoint directory to Amazon S3 using the AWS web console or CLI: ```text -$ aws s3 sync ./iris/export/estimator/156293432 s3://my-bucket/iris/156293432 +$ aws s3 sync ./bert s3://my-bucket/bert ``` Reference your model in an `api`: @@ -33,7 +38,7 @@ Reference your model in an `api`: ```yaml - kind: api name: my-api - model: s3://my-bucket/iris/156293432 + model: s3://my-bucket/bert ``` ## ONNX diff --git a/examples/image-classifier/cortex.yaml b/examples/image-classifier/cortex.yaml index 90d8217b62..b37ee58131 100644 --- a/examples/image-classifier/cortex.yaml +++ b/examples/image-classifier/cortex.yaml @@ -3,5 +3,5 @@ - kind: api name: classifier - model: s3://cortex-examples/imagenet/1566492692 + model: s3://cortex-examples/imagenet/ request_handler: imagenet.py diff --git a/images/downloader/Dockerfile b/images/downloader/Dockerfile new file mode 100644 index 0000000000..0c87a1f6f4 --- /dev/null +++ b/images/downloader/Dockerfile @@ -0,0 +1,23 @@ +FROM ubuntu:18.04 + +ENV PYTHONPATH="/src:${PYTHONPATH}" + +RUN apt-get update -qq && apt-get install -y -q \ + python3 \ + python3-dev \ + python3-pip \ + && apt-get clean -qq && rm -rf /var/lib/apt/lists/* && \ + pip3 install --upgrade \ + pip \ + setuptools \ + && rm -rf /root/.cache/pip* + +COPY pkg/workloads/cortex/lib/requirements.txt /src/cortex/lib/requirements.txt +RUN pip3 install -r /src/cortex/lib/requirements.txt && \ + rm -rf /root/.cache/pip* + +COPY pkg/workloads/cortex/consts.py /src/cortex/ +COPY pkg/workloads/cortex/lib /src/cortex/lib +COPY pkg/workloads/cortex/downloader /src/cortex/downloader + +ENTRYPOINT ["/usr/bin/python3", "/src/cortex/downloader/download.py"] diff --git a/manager/install_cortex.sh b/manager/install_cortex.sh index 081b2d5ae6..6f001c3809 100755 --- a/manager/install_cortex.sh +++ b/manager/install_cortex.sh @@ -64,6 +64,7 @@ function setup_configmap() { --from-literal='IMAGE_ONNX_SERVE'=$CORTEX_IMAGE_ONNX_SERVE \ --from-literal='IMAGE_ONNX_SERVE_GPU'=$CORTEX_IMAGE_ONNX_SERVE_GPU \ --from-literal='IMAGE_TF_API'=$CORTEX_IMAGE_TF_API \ + --from-literal='IMAGE_DOWNLOADER'=$CORTEX_IMAGE_DOWNLOADER \ --from-literal='IMAGE_PYTHON_PACKAGER'=$CORTEX_IMAGE_PYTHON_PACKAGER \ --from-literal='IMAGE_TF_SERVE_GPU'=$CORTEX_IMAGE_TF_SERVE_GPU \ --from-literal='ENABLE_TELEMETRY'=$CORTEX_ENABLE_TELEMETRY \ diff --git a/pkg/lib/aws/aws.go b/pkg/lib/aws/aws.go index fa3ac4e76d..9dda79bfd4 100644 --- a/pkg/lib/aws/aws.go +++ b/pkg/lib/aws/aws.go @@ -31,7 +31,7 @@ import ( type Client struct { Region string Bucket string - s3Client *s3.S3 + S3 *s3.S3 stsClient *sts.STS cloudWatchLogsClient *cloudwatchlogs.CloudWatchLogs CloudWatchMetrics *cloudwatch.CloudWatch @@ -48,7 +48,7 @@ func New(region string, bucket string, withAccountID bool) (*Client, error) { awsClient := &Client{ Bucket: bucket, Region: region, - s3Client: s3.New(sess), + S3: s3.New(sess), stsClient: sts.New(sess), CloudWatchMetrics: cloudwatch.New(sess), cloudWatchLogsClient: cloudwatchlogs.New(sess), diff --git a/pkg/lib/aws/s3.go b/pkg/lib/aws/s3.go index b2d856cf20..c5a51239b3 100644 --- a/pkg/lib/aws/s3.go +++ b/pkg/lib/aws/s3.go @@ -70,7 +70,7 @@ func S3PathJoin(paths ...string) string { func (c *Client) IsS3File(keys ...string) (bool, error) { for _, key := range keys { - _, err := c.s3Client.HeadObject(&s3.HeadObjectInput{ + _, err := c.S3.HeadObject(&s3.HeadObjectInput{ Bucket: aws.String(c.Bucket), Key: aws.String(key), }) @@ -88,7 +88,7 @@ func (c *Client) IsS3File(keys ...string) (bool, error) { func (c *Client) IsS3Prefix(prefixes ...string) (bool, error) { for _, prefix := range prefixes { - out, err := c.s3Client.ListObjectsV2(&s3.ListObjectsV2Input{ + out, err := c.S3.ListObjectsV2(&s3.ListObjectsV2Input{ Bucket: aws.String(c.Bucket), Prefix: aws.String(prefix), }) @@ -138,7 +138,7 @@ func (c *Client) IsS3PathDir(s3Paths ...string) (bool, error) { } func (c *Client) UploadBytesToS3(data []byte, key string) error { - _, err := c.s3Client.PutObject(&s3.PutObjectInput{ + _, err := c.S3.PutObject(&s3.PutObjectInput{ Body: bytes.NewReader(data), Key: aws.String(key), Bucket: aws.String(c.Bucket), @@ -210,7 +210,7 @@ func (c *Client) ReadMsgpackFromS3(objPtr interface{}, key string) error { } func (c *Client) ReadStringFromS3(key string) (string, error) { - response, err := c.s3Client.GetObject(&s3.GetObjectInput{ + response, err := c.S3.GetObject(&s3.GetObjectInput{ Key: aws.String(key), Bucket: aws.String(c.Bucket), }) @@ -225,7 +225,7 @@ func (c *Client) ReadStringFromS3(key string) (string, error) { } func (c *Client) ReadBytesFromS3(key string) ([]byte, error) { - response, err := c.s3Client.GetObject(&s3.GetObjectInput{ + response, err := c.S3.GetObject(&s3.GetObjectInput{ Key: aws.String(key), Bucket: aws.String(c.Bucket), }) @@ -246,7 +246,7 @@ func (c *Client) ListPrefix(prefix string, maxResults int64) ([]*s3.Object, erro MaxKeys: aws.Int64(maxResults), } - output, err := c.s3Client.ListObjectsV2(listObjectsInput) + output, err := c.S3.ListObjectsV2(listObjectsInput) if err != nil { return nil, errors.Wrap(err, prefix) } @@ -263,7 +263,7 @@ func (c *Client) DeleteFromS3ByPrefix(prefix string, continueIfFailure bool) err var subErr error - err := c.s3Client.ListObjectsV2Pages(listObjectsInput, + err := c.S3.ListObjectsV2Pages(listObjectsInput, func(listObjectsOutput *s3.ListObjectsV2Output, lastPage bool) bool { deleteObjects := make([]*s3.ObjectIdentifier, len(listObjectsOutput.Contents)) for i, object := range listObjectsOutput.Contents { @@ -276,7 +276,7 @@ func (c *Client) DeleteFromS3ByPrefix(prefix string, continueIfFailure bool) err Quiet: aws.Bool(true), }, } - _, newSubErr := c.s3Client.DeleteObjects(deleteObjectsInput) + _, newSubErr := c.S3.DeleteObjects(deleteObjectsInput) if newSubErr != nil { subErr = newSubErr if !continueIfFailure { diff --git a/pkg/operator/api/userconfig/apis.go b/pkg/operator/api/userconfig/apis.go index 0a47d34472..3421c9f742 100644 --- a/pkg/operator/api/userconfig/apis.go +++ b/pkg/operator/api/userconfig/apis.go @@ -18,8 +18,11 @@ package userconfig import ( "fmt" + "path/filepath" + "strconv" "strings" + "github.com/aws/aws-sdk-go/service/s3" "github.com/cortexlabs/cortex/pkg/consts" "github.com/cortexlabs/cortex/pkg/lib/aws" cr "github.com/cortexlabs/cortex/pkg/lib/configreader" @@ -147,6 +150,45 @@ func IsValidTensorFlowS3Directory(path string, awsClient *aws.Client) bool { return true } +func GetTFServingExportFromS3Path(path string, awsClient *aws.Client) (string, error) { + if IsValidTensorFlowS3Directory(path, awsClient) { + return path, nil + } + + bucket, prefix, err := aws.SplitS3Path(path) + if err != nil { + return "", err + } + + resp, _ := awsClient.S3.ListObjects(&s3.ListObjectsInput{ + Bucket: &bucket, + Prefix: &prefix, + }) + + highestVersion := int64(0) + var highestPath string + for _, key := range resp.Contents { + if !strings.HasSuffix(*key.Key, "saved_model.pb") { + continue + } + + keyParts := strings.Split(*key.Key, "/") + versionStr := keyParts[len(keyParts)-1] + version, err := strconv.ParseInt(versionStr, 10, 64) + if err != nil { + version = 0 + } + + possiblePath := "s3://" + filepath.Join(bucket, filepath.Join(keyParts[:len(keyParts)-1]...)) + if version >= highestVersion && IsValidTensorFlowS3Directory(possiblePath, awsClient) { + highestVersion = version + highestPath = possiblePath + } + } + + return highestPath, nil +} + func (api *API) UserConfigStr() string { var sb strings.Builder sb.WriteString(api.ResourceFields.UserConfigStr()) @@ -190,27 +232,31 @@ func (api *API) Validate() error { return err } - switch api.ModelFormat { - case ONNXModelFormat: + switch { + case api.ModelFormat == ONNXModelFormat: if ok, err := awsClient.IsS3PathFile(api.Model); err != nil || !ok { return errors.Wrap(ErrorExternalNotFound(api.Model), Identify(api), ModelKey) } - case TensorFlowModelFormat: - if !IsValidTensorFlowS3Directory(api.Model, awsClient) { + case api.ModelFormat == TensorFlowModelFormat: + path, err := GetTFServingExportFromS3Path(api.Model, awsClient) + if path == "" || err != nil { return errors.Wrap(ErrorInvalidTensorflowDir(api.Model), Identify(api), ModelKey) } + case strings.HasSuffix(api.Model, ".onnx"): + api.ModelFormat = ONNXModelFormat + if ok, err := awsClient.IsS3PathFile(api.Model); err != nil || !ok { + return errors.Wrap(ErrorExternalNotFound(api.Model), Identify(api), ModelKey) + } default: - switch { - case strings.HasSuffix(api.Model, ".onnx"): - api.ModelFormat = ONNXModelFormat - if ok, err := awsClient.IsS3PathFile(api.Model); err != nil || !ok { - return errors.Wrap(ErrorExternalNotFound(api.Model), Identify(api), ModelKey) - } - case IsValidTensorFlowS3Directory(api.Model, awsClient): - api.ModelFormat = TensorFlowModelFormat - default: + path, err := GetTFServingExportFromS3Path(api.Model, awsClient) + if err != nil { + return errors.Wrap(err, Identify(api), ModelKey) + } + if path == "" { return errors.Wrap(ErrorUnableToInferModelFormat(api.Model), Identify(api)) } + api.ModelFormat = TensorFlowModelFormat + api.Model = path } if api.ModelFormat == TensorFlowModelFormat && api.TFServing == nil { diff --git a/pkg/operator/api/userconfig/errors.go b/pkg/operator/api/userconfig/errors.go index 1e106a6828..16493348a2 100644 --- a/pkg/operator/api/userconfig/errors.go +++ b/pkg/operator/api/userconfig/errors.go @@ -277,8 +277,8 @@ func ErrorExternalNotFound(path string) error { var onnxExpectedStructMessage = `For ONNX models, the path should end in .onnx` -var tfExpectedStructMessage = `For TensorFlow models, the path should be a directory with the following structure: - 1523423423/ (version prefix, usually a timestamp) +var tfExpectedStructMessage = `For TensorFlow models, the path must contain a directory with the following structure: + 1523423423/ (Version prefix, usually a timestamp) ├── saved_model.pb └── variables/ ├── variables.index diff --git a/pkg/operator/config/config.go b/pkg/operator/config/config.go index 8209406df0..96778b4d9f 100644 --- a/pkg/operator/config/config.go +++ b/pkg/operator/config/config.go @@ -46,6 +46,7 @@ type CortexConfig struct { OperatorImage string `json:"operator_image"` TFServeImage string `json:"tf_serve_image"` TFAPIImage string `json:"tf_api_image"` + DownloaderImage string `json:"downloader_image"` PythonPackagerImage string `json:"python_packager_image"` TFServeImageGPU string `json:"tf_serve_image_gpu"` ONNXServeImage string `json:"onnx_serve_image"` @@ -66,6 +67,7 @@ func Init() error { OperatorImage: getStr("IMAGE_OPERATOR"), TFServeImage: getStr("IMAGE_TF_SERVE"), TFAPIImage: getStr("IMAGE_TF_API"), + DownloaderImage: getStr("IMAGE_DOWNLOADER"), PythonPackagerImage: getStr("IMAGE_PYTHON_PACKAGER"), TFServeImageGPU: getStr("IMAGE_TF_SERVE_GPU"), ONNXServeImage: getStr("IMAGE_ONNX_SERVE"), diff --git a/pkg/operator/workloads/api_workload.go b/pkg/operator/workloads/api_workload.go index f0ee8c0d3f..1ee0af9a6c 100644 --- a/pkg/operator/workloads/api_workload.go +++ b/pkg/operator/workloads/api_workload.go @@ -34,9 +34,9 @@ import ( ) const ( - apiContainerName = "api" - tfServingContainerName = "serve" - modelDownloadInitContainerName = "model-download" + apiContainerName = "api" + tfServingContainerName = "serve" + downloaderInitContainerName = "downloader" defaultPortInt32, defaultPortStr = int32(8888), "8888" tfServingPortInt32, tfServingPortStr = int32(9000), "9000" @@ -281,18 +281,12 @@ func tfAPISpec( RestartPolicy: "Always", InitContainers: []kcore.Container{ { - Name: modelDownloadInitContainerName, - Image: config.Cortex.TFAPIImage, + Name: downloaderInitContainerName, + Image: config.Cortex.DownloaderImage, ImagePullPolicy: "Always", Args: []string{ - "--workload-id=" + workloadID, - "--port=" + defaultPortStr, - "--tf-serve-port=" + tfServingPortStr, - "--context=" + config.AWS.S3Path(ctx.Key), - "--api=" + ctx.APIs[api.Name].ID, - "--model-dir=" + path.Join(consts.EmptyDirMountPath, "model"), - "--cache-dir=" + consts.ContextCacheDir, - "--only-download=true", + "--download_from=" + ctx.APIs[api.Name].Model, + "--download_to=" + path.Join(consts.EmptyDirMountPath, "model"), }, Env: k8s.AWSCredentials(), VolumeMounts: k8s.DefaultVolumeMounts(), @@ -432,17 +426,12 @@ func onnxAPISpec( K8sPodSpec: kcore.PodSpec{ InitContainers: []kcore.Container{ { - Name: modelDownloadInitContainerName, - Image: servingImage, + Name: downloaderInitContainerName, + Image: config.Cortex.DownloaderImage, ImagePullPolicy: "Always", Args: []string{ - "--workload-id=" + workloadID, - "--port=" + defaultPortStr, - "--context=" + config.AWS.S3Path(ctx.Key), - "--api=" + ctx.APIs[api.Name].ID, - "--model-dir=" + path.Join(consts.EmptyDirMountPath, "model"), - "--cache-dir=" + consts.ContextCacheDir, - "--only-download=true", + "--download_from=" + ctx.APIs[api.Name].Model, + "--download_to=" + path.Join(consts.EmptyDirMountPath, "model"), }, Env: k8s.AWSCredentials(), VolumeMounts: k8s.DefaultVolumeMounts(), diff --git a/pkg/workloads/cortex/downloader/download.py b/pkg/workloads/cortex/downloader/download.py new file mode 100644 index 0000000000..5bb843cb21 --- /dev/null +++ b/pkg/workloads/cortex/downloader/download.py @@ -0,0 +1,41 @@ +# Copyright 2019 Cortex Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from cortex.lib.storage import S3 +from cortex.lib.log import get_logger + +logger = get_logger() + + +def start(args): + bucket_name, prefix = S3.deconstruct_s3_path(args.download_from) + s3_client = S3(bucket_name, client_config={}) + s3_client.download(prefix, args.download_to) + + +def main(): + parser = argparse.ArgumentParser() + na = parser.add_argument_group("required named arguments") + na.add_argument("--download_from", required=True, help="Storage Path to download the file from") + na.add_argument("--download_to", required=True, help="Directory to download the file to") + parser.set_defaults(func=start) + + args = parser.parse_args() + args.func(args) + + +if __name__ == "__main__": + main() diff --git a/pkg/workloads/cortex/lib/storage/s3.py b/pkg/workloads/cortex/lib/storage/s3.py index 95e930c999..83d81d7d4d 100644 --- a/pkg/workloads/cortex/lib/storage/s3.py +++ b/pkg/workloads/cortex/lib/storage/s3.py @@ -225,3 +225,9 @@ def download_and_unzip(self, key, local_dir): local_zip = os.path.join(local_dir, "zip.zip") self.download_file(key, local_zip) util.extract_zip(local_zip, delete_zip_file=True) + + def download(self, prefix, local_dir): + if self._is_s3_dir(prefix): + self.download_dir(prefix, local_dir) + else: + self.download_file_to_dir(prefix, local_dir) diff --git a/pkg/workloads/cortex/onnx_serve/api.py b/pkg/workloads/cortex/onnx_serve/api.py index 315a88c08b..2cf95c7afe 100644 --- a/pkg/workloads/cortex/onnx_serve/api.py +++ b/pkg/workloads/cortex/onnx_serve/api.py @@ -260,15 +260,8 @@ def start(args): local_cache["api"] = api local_cache["ctx"] = ctx - bucket_name, prefix = ctx.storage.deconstruct_s3_path(api["model"]) + _, prefix = ctx.storage.deconstruct_s3_path(api["model"]) model_path = os.path.join(args.model_dir, os.path.basename(prefix)) - if not os.path.exists(model_path): - s3_client = S3(bucket_name, client_config={}) - s3_client.download_file(prefix, model_path) - - if args.only_download: - return - if api.get("request_handler") is not None: package.install_packages(ctx.python_packages, ctx.storage) local_cache["request_handler"] = ctx.get_request_handler_impl(api["name"]) @@ -310,12 +303,6 @@ def main(): na.add_argument("--api", required=True, help="Resource id of api to serve") na.add_argument("--model-dir", required=True, help="Directory to download the model to") na.add_argument("--cache-dir", required=True, help="Local path for the context cache") - na.add_argument( - "--only-download", - required=False, - help="Only download model (for init-containers)", - default=False, - ) parser.set_defaults(func=start) diff --git a/pkg/workloads/cortex/tf_api/api.py b/pkg/workloads/cortex/tf_api/api.py index 7556b5da17..3caaf2889c 100644 --- a/pkg/workloads/cortex/tf_api/api.py +++ b/pkg/workloads/cortex/tf_api/api.py @@ -348,14 +348,6 @@ def start(args): local_cache["api"] = api local_cache["ctx"] = ctx - if not os.path.isdir(args.model_dir): - bucket_name, prefix = ctx.storage.deconstruct_s3_path(api["model"]) - s3_client = S3(bucket_name, client_config={}) - s3_client.download_dir(prefix, args.model_dir) - - if args.only_download: - return - if api.get("request_handler") is not None: package.install_packages(ctx.python_packages, ctx.storage) local_cache["request_handler"] = ctx.get_request_handler_impl(api["name"]) @@ -423,12 +415,6 @@ def main(): na.add_argument("--api", required=True, help="Resource id of api to serve") na.add_argument("--model-dir", required=True, help="Directory to download the model to") na.add_argument("--cache-dir", required=True, help="Local path for the context cache") - na.add_argument( - "--only-download", - required=False, - help="Only download model (for init-containers)", - default=False, - ) parser.set_defaults(func=start) args = parser.parse_args()