@@ -159,20 +159,6 @@ def _add_margins(table, data, values, rows, cols, aggfunc):
159
159
160
160
grand_margin = _compute_grand_margin (data , values , aggfunc )
161
161
162
- # categorical index or columns will fail below when 'All' is added
163
- # here we'll convert all categorical indices to object
164
- def convert_categorical (ind ):
165
- _convert = lambda ind : (ind .astype ('object' )
166
- if ind .dtype .name == 'category' else ind )
167
- if isinstance (ind , MultiIndex ):
168
- return ind .set_levels ([_convert (lev ) for lev in ind .levels ])
169
- else :
170
- return _convert (ind )
171
-
172
- table .index = convert_categorical (table .index )
173
- if hasattr (table , 'columns' ):
174
- table .columns = convert_categorical (table .columns )
175
-
176
162
if not values and isinstance (table , Series ):
177
163
# If there are no values and the table is a series, then there is only
178
164
# one column in the data. Compute grand margin and return it.
@@ -203,7 +189,13 @@ def convert_categorical(ind):
203
189
margin_dummy = DataFrame (row_margin , columns = [key ]).T
204
190
205
191
row_names = result .index .names
206
- result = result .append (margin_dummy )
192
+ try :
193
+ result = result .append (margin_dummy )
194
+ except TypeError :
195
+
196
+ # we cannot reshape, so coerce the axis
197
+ result .index = result .index ._to_safe_for_reshape ()
198
+ result = result .append (margin_dummy )
207
199
result .index .names = row_names
208
200
209
201
return result
@@ -232,6 +224,7 @@ def _compute_grand_margin(data, values, aggfunc):
232
224
233
225
234
226
def _generate_marginal_results (table , data , values , rows , cols , aggfunc , grand_margin ):
227
+
235
228
if len (cols ) > 0 :
236
229
# need to "interleave" the margins
237
230
table_pieces = []
@@ -249,7 +242,13 @@ def _all_key(key):
249
242
250
243
# we are going to mutate this, so need to copy!
251
244
piece = piece .copy ()
252
- piece [all_key ] = margin [key ]
245
+ try :
246
+ piece [all_key ] = margin [key ]
247
+ except TypeError :
248
+
249
+ # we cannot reshape, so coerce the axis
250
+ piece .set_axis (cat_axis , piece ._get_axis (cat_axis )._to_safe_for_reshape ())
251
+ piece [all_key ] = margin [key ]
253
252
254
253
table_pieces .append (piece )
255
254
margin_keys .append (all_key )
0 commit comments