@@ -267,7 +267,8 @@ def _validate_discrete_columns(self, train_data, discrete_columns):
267267 if invalid_columns :
268268 raise ValueError ('Invalid columns found: {}' .format (invalid_columns ))
269269
270- def fit (self , train_data , discrete_columns = tuple (), epochs = None ):
270+ def fit (self , train_data , discrete_columns = tuple (), epochs = None ,
271+ data_transformer_params = {}):
271272 """Fit the CTGAN Synthesizer models to the training data.
272273
273274 Args:
@@ -278,6 +279,8 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
278279 Vector. If ``train_data`` is a Numpy array, this list should
279280 contain the integer indices of the columns. Otherwise, if it is
280281 a ``pandas.DataFrame``, this list should contain the column names.
282+ data_transformer_params (dict):
283+ Dictionary of parameters for ``DataTransformer`` initialization.
281284 """
282285 self ._validate_discrete_columns (train_data , discrete_columns )
283286
@@ -290,7 +293,7 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
290293 DeprecationWarning
291294 )
292295
293- self ._transformer = DataTransformer ()
296+ self ._transformer = DataTransformer (** data_transformer_params )
294297 self ._transformer .fit (train_data , discrete_columns )
295298
296299 train_data = self ._transformer .transform (train_data )
0 commit comments