Skip to content
11 changes: 11 additions & 0 deletions src/mplfinance/_arg_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,17 @@ def _valid_mav(value, is_period=True):
return True
return False

def _colors_validator(value):
if not isinstance(value, list):
return False

for v in value:
if v:
if not (isinstance(v, dict) or isinstance(v, str)):
return False

return True


def _hlines_validator(value):
if isinstance(value,dict):
Expand Down
144 changes: 113 additions & 31 deletions src/mplfinance/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from six.moves import zip

def _check_input(opens, closes, highs, lows):
def _check_input(opens, closes, highs, lows, colors=None):
"""Checks that *opens*, *highs*, *lows* and *closes* have the same length.
NOTE: this code assumes if any value open, high, low, close is
missing (*-1*) they all are missing
Expand All @@ -46,6 +46,10 @@ def _check_input(opens, closes, highs, lows):
if not same_length:
raise ValueError('O,H,L,C must have the same length!')

if colors:
if len(opens) != len(colors):
raise ValueError('O,H,L,C and Colors must have the same length!')

o = np.where(np.isnan(opens))[0]
h = np.where(np.isnan(highs))[0]
l = np.where(np.isnan(lows))[0]
Expand Down Expand Up @@ -85,19 +89,19 @@ def _check_and_convert_xlim_configuration(data, config):
return xlim


def _construct_mpf_collections(ptype,dates,xdates,opens,highs,lows,closes,volumes,config,style):
def _construct_mpf_collections(ptype,dates,xdates,opens,highs,lows,closes,volumes,config,style,colors):
collections = None
if ptype == 'candle' or ptype == 'candlestick':
collections = _construct_candlestick_collections(xdates, opens, highs, lows, closes,
marketcolors=style['marketcolors'],config=config )
marketcolors=style['marketcolors'],config=config, colors=colors )

elif ptype =='hollow_and_filled':
collections = _construct_hollow_candlestick_collections(xdates, opens, highs, lows, closes,
marketcolors=style['marketcolors'],config=config )
marketcolors=style['marketcolors'],config=config, colors=colors )

elif ptype == 'ohlc' or ptype == 'bars' or ptype == 'ohlc_bars':
collections = _construct_ohlc_collections(xdates, opens, highs, lows, closes,
marketcolors=style['marketcolors'],config=config )
marketcolors=style['marketcolors'],config=config, colors=colors )
elif ptype == 'renko':
collections = _construct_renko_collections(
dates, highs, lows, volumes, config['renko_params'], closes, marketcolors=style['marketcolors'])
Expand Down Expand Up @@ -176,16 +180,45 @@ def coalesce_volume_dates(in_volumes, in_dates, indexes):
return volumes, dates


def _updown_colors(upcolor,downcolor,opens,closes,use_prev_close=False):
if upcolor == downcolor:
return upcolor
cmap = {True : upcolor, False : downcolor}
if not use_prev_close:
return [ cmap[opn < cls] for opn,cls in zip(opens,closes) ]
def _updown_colors(upcolor,downcolor,opens,closes,use_prev_close=False,colors=None):
if not colors:
if upcolor == downcolor:
return upcolor
cmap = {True : upcolor, False : downcolor}
if not use_prev_close:
return [ cmap[opn < cls] for opn,cls in zip(opens,closes) ]
else:
first = cmap[opens[0] < closes[0]]
_list = [ cmap[pre < cls] for cls,pre in zip(closes[1:], closes) ]
return [first] + _list
else:
first = cmap[opens[0] < closes[0]]
_list = [ cmap[pre < cls] for cls,pre in zip(closes[1:], closes) ]
return [first] + _list
cmap = {True: 'up', False: 'down'}
default = {'up': upcolor, 'down': downcolor}
custom = []
if not use_prev_close:
for i in range(len(opens)):
opn = opens[i]
cls = closes[i]
if colors[i]:
custom.append(colors[i][cmap[opn < cls]])
else:
custom.append(default[cmap[opn < cls]])
else:
if color[0]:
custom.append(colors[0][cmap[opens[0] < closes[0]]])
else:
custom.append(default[cmap[opens[0] < closes[0]]])

for i in range(len(closes) - 1):
pre = closes[1:][i]
cls = closes[i]
if colors[i]:
custom.append(colors[i][cmap[pre < cls]])
else:
custom.append(default[cmap[pre < cls]])

return custom



def _updownhollow_colors(upcolor,downcolor,hollowcolor,opens,closes):
Expand Down Expand Up @@ -447,7 +480,7 @@ def _valid_lines_kwargs():
return vkwargs


def _construct_ohlc_collections(dates, opens, highs, lows, closes, marketcolors=None, config=None):
def _construct_ohlc_collections(dates, opens, highs, lows, closes, marketcolors=None, config=None, colors=None):
"""Represent the time, open, high, low, close as a vertical line
ranging from low to high. The left tick is the open and the right
tick is the close.
Expand All @@ -472,8 +505,8 @@ def _construct_ohlc_collections(dates, opens, highs, lows, closes, marketcolors=
ret : list
a list or tuple of matplotlib collections to be added to the axes
"""

_check_input(opens, highs, lows, closes)
_check_input(opens, highs, lows, closes, colors)

if marketcolors is None:
mktcolors = _get_mpfstyle('classic')['marketcolors']['ohlc']
Expand All @@ -497,13 +530,25 @@ def _construct_ohlc_collections(dates, opens, highs, lows, closes, marketcolors=
# we'll translate these to the date, close location
closeSegments = [((dt, close), (dt+ticksize, close)) for dt, close in zip(dates, closes)]

if mktcolors['up'] == mktcolors['down']:
colors = mktcolors['up']
else:
colorup = mcolors.to_rgba(mktcolors['up'])
colordown = mcolors.to_rgba(mktcolors['down'])
colord = {True: colorup, False: colordown}
colors = [colord[open < close] for open, close in zip(opens, closes)]
bar_c = None
if colors:
bar_c = []
for color in colors:
if color:
bar_up = color['ohlc']['up']
bar_down = color['ohlc']['down']
if bar_up == 'k':
bar_up = mktcolors['up']
if bar_down == 'k':
bar_down = mktcolors['down']

bar_c.append({'up': mcolors.to_rgba(bar_up, 1), 'down': mcolors.to_rgba(bar_down, 1)})
else:
bar_c.append(None)

uc = mcolors.to_rgba(mktcolors['up'])
dc = mcolors.to_rgba(mktcolors['down'])
colors = _updown_colors(uc, dc, opens, closes, colors=bar_c)

lw = config['_width_config']['ohlc_linewidth']

Expand All @@ -525,7 +570,7 @@ def _construct_ohlc_collections(dates, opens, highs, lows, closes, marketcolors=
return [rangeCollection, openCollection, closeCollection]


def _construct_candlestick_collections(dates, opens, highs, lows, closes, marketcolors=None, config=None):
def _construct_candlestick_collections(dates, opens, highs, lows, closes, marketcolors=None, config=None, colors=None):
"""Represent the open, close as a bar line and high low range as a
vertical line.

Expand All @@ -552,8 +597,8 @@ def _construct_candlestick_collections(dates, opens, highs, lows, closes, market
ret : list
(lineCollection, barCollection)
"""
_check_input(opens, highs, lows, closes)

_check_input(opens, highs, lows, closes, colors)

if marketcolors is None:
marketcolors = _get_mpfstyle('classic')['marketcolors']
Expand Down Expand Up @@ -581,17 +626,54 @@ def _construct_candlestick_collections(dates, opens, highs, lows, closes, market

alpha = marketcolors['alpha']

candle_c = None
wick_c = None
edge_c = None
if colors:
candle_c = []
wick_c = []
edge_c = []
for color in colors:
if color:
candle_up = color['candle']['up']
candle_down = color['candle']['down']
edge_up = color['edge']['up']
edge_down = color['edge']['down']
wick_up = color['wick']['up']
wick_down = color['wick']['down']
if candle_up == 'w':
candle_up = marketcolors['candle']['up']
if candle_down == 'k':
candle_down = marketcolors['candle']['down']
if edge_up == 'k':
edge_up = candle_up
if edge_down == 'k':
edge_down = candle_down
if wick_up == 'k':
wick_up = candle_up
if wick_down == 'k':
wick_down = candle_down

candle_c.append({'up': mcolors.to_rgba(candle_up, alpha), 'down': mcolors.to_rgba(candle_down, alpha)})
edge_c.append({'up': mcolors.to_rgba(edge_up, 1), 'down': mcolors.to_rgba(edge_down, 1)})
wick_c.append({'up': mcolors.to_rgba(wick_up, 1), 'down': mcolors.to_rgba(wick_down, 1)})

else:
candle_c.append(None)
wick_c.append(None)
edge_c.append(None)

uc = mcolors.to_rgba(marketcolors['candle'][ 'up' ], alpha)
dc = mcolors.to_rgba(marketcolors['candle']['down'], alpha)
colors = _updown_colors(uc, dc, opens, closes)
colors = _updown_colors(uc, dc, opens, closes, colors=candle_c)

uc = mcolors.to_rgba(marketcolors['edge'][ 'up' ], 1.0)
dc = mcolors.to_rgba(marketcolors['edge']['down'], 1.0)
edgecolor = _updown_colors(uc, dc, opens, closes)
edgecolor = _updown_colors(uc, dc, opens, closes, colors=edge_c)

uc = mcolors.to_rgba(marketcolors['wick'][ 'up' ], 1.0)
dc = mcolors.to_rgba(marketcolors['wick']['down'], 1.0)
wickcolor = _updown_colors(uc, dc, opens, closes)
wickcolor = _updown_colors(uc, dc, opens, closes, colors=wick_c)

lw = config['_width_config']['candle_linewidth']

Expand All @@ -609,7 +691,7 @@ def _construct_candlestick_collections(dates, opens, highs, lows, closes, market
return [rangeCollection, barCollection]


def _construct_hollow_candlestick_collections(dates, opens, highs, lows, closes, marketcolors=None, config=None):
def _construct_hollow_candlestick_collections(dates, opens, highs, lows, closes, marketcolors=None, config=None, colors=None):
"""Represent today's open to close as a "bar" line (candle body)
and high low range as a vertical line (candle wick)

Expand Down
19 changes: 16 additions & 3 deletions src/mplfinance/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from mplfinance._arg_validators import _scale_padding_validator, _yscale_validator
from mplfinance._arg_validators import _valid_panel_id, _check_for_external_axes
from mplfinance._arg_validators import _xlim_validator
from mplfinance._arg_validators import _colors_validator

from mplfinance._panels import _build_panels
from mplfinance._panels import _set_ticks_on_bottom_panel_only
Expand All @@ -49,6 +50,8 @@
from mplfinance._helpers import _num_or_seq_of_num
from mplfinance._helpers import _adjust_color_brightness

from mplfinance._styles import make_marketcolors

VALID_PMOVE_TYPES = ['renko', 'pnf']

DEFAULT_FIGRATIO = (8.00,5.75)
Expand Down Expand Up @@ -125,6 +128,9 @@ def _valid_plot_kwargs():

'marketcolors' : { 'Default' : None, # use 'style' for default, instead.
'Validator' : lambda value: isinstance(value,dict) },

'colors' : { 'Default' : None, # use default style instead.
'Validator' : lambda value: _colors_validator(value) },

'no_xgaps' : { 'Default' : True, # None means follow default logic below:
'Validator' : lambda value: _warn_no_xgaps_deprecated(value) },
Expand Down Expand Up @@ -391,14 +397,21 @@ def plot( data, **kwargs ):
rwc = config['return_width_config']
if isinstance(rwc,dict) and len(rwc)==0:
config['return_width_config'].update(config['_width_config'])


if config['colors']:
colors = config['colors']
for c in range(len(colors)):
if isinstance(colors[c], str):
config['colors'][c] = make_marketcolors(up=colors[c], down=colors[c], edge=colors[c], wick=colors[c], ohlc=colors[c], volume=colors[c])
else:
config['colors'] = None

collections = None
if ptype == 'line':
lw = config['_width_config']['line_width']
axA1.plot(xdates, closes, color=config['linecolor'], linewidth=lw)
else:
collections =_construct_mpf_collections(ptype,dates,xdates,opens,highs,lows,closes,volumes,config,style)
collections =_construct_mpf_collections(ptype,dates,xdates,opens,highs,lows,closes,volumes,config,style,config['colors'])

if ptype in VALID_PMOVE_TYPES:
collections, calculated_values = collections
Expand Down Expand Up @@ -858,7 +871,7 @@ def _addplot_collections(panid,panels,apdict,xdates,config):
if not isinstance(apdata,pd.DataFrame):
raise TypeError('addplot type "'+aptype+'" MUST be accompanied by addplot data of type `pd.DataFrame`')
d,o,h,l,c,v = _check_and_prepare_data(apdata,config)
collections = _construct_mpf_collections(aptype,d,xdates,o,h,l,c,v,config,config['style'])
collections = _construct_mpf_collections(aptype,d,xdates,o,h,l,c,v,config,config['style'],config['colors'])

if not external_axes_mode:
lo = math.log(max(math.fabs(np.nanmin(l)),1e-7),10) - 0.5
Expand Down