Skip to content

Commit f06db4b

Browse files
simplify marketcolor overrides
1 parent d653acd commit f06db4b

File tree

3 files changed

+258
-327
lines changed

3 files changed

+258
-327
lines changed

src/mplfinance/_arg_validators.py

+32-24
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import matplotlib.dates as mdates
2+
import matplotlib.colors as mcolors
23
import pandas as pd
34
import numpy as np
45
import datetime
@@ -52,7 +53,7 @@ def _check_and_prepare_data(data, config):
5253
columns = ('Open', 'High', 'Low', 'Close', 'Volume')
5354
if all([c.lower() in data for c in columns[0:4]]):
5455
columns = ('open', 'high', 'low', 'close', 'volume')
55-
56+
5657
o, h, l, c, v = columns
5758
cols = [o, h, l, c]
5859

@@ -100,7 +101,7 @@ def _get_valid_plot_types(plottype=None):
100101
return _alias_types[plottype]
101102
else:
102103
return plottype
103-
104+
104105

105106
def _mav_validator(mav_value):
106107
'''
@@ -142,17 +143,6 @@ def _valid_mav(value, is_period=True):
142143
return True
143144
return False
144145

145-
def _colors_validator(value):
146-
if not isinstance(value, (list, tuple, np.ndarray)):
147-
return False
148-
149-
for v in value:
150-
if v:
151-
if v is not None and not isinstance(v, (dict, str)):
152-
return False
153-
154-
return True
155-
156146

157147
def _hlines_validator(value):
158148
if isinstance(value,dict):
@@ -204,11 +194,11 @@ def _alines_validator(value, returnStandardizedValue=False):
204194
A sequence of (line0, line1, line2), where:
205195
206196
linen = (x0, y0), (x1, y1), ... (xm, ym)
207-
197+
208198
or the equivalent numpy array with two columns. Each line can be a different length.
209199
210200
The above is from the matplotlib LineCollection documentation.
211-
It basically says that the "segments" passed into the LineCollection constructor
201+
It basically says that the "segments" passed into the LineCollection constructor
212202
must be a Sequence of Sequences of 2 or more xy Pairs. However here in `mplfinance`
213203
we want to allow that (seq of seq of xy pairs) _as well as_ just a sequence of pairs.
214204
Therefore here in the validator we will allow both:
@@ -270,8 +260,8 @@ def _tlines_subvalidator(value):
270260
def _bypass_kwarg_validation(value):
271261
''' For some kwargs, we either don't know enough, or
272262
the validation is too complex to make it worth while,
273-
so we bypass kwarg validation. If the kwarg is
274-
invalid, then eventually an exception will be
263+
so we bypass kwarg validation. If the kwarg is
264+
invalid, then eventually an exception will be
275265
raised at the time the kwarg value is actually used.
276266
'''
277267
return True
@@ -300,7 +290,7 @@ def _process_kwargs(kwargs, vkwargs):
300290
Given a "valid kwargs table" and some kwargs, verify that each key-word
301291
is valid per the kwargs table, and that the value of the kwarg is the
302292
correct type. Fill a configuration dictionary with the default value
303-
for each kwarg, and then substitute in any values that were provided
293+
for each kwarg, and then substitute in any values that were provided
304294
as kwargs and return the configuration dictionary.
305295
'''
306296
# initialize configuration from valid_kwargs_table:
@@ -327,7 +317,7 @@ def _process_kwargs(kwargs, vkwargs):
327317

328318
# ---------------------------------------------------------------
329319
# At this point in the loop, if we have not raised an exception,
330-
# then kwarg is valid as far as we can tell, therefore,
320+
# then kwarg is valid as far as we can tell, therefore,
331321
# go ahead and replace the appropriate value in config:
332322

333323
config[key] = value
@@ -346,7 +336,7 @@ def _scale_padding_validator(value):
346336
if key not in valid_keys:
347337
raise ValueError('Invalid key "'+str(key)+'" found in `scale_padding` dict.')
348338
if not isinstance(value[key],(int,float)):
349-
raise ValueError('`scale_padding` dict contains non-number at key "'+str(key)+'"')
339+
raise ValueError('`scale_padding` dict contains non-number at key "'+str(key)+'"')
350340
return True
351341
else:
352342
raise ValueError('`scale_padding` kwarg must be a number, or dict of (left,right,top,bottom) numbers.')
@@ -370,11 +360,29 @@ def _yscale_validator(value):
370360
return True
371361

372362

363+
def _is_marketcolor_object(obj):
364+
if not isinstance(obj,dict): return False
365+
market_colors_keys = ('candle','edge','wick','ohlc')
366+
return all([k in obj for k in market_colors_keys])
367+
368+
369+
def _mco_validator(value): # marketcolor overrides validator
370+
if isinstance(value,dict): # not yet supported, but maybe we will have other
371+
if 'colors' not in value: # kwargs related to mktcolor overrides (ex: `mco_faceonly`)
372+
raise ValueError('`marketcolor_overrides` as dict must contain `colors` key.')
373+
colors = value['colors']
374+
else:
375+
colors = value
376+
if not isinstance(colors,(list,tuple,np.ndarray)):
377+
return False
378+
return all([(c is None or mcolors.is_color_like(c) or _is_marketcolor_object(c)) for c in colors])
379+
380+
373381
def _check_for_external_axes(config):
374382
'''
375-
Check that all `fig` and `ax` kwargs are either ALL None,
383+
Check that all `fig` and `ax` kwargs are either ALL None,
376384
or ALL are valid instances of Figures/Axes:
377-
385+
378386
An external Axes object can be passed in three places:
379387
- mpf.plot() `ax=` kwarg
380388
- mpf.plot() `volume=` kwarg
@@ -391,7 +399,7 @@ def _check_for_external_axes(config):
391399
raise TypeError('addplot must be `dict`, or `list of dict`, NOT '+str(type(addplot)))
392400
for apd in addplot:
393401
ap_axlist.append(apd['ax'])
394-
402+
395403
if len(ap_axlist) > 0:
396404
if config['ax'] is None:
397405
if not all([ax is None for ax in ap_axlist]):
@@ -416,6 +424,6 @@ def _check_for_external_axes(config):
416424
raise ValueError('`volume` must be of type `matplotlib.axis.Axes`')
417425
#if not isinstance(config['fig'],mpl.figure.Figure):
418426
# raise ValueError('`fig` kwarg must be of type `matplotlib.figure.Figure`')
419-
427+
420428
external_axes_mode = True if isinstance(config['ax'],mpl.axes.Axes) else False
421429
return external_axes_mode

0 commit comments

Comments
 (0)