Skip to content

Commit dd2b944

Browse files
authored
Merge pull request #14 from jrhyness/jr_tf_mnist_datahandler
Update tensorflow data handler
2 parents 9af220d + 5f7cc6e commit dd2b944

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

Files/mnist_keras_data_handler.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pickle
44
import numpy as np
55

6-
from ibmfl.data.data_handler import DataHandler
6+
from ibm_watson_machine_learning.federated_learning.data_handler import DataHandler
77

88
logger = logging.getLogger(__name__)
99

@@ -38,14 +38,14 @@ def get_data(self, nb_points=500):
3838
logger.info(
3939
'Loaded training data from ' + str(self.train_file_name))
4040
with open(self.train_file_name, 'rb') as f:
41-
(x_train, y_train)= pickle.load(f)
41+
(self.x_train, self.y_train)= pickle.load(f)
4242
logger.info(
4343
'Loaded test data from ' + str(self.test_file_name))
4444
with open(self.test_file_name, 'rb') as f:
45-
(x_test, y_test)= pickle.load(f)
45+
(self.x_test, self.y_test)= pickle.load(f)
4646

47-
x_train = x_train / 255.0
48-
x_test = x_test / 255.0
47+
self.x_train = self.x_train / 255.0
48+
self.x_test = self.x_test / 255.0
4949

5050

5151
except Exception:
@@ -55,11 +55,11 @@ def get_data(self, nb_points=500):
5555

5656
# Add a channels dimension
5757
import tensorflow as tf
58-
x_train = x_train[..., tf.newaxis]
59-
x_test = x_test[..., tf.newaxis]
58+
self.x_train = self.x_train[..., tf.newaxis]
59+
self.x_test = self.x_test[..., tf.newaxis]
6060

61-
print('x_train shape:', x_train.shape)
62-
print(x_train.shape[0], 'train samples')
63-
print(x_test.shape[0], 'test samples')
61+
print('self.x_train shape:', self.x_train.shape)
62+
print(self.x_train.shape[0], 'train samples')
63+
print(self.x_test.shape[0], 'test samples')
6464

65-
return (x_train, y_train), (x_test, y_test)
65+
return (self.x_train, self.y_train), (self.x_test, self.y_test)

0 commit comments

Comments
 (0)