Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions test/integration/sagemaker/recordio_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import argparse
from random import randint
import struct
import sys

import numpy as np
import tensorflow as tf

# Utility functions for generating a recordio encoded file of labeled numpy data
# for testing. Each file contains one or more records. Each record is a TensorFlow
# protobuf Example object. Each object contains an integer label and a numpy array
# encoded as a byte list.

# This file can be used in script mode to generate a single file or be used
# as a module to generate files via build_record_file.

_kmagic = 0xced7230a

padding = {}
for amount in range(4):
if sys.version_info >= (3,):
padding[amount] = bytes([0x00 for _ in range(amount)])
else:
padding[amount] = bytearray([0x00 for _ in range(amount)])


def write_recordio(f, data, header_flag=0):
"""Writes a single data point as a RecordIO record to the given file."""
length = len(data)
f.write(struct.pack('I', _kmagic))
header = (header_flag << 29) | length
f.write(struct.pack('I', header))
pad = (((length + 3) >> 2) << 2) - length
f.write(data)
f.write(padding[pad])


def write_recordio_multipart(f, data):
"""Writes a single data point into three multipart records."""
length = len(data)
stride = int(length / 3)

data_start = data[0:stride]
data_middle = data[stride:2 * stride]
data_end = data[2 * stride:]

write_recordio(f, data_start, 1)
write_recordio(f, data_middle, 2)
write_recordio(f, data_end, 3)


def string_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.tostring()]))


def label_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def write_numpy_array(f, feature_name, label, arr, multipart=False):
feature = {'labels': label_feature(label), feature_name: string_feature(arr)}
example = tf.train.Example(features=tf.train.Features(feature=feature))
if multipart:
write_recordio_multipart(f, example.SerializeToString())
else:
write_recordio(f, example.SerializeToString())


def build_record_file(filename, num_records, dimension, classes=2, data_feature_name='data', multipart=False):
"""Builds a recordio encoded file of TF protobuf Example objects. Each object
is a labeled numpy array. Each example has two field - a single int64 'label'
field and a single bytes list field, containing a serialized numpy array.

Each generated numpy array is a multidimensional normal with
the specified dimension. The normal distribution is class specific, each class
has a different mean for the distribution, so it should be possible to learn
a multiclass classifier on this data. Class means are determnistic - so multiple
calls to this function with the same number of classes will produce samples drawn
from the same distribution for each class.

Args:
filename - the file to write to
num_records - how many labeled numpy arrays to generate
classes - the cardinality of labels
data_feature_name - the name to give the numpy array in the Example object
dimension - the size of each numpy array.
"""
with open(filename, 'wb') as f:
for i in range(num_records):
cur_class = i % classes
loc = int(cur_class - (classes / 2))
write_numpy_array(f, data_feature_name, cur_class, np.random.normal(loc=loc, size=(dimension,)), multipart)


def build_single_record_file(filename, dimension, classes=2, data_feature_name='data'):
cur_class = randint(0, classes - 1)
loc = int(cur_class - (classes / 2))

arr = np.random.normal(loc=loc, size=(dimension,))
feature = {'labels': label_feature(cur_class), data_feature_name: string_feature(arr)}
example = tf.train.Example(features=tf.train.Features(feature=feature))
with open(filename, 'wb') as f:
f.write(example.SerializeToString())


def validate_record_file(filename, dimension):
data = open(filename, 'rb').read()
magic_number, length = struct.unpack('II', data[0:8])
encoded = data[8:8 + length]

features = {
'data': tf.io.FixedLenFeature([], tf.string),
'labels': tf.io.FixedLenFeature([], tf.int64),
}
parsed = tf.io.parse_single_example(encoded, features)
array = tf.io.decode_raw(parsed['data'], tf.float64)

assert array.shape[0] == dimension


if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Generate synthetic multi-class training data")
parser.add_argument('--dimension', default=65536, type=int)
parser.add_argument('--classes', default=2, type=int)
parser.add_argument('--num-records', default=4, type=int)
parser.add_argument('--data-feature-name', default='data')
parser.add_argument('filename', type=str)
args = parser.parse_args()
build_record_file(args.filename, args.num_records, args.dimension, args.classes, args.data_feature_name)
validate_record_file(args.filename, args.dimension)
114 changes: 114 additions & 0 deletions test/integration/sagemaker/test_pipemode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import os
import shutil
import uuid

import pytest
from recordio_utils import build_record_file, build_single_record_file
from sagemaker import s3_input
from sagemaker.tensorflow import TensorFlow

from test.integration.utils import processor, py_version, unique_name_from_base # noqa: F401
from timeout import timeout

DIMENSION = 5


def make_test_data(directory, name, num_files, num_records, dimension, sagemaker_session):
if not os.path.exists('test-data'):
os.makedirs('test-data')
for i in range(num_files):
if num_records > 1:
build_record_file(os.path.join(directory, name + str(i)),
num_records=num_records, dimension=dimension)
else:
build_single_record_file(os.path.join(directory, name + str(i)),
dimension=dimension)

return sagemaker_session.upload_data(path=os.path.join(directory),
key_prefix='pipemode-{}-files'.format(name))


@pytest.fixture(scope='session')
def multi_records_test_data(sagemaker_session):
test_data = 'test-data-' + str(uuid.uuid4())
os.makedirs(test_data)
s3_url = make_test_data(
directory=test_data,
name='multi',
num_files=1,
num_records=1000,
dimension=DIMENSION,
sagemaker_session=sagemaker_session)
shutil.rmtree(test_data)
return s3_url


@pytest.fixture(scope='session')
def single_record_test_data(sagemaker_session):
test_data = 'test-data-' + str(uuid.uuid4())
os.makedirs(test_data)
s3_url = make_test_data(
directory=test_data,
name='single',
num_files=100,
num_records=1,
dimension=DIMENSION,
sagemaker_session=sagemaker_session)
shutil.rmtree(test_data)
return s3_url


def run_test(sagemaker_session, ecr_image, instance_type, framework_version, test_data,
record_wrapper_type=None):
source_path = os.path.join(os.path.dirname(__file__), '..', '..', 'resources', 'pipemode')
script = os.path.join(source_path, 'pipemode.py')
estimator = TensorFlow(entry_point=script,
role='SageMakerRole',
train_instance_type=instance_type,
train_instance_count=1,
sagemaker_session=sagemaker_session,
image_name=ecr_image,
framework_version=framework_version,
script_mode=True,
input_mode='Pipe',
hyperparameters={'dimension': DIMENSION})
input = s3_input(s3_data=test_data,
distribution='FullyReplicated',
record_wrapping=record_wrapper_type,
input_mode='Pipe')
with timeout(minutes=20):
estimator.fit({'elizabeth': input},
job_name=unique_name_from_base('test-sagemaker-pipemode'))


def test_single_record(sagemaker_session, ecr_image, instance_type, framework_version,
single_record_test_data):
run_test(sagemaker_session,
ecr_image,
instance_type,
framework_version,
single_record_test_data,
'RecordIO')


def test_multi_records(sagemaker_session, ecr_image, instance_type, framework_version,
multi_records_test_data):
run_test(sagemaker_session,
ecr_image,
instance_type,
framework_version,
multi_records_test_data)
116 changes: 116 additions & 0 deletions test/resources/pipemode/pipemode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import json
import multiprocessing
import os
import tempfile

import tensorflow as tf
from sagemaker_tensorflow import PipeModeDataset

print("Starting estimator script")

ds = PipeModeDataset("elizabeth",benchmark=True)


class BenchmarkConfig(object):

def __init__(self):
self.hp = json.load(open('/opt/ml/input/config/hyperparameters.json'))

@property
def batch_size(self):
return int(self.hp.get('batch_size', 5))

@property
def prefetch_size(self):
return int(self.hp.get('prefetch_size', 1000))

@property
def channel(self):
return self.hp.get('channel', 'elizabeth')

@property
def dimension(self):
return int(self.hp['dimension'])

@property
def epochs(self):
return int(self.hp.get('epochs', 3))

@property
def parallel_transform_calls(self):
return int(self.hp.get('parallel_transform_calls', max(1, multiprocessing.cpu_count() - 2)))

def __repr__(self):
"""Return all properties"""
return str(vars(self))


config = BenchmarkConfig()


def input_fn():
features = {
'data': tf.io.FixedLenFeature([], tf.string),
'labels': tf.io.FixedLenFeature([], tf.int64),
}

def parse(record):
parsed = tf.io.parse_single_example(serialized=record, features=features)
return ({
'data': tf.io.decode_raw(parsed['data'], tf.float64)
}, parsed['labels'])

ds = PipeModeDataset(config.channel)

if config.epochs > 1:
ds = ds.repeat(config.epochs)
if config.prefetch_size > 0:
ds = ds.prefetch(config.prefetch_size)
ds = ds.map(parse, num_parallel_calls=config.parallel_transform_calls)
ds = ds.batch(config.batch_size)
return ds


# Perform Estimator training
column = tf.feature_column.numeric_column('data', shape=(config.dimension, ))
model_dir = tempfile.mkdtemp()
estimator = tf.estimator.LinearClassifier(feature_columns=[column])

print("About to call train")
estimator.train(input_fn=input_fn)

# Confirm that we have read the correct number of pipes
assert os.path.exists('/opt/ml/input/data/{}_{}'.format(config.channel, config.epochs + 1))

print("About to call evaluate")
result = estimator.evaluate(input_fn=input_fn)
for key, value in sorted(result.items()):
print('%s: %s' % (key, value))


# Test that we can create a new PipeModeDataset after training has run
print("Validate that new PipeModeDataset on existing channel can be created")
tf.compat.v1.disable_eager_execution()

ds = PipeModeDataset(config.channel,benchmark=True)
with tf.compat.v1.Session() as sess:
it = tf.compat.v1.data.make_one_shot_iterator(ds)
next = it.get_next()
sess.run(next)

print("Validate create, read, discard, recreate")

# Test that we can create a PipeModeDataset, discard it, and read from a new one
ds = PipeModeDataset(config.channel,benchmark=True)
with tf.compat.v1.Session() as sess:
it = tf.compat.v1.data.make_one_shot_iterator(ds)
next = it.get_next()


with tf.compat.v1.Session() as sess:
it = tf.compat.v1.data.make_one_shot_iterator(ds)
next = it.get_next()
sess.run(next)

print("Validate recreate")
ds = PipeModeDataset(config.channel,benchmark=True)