Skip to content

Commit 15e9f98

Browse files
committed
made changes according to pr
1 parent 71ca403 commit 15e9f98

File tree

4 files changed

+239
-194
lines changed

4 files changed

+239
-194
lines changed

src/mplfinance/_arg_validators.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def _check_and_prepare_data(data, config):
5252
columns = ('Open', 'High', 'Low', 'Close', 'Volume')
5353
if all([c.lower() in data for c in columns[0:4]]):
5454
columns = ('open', 'high', 'low', 'close', 'volume')
55-
55+
5656
o, h, l, c, v = columns
5757
cols = [o, h, l, c]
5858

@@ -100,7 +100,7 @@ def _get_valid_plot_types(plottype=None):
100100
return _alias_types[plottype]
101101
else:
102102
return plottype
103-
103+
104104

105105
def _mav_validator(mav_value):
106106
'''
@@ -143,12 +143,12 @@ def _valid_mav(value, is_period=True):
143143
return False
144144

145145
def _colors_validator(value):
146-
if not isinstance(value, list):
146+
if not isinstance(value, (list, tuple, np.ndarray)):
147147
return False
148148

149149
for v in value:
150150
if v:
151-
if not (isinstance(v, dict) or isinstance(v, str)):
151+
if v is not None and not isinstance(v, (dict, str)):
152152
return False
153153

154154
return True
@@ -204,11 +204,11 @@ def _alines_validator(value, returnStandardizedValue=False):
204204
A sequence of (line0, line1, line2), where:
205205
206206
linen = (x0, y0), (x1, y1), ... (xm, ym)
207-
207+
208208
or the equivalent numpy array with two columns. Each line can be a different length.
209209
210210
The above is from the matplotlib LineCollection documentation.
211-
It basically says that the "segments" passed into the LineCollection constructor
211+
It basically says that the "segments" passed into the LineCollection constructor
212212
must be a Sequence of Sequences of 2 or more xy Pairs. However here in `mplfinance`
213213
we want to allow that (seq of seq of xy pairs) _as well as_ just a sequence of pairs.
214214
Therefore here in the validator we will allow both:
@@ -270,8 +270,8 @@ def _tlines_subvalidator(value):
270270
def _bypass_kwarg_validation(value):
271271
''' For some kwargs, we either don't know enough, or
272272
the validation is too complex to make it worth while,
273-
so we bypass kwarg validation. If the kwarg is
274-
invalid, then eventually an exception will be
273+
so we bypass kwarg validation. If the kwarg is
274+
invalid, then eventually an exception will be
275275
raised at the time the kwarg value is actually used.
276276
'''
277277
return True
@@ -300,7 +300,7 @@ def _process_kwargs(kwargs, vkwargs):
300300
Given a "valid kwargs table" and some kwargs, verify that each key-word
301301
is valid per the kwargs table, and that the value of the kwarg is the
302302
correct type. Fill a configuration dictionary with the default value
303-
for each kwarg, and then substitute in any values that were provided
303+
for each kwarg, and then substitute in any values that were provided
304304
as kwargs and return the configuration dictionary.
305305
'''
306306
# initialize configuration from valid_kwargs_table:
@@ -327,7 +327,7 @@ def _process_kwargs(kwargs, vkwargs):
327327

328328
# ---------------------------------------------------------------
329329
# At this point in the loop, if we have not raised an exception,
330-
# then kwarg is valid as far as we can tell, therefore,
330+
# then kwarg is valid as far as we can tell, therefore,
331331
# go ahead and replace the appropriate value in config:
332332

333333
config[key] = value
@@ -346,7 +346,7 @@ def _scale_padding_validator(value):
346346
if key not in valid_keys:
347347
raise ValueError('Invalid key "'+str(key)+'" found in `scale_padding` dict.')
348348
if not isinstance(value[key],(int,float)):
349-
raise ValueError('`scale_padding` dict contains non-number at key "'+str(key)+'"')
349+
raise ValueError('`scale_padding` dict contains non-number at key "'+str(key)+'"')
350350
return True
351351
else:
352352
raise ValueError('`scale_padding` kwarg must be a number, or dict of (left,right,top,bottom) numbers.')
@@ -372,9 +372,9 @@ def _yscale_validator(value):
372372

373373
def _check_for_external_axes(config):
374374
'''
375-
Check that all `fig` and `ax` kwargs are either ALL None,
375+
Check that all `fig` and `ax` kwargs are either ALL None,
376376
or ALL are valid instances of Figures/Axes:
377-
377+
378378
An external Axes object can be passed in three places:
379379
- mpf.plot() `ax=` kwarg
380380
- mpf.plot() `volume=` kwarg
@@ -391,7 +391,7 @@ def _check_for_external_axes(config):
391391
raise TypeError('addplot must be `dict`, or `list of dict`, NOT '+str(type(addplot)))
392392
for apd in addplot:
393393
ap_axlist.append(apd['ax'])
394-
394+
395395
if len(ap_axlist) > 0:
396396
if config['ax'] is None:
397397
if not all([ax is None for ax in ap_axlist]):
@@ -416,6 +416,6 @@ def _check_for_external_axes(config):
416416
raise ValueError('`volume` must be of type `matplotlib.axis.Axes`')
417417
#if not isinstance(config['fig'],mpl.figure.Figure):
418418
# raise ValueError('`fig` kwarg must be of type `matplotlib.figure.Figure`')
419-
419+
420420
external_axes_mode = True if isinstance(config['ax'],mpl.axes.Axes) else False
421421
return external_axes_mode

0 commit comments

Comments
 (0)