diff --git a/README.md b/README.md index cce862b6..672ad69b 100644 --- a/README.md +++ b/README.md @@ -14,13 +14,6 @@ OR pass a json_dictionary directly into the module with the parameters defined AND/OR pass parameters via the command line, in a way that will override the input_json or the json_dictionary given. -## Upgrading to version 2.0 -The major change in argschema 2.0 is becoming -compatible with marshmallow 3, which changes -many of the ways your schemas and schema modifications work. Some noteable differences are that schemas are strict now by default, so tossing keys in your outputs or inputs that were ignored and stripped before now throw errors unless -def -Please read this document for more guidance -https://marshmallow.readthedocs.io/en/stable/upgrading.html ## Level of Support We are planning on occasional updating this tool with no fixed schedule. Community involvement is encouraged through both issues and pull requests. Please make pull requests against the dev branch, as we will test changes there before merging into master. @@ -75,7 +68,7 @@ You start building some code in an ipython notebook to play around with a new id It's a mess, and you know you should migrate your code over to a module that you can call from other programs or notebooks. You start collecting your input variables to the top of the notebook and make yourself a wrapper function that you can call. However, now your mistake in filename typing is a disaster because the file doesn't exist, and your code doesn't check for the existence of the file until quite late. You start implementing some input validation checks to avoid this problem. -Now you start wanting to integrate this code with other things, including elements that aren't in python. You decide that you need to have a command line module that executes the code, because then you can use other tools to stitch together your processing, like maybe some shell scripts or docker run commands. You implement an argparse set of inputs and default values that make your python program a self-contained program, with some help documentation. Along the way, you have to refactor the parsed argparse variables into your function and strip out your old hacky validation code to avoid maintaining two versions of validation in the future. +Now you start wanting to integrate this code with other things, including elements that aren't in python. You decide that you need to have a command line module that executes the code, because then you can use other tools to stitch together your processing, like maybe some shell scripts or docker run commands. You implement an argparse set of inputs and default values that make your python program a self-contained program, with some helpful documentation. Along the way, you have to refactor the parsed argparse variables into your function and strip out your old hacky validation code to avoid maintaining two versions of validation in the future. This module starts becoming useful enough that you want to integrate it into more complex modules. You end up copying and pasting various argparse lines over to other modules, and then 5 other modules. Later you decide to change your original module a little bit, and you have a nightmare of code replace to fix up the other modules to mirror this phenomenon.. you kick yourself for not having thought this through more clearly. @@ -85,5 +78,25 @@ If you had only designed things from the beginning to allow for each of these us This is what argschema is designed to do. + +## Upgrading to version 3.0 +The major change in argschema 3.0 is introducing a more generalized interface for reading and writing dictionaries, referred to as ArgSource and ArgSink. One can define customized classes that read dictionaries from any source you can code, such as making a database call, reading from a web service, reading a yaml file, etc. Argschema isn't just for json anymore. Similarly you can now dynamically tell your ArgSchemaParser to write output to an Argsink, which might write to a database, a webservice, or a messaging service. This enables those integrating modules into larger workflow management solutions more flexibility in wiring up your python modules to those systems. + +It also removes features that were marked previously as deprecated. + +Notably parsing List arguments with --listarg a b c, which instead should be called as --listarg a,b,c. In other words cli_as_single_argument = False is no longer an option. + +It also removes the old names JsonModule, ModuleParameters, which are now ArgSchemaParser and ArgSchema respectively. + +The field OptionList has been removed. The same functionality can be accomplished with the keyword, validate=mm.validate.OneOf([a,b,c...]) in the field definition. + +## Upgrading to version 2.0 +The major change in argschema 2.0 is becoming +compatible with marshmallow 3, which changes +many of the ways your schemas and schema modifications work. Some noteable differences are that schemas are strict now by default, so tossing keys in your outputs or inputs that were ignored and stripped before now throw errors. + +Please read this document for more guidance +https://marshmallow.readthedocs.io/en/stable/upgrading.html + Copyright 2017 Allen Institute diff --git a/argschema/__init__.py b/argschema/__init__.py index f28aafde..64e09c8e 100644 --- a/argschema/__init__.py +++ b/argschema/__init__.py @@ -1,8 +1,7 @@ '''argschema: flexible definition, validation and setting of parameters''' -from .fields import InputFile, InputDir, OutputFile, OptionList # noQA:F401 -from .schemas import ArgSchema # noQA:F401 -from .argschema_parser import ArgSchemaParser # noQA:F401 -from .deprecated import JsonModule, ModuleParameters # noQA:F401 +from argschema.fields import InputFile, InputDir, OutputFile # noQA:F401 +from argschema.schemas import ArgSchema # noQA:F401 +from argschema.argschema_parser import ArgSchemaParser # noQA:F401 __version__ = "3.0.1" diff --git a/argschema/argschema_parser.py b/argschema/argschema_parser.py index f29e081e..03b1ffce 100644 --- a/argschema/argschema_parser.py +++ b/argschema/argschema_parser.py @@ -1,110 +1,33 @@ '''Module that contains the base class ArgSchemaParser which should be subclassed when using this library ''' -import json +from typing import List, Sequence, Dict, Optional, Union, Tuple, Type, TypeVar import logging -import copy from . import schemas from . import utils -from . import fields import marshmallow as mm - - -def contains_non_default_schemas(schema, schema_list=[]): - """returns True if this schema contains a schema which was not an instance of DefaultSchema - - Parameters - ---------- - schema : marshmallow.Schema - schema to check - schema_list : - (Default value = []) - - Returns - ------- - bool - does this schema only contain schemas which are subclassed from schemas.DefaultSchema - - """ - if not isinstance(schema, schemas.DefaultSchema): - return True - for k, v in schema.declared_fields.items(): - if isinstance(v, mm.fields.Nested): - if type(v.schema) in schema_list: - return False - else: - schema_list.append(type(v.schema)) - if contains_non_default_schemas(v.schema, schema_list): - return True - return False - - -def is_recursive_schema(schema, schema_list=[]): - """returns true if this schema contains recursive elements - - Parameters - ---------- - schema : marshmallow.Schema - schema to check - schema_list : - (Default value = []) - - Returns - ------- - bool - does this schema contain any recursively defined schemas - - """ - for k, v in schema.declared_fields.items(): - if isinstance(v, mm.fields.Nested): - if type(v.schema) in schema_list: - return True - else: - schema_list.append(type(v.schema)) - if is_recursive_schema(v.schema, schema_list): - return True - return False - - -def fill_defaults(schema, args): - """DEPRECATED, function to fill in default values from schema into args - bug: goes into an infinite loop when there is a recursively defined schema - - Parameters - ---------- - schema : marshmallow.Schema - schema to get defaults from - args : - - - Returns - ------- - dict - dictionary with missing default values filled in - - """ - - defaults = [] - - # find all of the schema entries with default values - schemata = [(schema, [])] - while schemata: - subschema, path = schemata.pop() - for k, v in subschema.declared_fields.items(): - if isinstance(v, mm.fields.Nested): - schemata.append((v.schema, path + [k])) - elif v.default != mm.missing: - defaults.append((path + [k], v.default)) - - # put the default entries into the args dictionary - args = copy.deepcopy(args) - for path, val in defaults: - d = args - for path_item in path[:-1]: - d = d.setdefault(path_item, {}) - if path[-1] not in d: - d[path[-1]] = val - return args +from .sources.json_source import JsonSource, JsonSink +from .sources.yaml_source import YamlSource, YamlSink +from .sources.source import ( + ConfigurableSource, + ConfigurableSink, + NonconfigurationError, + MultipleConfigurationError, +) + + +SourceType = Union[ConfigurableSource, Type[ConfigurableSource]] +RegistrableSources = Union[ + None, + SourceType, + Sequence[SourceType], +] +SinkType = Union[ConfigurableSink, Type[ConfigurableSink]] +RegistrableSinks = Union[ + None, + SinkType, + Sequence[SinkType], +] class ArgSchemaParser(object): @@ -118,15 +41,20 @@ class ArgSchemaParser(object): Parameters ---------- input_data : dict or None - dictionary parameters instead of --input_json + dictionary parameters to fall back on if not source is given or configured via command line schema_type : schemas.ArgSchema the schema to use to validate the parameters output_schema_type : marshmallow.Schema - the schema to use to validate the output_json, used by self.output + the schema to use to validate the output, used by self.output + input_sources : Sequence[argschema.sources.source.ConfigurableSource] + each of these will be considered as a potential source of input data + output_sinks : Sequence[argschema.sources.source.ConfigurableSource] + each of these will be considered as a potential sink for output data args : list or None - command line arguments passed to the module, if None use argparse to parse the command line, set to [] if you want to bypass command line parsing + command line arguments passed to the module, if None use argparse to parse the command line, + set to [] if you want to bypass command line parsing logger_name : str - name of logger from the logging module you want to instantiate 'argschema' + name of logger from the logging module you want to instantiate default ('argschema') Raises ------- @@ -137,12 +65,38 @@ class ArgSchemaParser(object): """ default_schema = schemas.ArgSchema default_output_schema = None + default_sources: Tuple[SourceType] = (JsonSource,) + default_sinks: Tuple[SinkType] = (JsonSink,) + + @property + def input_sources(self) -> List[ConfigurableSource]: + if not hasattr(self, "_input_sources"): + self._input_sources: List[ConfigurableSource] = [] + return self._input_sources + + @property + def output_sinks(self) -> List[ConfigurableSink]: + if not hasattr(self, "_output_sinks"): + self._output_sinks: List[ConfigurableSink] = [] + return self._output_sinks + + @property + def io_schemas(self) -> List[mm.Schema]: + if not hasattr(self, "_io_schemas"): + self._io_schemas: List[mm.Schema] = [] + return self._io_schemas + + @io_schemas.setter + def io_schemas(self, schemas: List[mm.Schema]): + self._io_schemas = schemas def __init__(self, input_data=None, # dictionary input as option instead of --input_json schema_type=None, # schema for parsing arguments output_schema_type=None, # schema for parsing output_json args=None, + input_sources=None, + output_sinks=None, logger_name=__name__): if schema_type is None: @@ -154,30 +108,201 @@ def __init__(self, self.logger = self.initialize_logger(logger_name, 'WARNING') self.logger.debug('input_data is {}'.format(input_data)) - # convert schema to argparse object - p = utils.schema_argparser(self.schema) - argsobj = p.parse_args(args) - argsdict = utils.args_to_dict(argsobj, self.schema) - self.logger.debug('argsdict is {}'.format(argsdict)) + self.register_sources(input_sources) + self.register_sinks(output_sinks) + + argsdict = self.parse_command_line(args) + resolved_args = self.resolve_inputs(input_data, argsdict) + + self.output_sink = self.__get_output_sink_from_config(resolved_args) + self.args = self.load_schema_with_defaults(self.schema, resolved_args) + + self.output_schema_type = output_schema_type + self.logger = self.initialize_logger( + logger_name, self.args.get('log_level')) + + def register_sources( + self, + sources: RegistrableSources + ): + """consolidate a list of the input source configuration schemas + + Parameters + ---------- + sources : (sequence of) ConfigurableSource or None + Each source will be registered (and may then be configured by data + passed to this parser). If None is argued, the default_sources + associated with this class will be registered. + + """ + + if isinstance(sources, (ConfigurableSource, type)): + coerced_sources: Sequence[SourceType] = [sources] + elif sources is None: + coerced_sources = self.default_sources + else: + coerced_sources = sources + + for source in coerced_sources: + if isinstance(source, type): + source = source() + self.io_schemas.append(source.schema) + self.input_sources.append(source) + + def register_sinks( + self, + sinks: RegistrableSinks + ): + """Consolidate a list of the output sink configuration schemas + + Parameters + ---------- + sinks : (sequence of) ConfigurableSink or None + Each sink will be registered (and may then be configured by data + passed to this parser). If None is argued, the default_sinks + associated with this class will be registered. + + """ - if argsobj.input_json is not None: - fields.files.validate_input_path(argsobj.input_json) - with open(argsobj.input_json, 'r') as j: - jsonargs = json.load(j) + if isinstance(sinks, (ConfigurableSink, type)): + coerced_sinks: Sequence[SinkType] = [sinks] + elif sinks is None: + coerced_sinks = self.default_sinks else: - jsonargs = input_data if input_data else {} + coerced_sinks = sinks - # merge the command line dictionary into the input json - args = utils.smart_merge(jsonargs, argsdict) + for sink in coerced_sinks: + if isinstance(sink, type): + sink = sink() + self.io_schemas.append(sink.schema) + self.output_sinks.append(sink) + + def parse_command_line(self, args: Optional[List[str]]) -> Dict: + """Build a command line parser from the input schemas and + configurations. Parse command line arguments using this parser + + Parameters + ---------- + args : list of str or None + Will be passed directly to argparse's parse_args. If None, sys.argv + will be used. If provided, should be formatted like: + ["positional_arg", "--optional_arg", "optional_value"] + + Returns + ------- + argsdict : dict + a (potentially nested) dictionary of parsed command line arguments + + """ + parser = utils.schema_argparser(self.schema, self.io_schemas) + argsobj = parser.parse_args(args) + argsdict = utils.args_to_dict(argsobj, [self.schema] + self.io_schemas) + self.logger.debug('argsdict is {}'.format(argsdict)) + return argsdict + + def resolve_inputs(self, input_data: Dict, argsdict: Dict) -> Dict: + """ Resolve input source by checking candidate sources against + constructor and command line arguments + + Parameters + ---------- + input_data : dict + Manually (on ArgschemaParser construction) specified parameters. + Will be overridden if values are successfully extracted from + argsdict. + argsdict : dict + Command line parameters, parsed into a nested dictionary. + + Returns + ------- + args : dict + A fully merged (possibly nested) collection of inputs. May draw from + 1. input data + 2. the argsdict + 3. any configurable sources whose config schemas are satisfied + by values in the above + + """ + + config_data = self.__get_input_data_from_config(input_data) + if config_data is not None: + input_data = config_data + + config_data = self.__get_input_data_from_config( + utils.smart_merge({}, argsdict)) + if config_data is not None: + input_data = config_data + + args = utils.smart_merge(input_data, argsdict) self.logger.debug('args after merge {}'.format(args)) - # validate with load! - result = self.load_schema_with_defaults(self.schema, args) + return args - self.args = result - self.output_schema_type = output_schema_type - self.logger = self.initialize_logger( - logger_name, self.args.get('log_level')) + def __get_output_sink_from_config(self, d): + """private function to check for ConfigurableSink configuration in a dictionary and return a configured ConfigurableSink + + Parameters + ---------- + d : dict + dictionary to look for ConfigurableSink Configuration parameters in + + Returns + ------- + ConfigurableSink + A configured ConfigurableSink + + Raises + ------ + MultipleConfigurationError + If more than one Sink is configured + """ + output_set = False + output_sink = None + for sink in self.output_sinks: + try: + sink.load_config(d) + + if output_set: + raise MultipleConfigurationError( + "more then one OutputSink configuration present in {}".format(d)) + output_sink = sink + output_set = True + except NonconfigurationError: + pass + return output_sink + + def __get_input_data_from_config(self, d): + """private function to check for ConfigurableSource configurations in a dictionary + and return the data if it exists + + Parameters + ---------- + d : dict + dictionary to look for InputSource configuration parameters in + + Returns + ------- + dict or None + dictionary of InputData if it found a valid configuration, None otherwise + + Raises + ------ + MultipleConfigurationError + if more than one InputSource is configured + """ + input_set = False + input_data = None + for source in self.input_sources: + try: + source.load_config(d) + input_data = source.get_dict() + if input_set: + raise MultipleConfigurationError( + "more then one InputSource configuration present in {}".format(d)) + input_set = True + except NonconfigurationError as e: + pass + return input_data def get_output_json(self, d): """method for getting the output_json pushed through validation @@ -210,7 +335,7 @@ def get_output_json(self, d): return output_json - def output(self, d, output_path=None, **json_dump_options): + def output(self, d, sink=None): """method for outputing dictionary to the output_json file path after validating it through the output_schema_type @@ -218,22 +343,20 @@ def output(self, d, output_path=None, **json_dump_options): ---------- d:dict output dictionary to output - output_path: str - path to save to output file, optional (with default to self.mod['output_json'] location) - **json_dump_options : - will be passed through to json.dump + sink: argschema.sources.source.ConfigurableSink + output_sink to output to (optional default to self.output_source) Raises ------ marshmallow.ValidationError If any of the output dictionary doesn't meet the output schema """ - if output_path is None: - output_path = self.args['output_json'] - output_json = self.get_output_json(d) - with open(output_path, 'w') as fp: - json.dump(output_json, fp, **json_dump_options) + output_d = self.get_output_json(d) + if sink is not None: + sink.put_dict(output_d) + else: + self.output_sink.put_dict(output_d) def load_schema_with_defaults(self, schema, args): """method for deserializing the arguments dictionary (args) @@ -258,20 +381,6 @@ def load_schema_with_defaults(self, schema, args): because these won't work with loading defaults. """ - is_recursive = is_recursive_schema(schema) - is_non_default = contains_non_default_schemas(schema) - if (not is_recursive) and is_non_default: - # throw a warning - self.logger.warning("""DEPRECATED:You are using a Schema which contains - a Schema which is not subclassed from argschema.DefaultSchema, - default values will not work correctly in this case, - this use is deprecated, and future versions will not fill in default - values when you use non-DefaultSchema subclasses""") - args = fill_defaults(schema, args) - if is_recursive and is_non_default: - raise mm.ValidationError( - 'Recursive schemas need to subclass argschema.DefaultSchema else defaults will not work') - # load the dictionary via the schema result = utils.load(schema, args) diff --git a/argschema/autodoc.py b/argschema/autodoc.py index ee11ab91..58db57a8 100644 --- a/argschema/autodoc.py +++ b/argschema/autodoc.py @@ -3,7 +3,10 @@ from argschema.utils import get_description_from_field from argschema.argschema_parser import ArgSchemaParser import inspect - +try: + from inspect import getfullargspec +except ImportError: + from inspect import getargspec as getfullargspec FIELD_TYPE_MAP = {v: k for k, v in mm.Schema.TYPE_MAPPING.items()} @@ -120,6 +123,7 @@ def setup(app): except Exception as e: # in case this fails for some reason, note it as unknown # TODO handle this more elegantly, identify and patch up such cases + print(e) field_line += "unknown,unknown" lines.append(field_line) # lines.append(table_line) diff --git a/argschema/deprecated.py b/argschema/deprecated.py deleted file mode 100644 index 6d8ce615..00000000 --- a/argschema/deprecated.py +++ /dev/null @@ -1,12 +0,0 @@ -from .argschema_parser import ArgSchemaParser -from .schemas import ArgSchema - - -class JsonModule(ArgSchemaParser): - """deprecated name of ArgSchemaParser""" - pass - - -class ModuleParameters(ArgSchema): - """deprecated name of ArgSchema""" - pass diff --git a/argschema/fields/__init__.py b/argschema/fields/__init__.py index c752233c..94857d9b 100644 --- a/argschema/fields/__init__.py +++ b/argschema/fields/__init__.py @@ -3,7 +3,6 @@ from marshmallow.fields import __all__ as __mmall__ # noQA:F401 from .files import OutputFile, InputDir, InputFile, OutputDir # noQA:F401 from .numpyarrays import NumpyArray # noQA:F401 -from .deprecated import OptionList # noQA:F401 from .loglevel import LogLevel # noQA:F401 from .slice import Slice # noQA:F401 diff --git a/argschema/fields/deprecated.py b/argschema/fields/deprecated.py deleted file mode 100644 index 717b876b..00000000 --- a/argschema/fields/deprecated.py +++ /dev/null @@ -1,34 +0,0 @@ -'''marshmallow fields related to choosing amongst a set of options''' -import marshmallow as mm -import logging -logger = logging.getLogger('argschema') - - -class OptionList(mm.fields.Field): - """OptionList is a marshmallow field which enforces that this field - is one of a finite set of options. - OptionList(options,*args,**kwargs) where options is a list of - json compatible options which this option will be enforced to belong - - Parameters - ---------- - options : list - A list of python objects of which this field must be one of - kwargs : dict - the same as any :class:`Field` receives - """ - - def __init__(self, options, **kwargs): - self.options = options - logger.warning( - 'DEPRECATED: use validate=mm.validate.OneOf([a,b,c...]) in field definition instead') - super(OptionList, self).__init__(**kwargs) - - def _serialize(self, value, attr, obj): - return value - - def _validate(self, value): - if value not in self.options: - raise mm.ValidationError("%s is not a valid option" % value) - - return value diff --git a/argschema/fields/files.py b/argschema/fields/files.py index 1172d1b8..a27d5851 100644 --- a/argschema/fields/files.py +++ b/argschema/fields/files.py @@ -87,7 +87,7 @@ def _validate(self, value): path = os.path.dirname(value) except Exception as e: # pragma: no cover raise mm.ValidationError( - "%s cannot be os.path.dirname-ed" % value) # pragma: no cover + "{} cannot be os.path.dirname-ed: {}".format(value, e)) # pragma: no cover validate_outpath(path) class OutputDirModeException(Exception): @@ -157,6 +157,7 @@ def validate_input_path(value): except Exception as value: raise mm.ValidationError("%s is not readable" % value) + class InputDir(mm.fields.Str): """InputDir is :class:`marshmallow.fields.Str` subclass which is a path to a a directory that exists and that the user can access diff --git a/argschema/fields/numpyarrays.py b/argschema/fields/numpyarrays.py index a369ad9e..98e38cc2 100644 --- a/argschema/fields/numpyarrays.py +++ b/argschema/fields/numpyarrays.py @@ -20,8 +20,6 @@ class NumpyArray(mm.fields.List): def __init__(self, dtype=None, *args, **kwargs): self.dtype = dtype - if "cli_as_single_argument" not in kwargs: - kwargs["cli_as_single_argument"] = True super(NumpyArray, self).__init__(mm.fields.Field, *args, **kwargs) def _deserialize(self, value, attr, obj, **kwargs): @@ -29,8 +27,8 @@ def _deserialize(self, value, attr, obj, **kwargs): return np.array(value, dtype=self.dtype) except ValueError as e: raise mm.ValidationError( - 'Cannot create numpy array with type {} from data.'.format( - self.dtype)) + 'Cannot create numpy array with type {} from data: {}.'.format( + self.dtype, e)) def _serialize(self, value, attr, obj, **kwargs): if value is None: diff --git a/argschema/fields/slice.py b/argschema/fields/slice.py index 1a5d0fa8..232622cc 100644 --- a/argschema/fields/slice.py +++ b/argschema/fields/slice.py @@ -17,7 +17,7 @@ def __init__(self, **kwargs): kwargs['metadata'] = kwargs.get( 'metadata', {'description': 'slice the dataset'}) kwargs['default'] = kwargs.get('default', slice(None)) - super(Slice, self).__init__( **kwargs) + super(Slice, self).__init__(**kwargs) def _deserialize(self, value, attr, obj, **kwargs): try: diff --git a/argschema/schemas.py b/argschema/schemas.py index 52dac07c..9ad7831f 100644 --- a/argschema/schemas.py +++ b/argschema/schemas.py @@ -1,5 +1,5 @@ import marshmallow as mm -from .fields import LogLevel, InputFile, OutputFile +from .fields import LogLevel class DefaultSchema(mm.Schema): @@ -34,11 +34,10 @@ class ArgSchema(DefaultSchema): input_json and output_json files and the log_level """ - input_json = InputFile( - description="file path of input json file") - - output_json = OutputFile( - description="file path to output json file") + # input_json = InputFile( + # description= "file path of input json file") + # output_json = OutputFile( + # description= "file path to output json file") log_level = LogLevel( default='ERROR', description="set the logging level of the module") diff --git a/argschema/sources/__init__.py b/argschema/sources/__init__.py new file mode 100644 index 00000000..76bfa557 --- /dev/null +++ b/argschema/sources/__init__.py @@ -0,0 +1,2 @@ +from argschema.sources.source import ConfigurableSource, ConfigurableSink +from argschema.sources.json_source import JsonSource, JsonSink diff --git a/argschema/sources/json_source.py b/argschema/sources/json_source.py new file mode 100644 index 00000000..22ec191f --- /dev/null +++ b/argschema/sources/json_source.py @@ -0,0 +1,42 @@ +from argschema.sources.source import ConfigurableSource, ConfigurableSink +import json +import marshmallow as mm +import argschema + + +class JsonInputConfigSchema(mm.Schema): + input_json = argschema.fields.InputFile(required=True, + description='filepath to input_json') + + +class JsonOutputConfigSchema(mm.Schema): + output_json = argschema.fields.OutputFile(required=True, + description='filepath to save output_json') + output_json_indent = argschema.fields.Int(required=False, + description='whether to indent options or not') + + +class JsonSource(ConfigurableSource): + """ A configurable source which reads values from a json. Expects + --input_json + to be specified. + """ + + ConfigSchema = JsonInputConfigSchema + + def get_dict(self): + with open(self.config["input_json"], 'r') as fp: + return json.load(fp,) + + +class JsonSink(ConfigurableSink): + """ A configurable sink which writes values to a json. Expects + --output_json + to be specified. + """ + ConfigSchema = JsonOutputConfigSchema + + def put_dict(self, data): + with open(self.config["output_json"], 'w') as fp: + json.dump( + data, fp, indent=self.config.get("output_json_indent", None)) diff --git a/argschema/sources/source.py b/argschema/sources/source.py new file mode 100644 index 00000000..0d4e49c3 --- /dev/null +++ b/argschema/sources/source.py @@ -0,0 +1,143 @@ +import abc +from typing import Dict, Type + +import marshmallow as mm + + +class ConfigurationError(mm.ValidationError): + """Base Exception class for configurations""" + pass + + +class MisconfigurationError(ConfigurationError): + """Exception when a configuration was present in part but failed + validation""" + pass + + +class NonconfigurationError(ConfigurationError): + """Exception when a configuration is simply completely missing""" + pass + + +class MultipleConfigurationError(ConfigurationError): + """Exception when there is more than one valid configuration""" + pass + + +def d_contains_any_fields(schema: mm.Schema, data: Dict) -> bool: + """function to test if a dictionary contains any elements of a schema + + Parameters + ---------- + schema: marshmallow.Schema + a marshmallow schema to test d with + data: dict + the dictionary to test whether it contains any elements of a schema + + Returns + ------- + bool: + True/False whether d contains any elements of a schema. If a schema + contains no elements, returns True + """ + + if len(schema.declared_fields) == 0: + return True + + for field_name, field in schema.declared_fields.items(): + if field_name in data.keys(): + if data[field_name] is not None: + return True + + return False + + + +class Configurable(object): + """Base class for sources and sinks of marshmallow-validatable + parameters. + + Parameters + ---------- + **default_config : dict + Optionally, attempt to load a config immediately upon construction + + Attributes + ---------- + ConfigSchema : type(mm.Schema), class attribute + Defines a schema for this Configurable's config. + config : dict + Stores for values loaded according to this instance's schema + schema : mm.Schema + An instance of this class's ConfigSchema. Used to validate potential + configurations. + + """ + + ConfigSchema: Type[mm.Schema] = mm.Schema + + def __init__(self, **default_config: Dict): + + self.schema: mm.Schema = self.ConfigSchema() + self.config: Dict = {} + + if default_config: + self.load_config(default_config) + + def load_config(self, candidate: Dict): + """Attempt to configure this object inplace using values in a candidate + dictionary. + + Parameters + ---------- + candidate : dict + Might satisfy (and will be loaded using) this object's schema. + + Raises + ------ + NonconfigurationError : Indicates that the candidate was completely + inapplicable. + MisconfigurationError : Indicates that the candidate did not adequetly + satisfy this configurable's schema. + + """ + + if candidate is None: + candidate = {} + + if not d_contains_any_fields(self.schema, candidate): + raise NonconfigurationError( + "This source is not present in \n {}".format(candidate)) + + try: + self.config = self.schema.load(candidate, unknown=mm.EXCLUDE) + except mm.ValidationError as e: + raise MisconfigurationError( + "Source incorrectly configured\n {}".format(e)) + + +class ConfigurableSource(Configurable): + def get_dict(self) -> Dict: + """Produces a dictionary, potentially using information from this + source's config. + + Returns + ------- + dict : Suitable for validatation by some external marshmallow schema. + + """ + raise NotImplementedError() + + +class ConfigurableSink(Configurable): + def put_dict(self, data: Dict): + """Writes a dictionary, potentially using information from this + sink's config. + + Parameters + ---------- + dict : Will be written to some external sink. + + """ + raise NotImplementedError() diff --git a/argschema/sources/url_source.py b/argschema/sources/url_source.py new file mode 100644 index 00000000..e7ee6fd1 --- /dev/null +++ b/argschema/sources/url_source.py @@ -0,0 +1,37 @@ +from argschema.sources import ConfigurableSource +from argschema.schemas import DefaultSchema +from argschema.fields import Str,Int +from argschema import ArgSchemaParser +import requests +try: + from urllib.parse import urlunparse +except: + from urlparse import urlunparse + +class UrlSourceConfig(DefaultSchema): + input_host = Str(required=True, description="host of url") + input_port = Int(required=False, default=None, description="port of url") + input_url = Str(required=True, description="location on host of input") + input_protocol = Str(required=False, default='http', description="url protocol to use") + +class UrlSource(ConfigurableSource): + """ A configurable source which obtains values by making a GET request, + expecting a JSON response. + """ + ConfigSchema = UrlSourceConfig + + def get_dict(self): + netloc = self.config["input_host"] + if self.config["input_port"] is not None: + netloc = "{}:{}".format(netloc, self.config["input_port"]) + + url = urlunparse(( + self.config["input_protocol"], + netloc, + self.config["input_url"], + None, None, None + )) + + response = requests.get(url) + response.raise_for_status() + return response.json() diff --git a/argschema/sources/yaml_source.py b/argschema/sources/yaml_source.py new file mode 100644 index 00000000..3b6a90db --- /dev/null +++ b/argschema/sources/yaml_source.py @@ -0,0 +1,38 @@ +import yaml +from argschema.sources.source import ConfigurableSource, ConfigurableSink +import argschema +import marshmallow as mm + + +class YamlInputConfigSchema(mm.Schema): + input_yaml = argschema.fields.InputFile(required=True, + description='filepath to input yaml') + + +class YamlOutputConfigSchema(mm.Schema): + output_yaml = argschema.fields.OutputFile(required=True, + description='filepath to save output yaml') + + +class YamlSource(ConfigurableSource): + """ A configurable source which reads values from a yaml. Expects + --input_yaml + to be specified. + """ + ConfigSchema = YamlInputConfigSchema + + def get_dict(self): + with open(self.config["input_yaml"], 'r') as fp: + return yaml.load(fp, Loader=yaml.FullLoader) + + +class YamlSink(ConfigurableSink): + """ A configurable sink which writes values to a yaml. Expects + --output_yaml + to be specified. + """ + ConfigSchema = YamlOutputConfigSchema + + def put_dict(self, data): + with open(self.config["output_yaml"], 'w') as fp: + yaml.dump(data, fp, default_flow_style=False) diff --git a/argschema/utils.py b/argschema/utils.py index 9e9c640d..f5905300 100644 --- a/argschema/utils.py +++ b/argschema/utils.py @@ -2,7 +2,6 @@ marshmallow schemas to argparse and merging dictionaries from both systems ''' import logging -import warnings import ast import argparse from operator import add @@ -55,11 +54,7 @@ def get_type_from_field(field): callable Function to call to cast argument to """ - if (isinstance(field, fields.List) and - not field.metadata.get("cli_as_single_argument", False)): - return list - else: - return FIELD_TYPE_MAP.get(type(field), str) + return FIELD_TYPE_MAP.get(type(field), str) def cli_error_dict(arg_path, field_type, index=0): @@ -88,15 +83,44 @@ def cli_error_dict(arg_path, field_type, index=0): return {arg_path[index]: cli_error_dict(arg_path, field_type, index + 1)} -def args_to_dict(argsobj, schema=None): +def get_field_def_from_schema(parts, schema): + """function to get a field_definition from a particular key, specified by it's parts list + + Parameters + ---------- + parts : list[str] + the list of keys to get this schema + schema: marshmallow.Schema + the marshmallow schema to look up this key + + Returns + ------- + marshmallow.Field or None + returns the field in the schema if it exists, otherwise returns None + """ + current_schema = schema + for part in parts: + if part not in current_schema.fields.keys(): + return None + else: + if current_schema.only and part not in current_schema.only: + field_def = None + else: + field_def = current_schema.fields[part] + if isinstance(field_def, fields.Nested): + current_schema = field_def.schema + return field_def + + +def args_to_dict(argsobj, schemas=None): """function to convert namespace returned by argsparse into a nested dictionary Parameters ---------- argsobj : argparse.Namespace Namespace object returned by standard argparse.parse function - schema : marshmallow.Schema - Optional schema which will be used to cast fields via `FIELD_TYPE_MAP` + schemas : list[marshmallow.Schema] + Optional list of schemas which will be used to cast fields via `FIELD_TYPE_MAP` Returns @@ -110,18 +134,19 @@ def args_to_dict(argsobj, schema=None): errors = {} field_def = None for field in argsdict.keys(): - current_schema = schema parts = field.split('.') root = d for i in range(len(parts)): - if current_schema is not None: - if current_schema.only and parts[i] not in current_schema.only: - field_def = None - else: - field_def = current_schema.fields[parts[i]] - if isinstance(field_def, fields.Nested): - current_schema = field_def.schema + if i == (len(parts) - 1): + field_def = None + for schema in schemas: + field_def = get_field_def_from_schema(parts, schema) + if field_def is not None: + break + + # field_def = next(get_field_def(parts,schema) for schema in schemas if field_in_schema(parts,schema)) + value = argsdict.get(field) if value is not None: try: @@ -335,17 +360,6 @@ def build_schema_arguments(schema, arguments=None, path=None, description=None): if isinstance(validator, mm.validate.OneOf): arg['help'] += " (valid options are {})".format(validator.choices) - if (isinstance(field, mm.fields.List) and - not field.metadata.get("cli_as_single_argument", False)): - warn_msg = ("'{}' is using old-style command-line syntax with " - "each element as a separate argument. This will " - "not be supported in argschema after " - "2.0. See http://argschema.readthedocs.io/en/" - "master/user/intro.html#command-line-specification" - " for details.").format(arg_name) - warnings.warn(warn_msg, FutureWarning) - arg['nargs'] = '*' - # do type mapping after parsing so we can raise validation errors arg['type'] = str @@ -361,35 +375,42 @@ def build_schema_arguments(schema, arguments=None, path=None, description=None): return arguments -def schema_argparser(schema): +def schema_argparser(schema, additional_schemas=None): """given a jsonschema, build an argparse.ArgumentParser Parameters ---------- schema : argschema.schemas.ArgSchema schema to build an argparser from - + additional_schemas : list[marshmallow.schema] + list of additional schemas to add to the command line arguments Returns ------- argparse.ArgumentParser - the represents the schema + that represents the schemas """ - # build up a list of argument groups using recursive function - # to traverse the tree, root node gets the description given by doc string - # of the schema - arguments = build_schema_arguments(schema, description=schema.__doc__) - # make the root schema appeear first rather than last - arguments = [arguments[-1]] + arguments[0:-1] + if additional_schemas is not None: + schema_list = [schema] + additional_schemas + else: + schema_list = [schema] parser = argparse.ArgumentParser() - - for arg_group in arguments: - group = parser.add_argument_group( - arg_group['title'], arg_group['description']) - for arg_name, arg in arg_group['args'].items(): - group.add_argument(arg_name, **arg) + for s in schema_list: + # build up a list of argument groups using recursive function + # to traverse the tree, root node gets the description given by doc string + # of the schema + arguments = build_schema_arguments(s, description=schema.__doc__) + + # make the root schema appeear first rather than last + arguments = [arguments[-1]] + arguments[0:-1] + + for arg_group in arguments: + group = parser.add_argument_group( + arg_group['title'], arg_group['description']) + for arg_name, arg in arg_group['args'].items(): + group.add_argument(arg_name, **arg) return parser @@ -403,16 +424,10 @@ def load(schema, d): schema that you want to use to validate d: dict dictionary to validate and load - Returns ------- dict deserialized and validated dictionary - - Raises - ------ - marshmallow.ValidationError - if the dictionary does not conform to the schema """ results = schema.load(d) @@ -429,16 +444,10 @@ def dump(schema, d): schema that you want to use to validate and dump d: dict dictionary to validate and dump - Returns ------- dict serialized and validated dictionary - - Raises - ------ - marshmallow.ValidationError - if the dictionary does not conform to the schema """ errors=schema.validate(d) if len(errors)>0: diff --git a/doc_requirements.txt b/doc_requirements.txt index 8e00ab27..36ddf95f 100644 --- a/doc_requirements.txt +++ b/doc_requirements.txt @@ -2,6 +2,6 @@ sphinxcontrib-napoleon sphinxcontrib-programoutput sphinxcontrib-inlinesyntaxhighlight numpy -marshmallow==3.0.0rc6 +marshmallow==3.6.1 pytest rstcheck diff --git a/docs/api/argschema.rst b/docs/api/argschema.rst index a8660473..a17ec924 100644 --- a/docs/api/argschema.rst +++ b/docs/api/argschema.rst @@ -19,14 +19,6 @@ argschema\.argschema\_parser module :undoc-members: :show-inheritance: -argschema\.deprecated module ----------------------------- - -.. automodule:: argschema.deprecated - :members: - :undoc-members: - :show-inheritance: - argschema\.schemas module ------------------------- diff --git a/docs/tests/fields.rst b/docs/tests/fields.rst index 933bb8b2..d5c65a9d 100644 --- a/docs/tests/fields.rst +++ b/docs/tests/fields.rst @@ -4,14 +4,6 @@ fields package Submodules ---------- -fields\.test\_deprecated module -------------------------------- - -.. automodule:: fields.test_deprecated - :members: - :undoc-members: - :show-inheritance: - fields\.test\_files module -------------------------- diff --git a/docs/user/intro.rst b/docs/user/intro.rst index c8235309..87a342e9 100644 --- a/docs/user/intro.rst +++ b/docs/user/intro.rst @@ -1,5 +1,18 @@ User Guide ===================================== +Installation +------------ +install via source code + +:: + + $ python setup.py install + +or pip + +:: + + $ pip install argschema Your First Module ------------------ @@ -80,19 +93,19 @@ argschema uses marshmallow (http://marshmallow.readthedocs.io/) under the hood to define the parameters schemas. It comes with a basic set of fields that you can use to define your schemas. One powerful feature of Marshmallow is that you can define custom fields that do arbitrary validation. -:class:`~argschema.fields` contains all the built-in marshmallow fields, +:class:`argschema.fields` contains all the built-in marshmallow fields, but also some useful custom ones, -such as :class:`~argschema.fields.InputFile`, -:class:`~argschema.fields.OutputFile`, -:class:`~argschema.fields.InputDir` that validate that the paths exist and have the proper +such as :class:`argschema.fields.InputFile`, +:class:`argschema.fields.OutputFile`, +:class:`argschema.fields.InputDir` that validate that the paths exist and have the proper permissions to allow files to be read or written. -Other fields, such as :class:`~argschema.fields.NumpyArray` will deserialize ordered lists of lists +Other fields, such as :class:`argschema.fields.NumpyArray` will deserialize ordered lists of lists directly into a numpy array of your choosing. -Finally, an important Field to know is :class:`~argschema.fields.Nested`, which allows you to define +Finally, an important Field to know is :class:`argschema.fields.Nested`, which allows you to define heirarchical nested structures. Note, that if you use Nested schemas, your Nested schemas should -subclass :class:`~argschema.schemas.DefaultSchema` in order that they properly fill in default values, +subclass :class:`argschema.schemas.DefaultSchema` in order that they properly fill in default values, as :class:`marshmallow.Schema` does not do that by itself. Another common question about :class:`~argschema.fields.Nested` is how you specify that @@ -158,25 +171,6 @@ passed by the shell. If there are spaces in the value, it will need to be wrapped in quotes, and any special characters will need to be escaped with \. Booleans are set with True or 1 for true and False or 0 for false. -An exception to this rule is list formatting. If a schema contains a -:class:`~marshmallow.fields.List` and does not set the -`cli_as_single_argument` keyword argument to True, lists will be parsed -as `--list_name ...`. In argschema 2.0 lists will be -parsed in the same way as other arguments, as it allows more flexibility -in list types and more clearly represents the intended data structure. - -An example script showing old and new list settings: - -.. literalinclude:: ../../examples/deprecated_example.py - :caption: deprecated_example.py - -Running this code can demonstrate the differences in command-line usage: - -.. command-output:: python deprecated_example.py --help - :cwd: /../examples - -.. command-output:: python deprecated_example.py --list_old 9.1 8.2 7.3 --list_new [6.4,5.5,4.6] - :cwd: /../examples We can explore some typical examples of command line usage with the following script: @@ -200,13 +194,56 @@ example, having an invalid literal) we will see a casting validation error: argschema does not support setting :class:`~marshmallow.fields.Dict` at the command line. +Alternate Sources/Sinks +----------------------- +Json files are just one way that you might decide to serialize module parameters or outputs. +Argschema by default provides json support because that is what we use most frequently at the Allen Institute, +however we have generalized the concept to allow :class:`argschema.ArgSchemaParser` to plugin alternative +"sources" and "sinks" of dictionary inputs and outputs. + +For example, yaml is another reasonable choice for storing nested key-value stores. +:class:`argschema.argschema_parser.ArgSchemaYamlParser` demonstrates just that functionality. So now +input_yaml and output_yaml can be specified instead. + +Furthermore, you can pass an ArgSchemaParser an :class:`argschema.sources.ArgSource` object which +implements a get_dict method, and any :class:`argschema.ArgSchemaParser` will get its input parameters +from that dictionary. Importantly, this is true even when the original module author didn't +explicitly support passing parameters from that mechanism, and the parameters will still be +deserialized and validated in a uniform manner. + +Similarly you can pass an :class:`argschema.sources.ArgSink` object which implements a put_dict method, +and :class:`argschema.ArgSchemaParser.output` will output the dictionary however that +:class:`argschema.sources.ArgSink` specifies it should. + +Finally, both :class:`argschema.sources.ArgSource` and :class:`argschema.sources.ArgSink` +have a property called ConfigSchema, which is a :class:`marshmallow.Schema` for how to deserialize +the kwargs to it's init class. + +For example, the default :class:`argschema.sources.json_source.JsonSource` has one string +field of 'input_json'. This is how :class:`argschema.ArgSchemaParser` is told what keys and values +should be read to initialize a :class:`argschema.sources.ArgSource` or + :class:`argschema.sources.ArgSink` instance. + +So for example, if you wanted to define a :class:`argschema.sources.ArgSource` which loaded a dictionary +from a particular host, port and url, and a module which had a command line interface for setting that +host port and url you could do so like this. + +.. literalinclude:: ../../test/sources/url_source.py + +so now a UrlArgSchemaParser would expect command line flags of '--input_host' and '--input_url', and +optionally '--input_port','--input_protocol' (or look for them in input_data) and will look to download +the json from that http location via requests. In addition, an existing :class:`argschema.ArgSchemaParser` +module could be simply passed a configured UrlSource via input_source, +and it would get its parameters from there. + Sphinx Documentation -------------------- argschema comes with a autodocumentation feature for Sphnix which will help you automatically -add documentation of your Schemas and ArgSchemaParser classes in your project. This is how the -documentation of the :doc:`../tests/modules` suite included here was generated. +add documentation of your Schemas and :class:`argschema.ArgSchemaParser` classes in your project. +This is how the documentation of the :doc:`../tests/modules` suite included here was generated. -To configure sphinx to use this function, you must be using the sphnix autodoc module and add the following to your conf.py file +To configure sphnix to use this function, you must be using the sphnix autodoc module +and add the following to your conf.py file .. code-block:: python @@ -215,19 +252,7 @@ To configure sphinx to use this function, you must be using the sphnix autodoc m def setup(app): app.connect('autodoc-process-docstring',process_schemas) -Installation ------------- -install via source code - -:: - - $ python setup.py install - -or pip -:: - - $ pip install argschema .. toctree:: diff --git a/examples/cli_example.py b/examples/cli_example.py index 5a71323e..b7cb6968 100644 --- a/examples/cli_example.py +++ b/examples/cli_example.py @@ -13,10 +13,8 @@ class MySchema(ArgSchema): description="my example array") string_list = List(List(Str), default=[["hello", "world"], ["lists!"]], - cli_as_single_argument=True, description="list of lists of strings") int_list = List(Int, default=[1, 2, 3], - cli_as_single_argument=True, description="list of ints") nested = Nested(MyNestedSchema, required=True) diff --git a/examples/deprecated_example.py b/examples/deprecated_example.py deleted file mode 100644 index bea1e12c..00000000 --- a/examples/deprecated_example.py +++ /dev/null @@ -1,15 +0,0 @@ -from argschema import ArgSchema, ArgSchemaParser -from argschema.fields import List, Float - - -class MySchema(ArgSchema): - list_old = List(Float, default=[1.1, 2.2, 3.3], - description="float list with deprecated cli") - list_new = List(Float, default=[4.4, 5.5, 6.6], - cli_as_single_argument=True, - description="float list with supported cli") - - -if __name__ == '__main__': - mod = ArgSchemaParser(schema_type=MySchema) - print(mod.args) diff --git a/examples/multisource_example.json b/examples/multisource_example.json new file mode 100644 index 00000000..49ad4497 --- /dev/null +++ b/examples/multisource_example.json @@ -0,0 +1,6 @@ +{ + "a_subschema": { + "an_int": 12 + }, + "a_float": 15.5 +} \ No newline at end of file diff --git a/examples/multisource_example.py b/examples/multisource_example.py new file mode 100644 index 00000000..7ac062cb --- /dev/null +++ b/examples/multisource_example.py @@ -0,0 +1,48 @@ +"""This example shows you how to register multiple input sources for your executable, which users can then select from dynamically when running it. This feature makes your code a bit more flexible about the format of the input parameters. + +There is a similar feature (not shown here) for specifying output sinks. It follows the same pattern. + +Usage +----- +# you can load parameters from a yaml ... +$ python examples/multisource_example.py --input_yaml examples/multisource_example.yaml +{'a_subschema': {'an_int': 13}, 'log_level': 'ERROR', 'a_float': 16.7} + +# ... or from an input json ... +$ python examples/multisource_example.py --input_json examples/multisource_example.json +{'a_float': 15.5, 'a_subschema': {'an_int': 12}, 'log_level': 'ERROR'} + +# ... but not both +$ python examples/multisource_example.py --input_json examples/multisource_example.json --input_yaml examples/multisource_example.yaml +argschema.sources.source.MultipleConfigurationError: more then one InputSource configuration present in {'input_json': 'examples/multisource_example.json', 'input_yaml': 'examples/multisource_example.yaml'} + +# command line parameters still override sourced ones +$ python examples/multisource_example.py --input_json examples/multisource_example.json --a_float 13.1 +{'a_float': 13.1, 'a_subschema': {'an_int': 12}, 'log_level': 'ERROR'} + +""" + +import argschema + +class SubSchema(argschema.schemas.DefaultSchema): + an_int = argschema.fields.Int() + +class MySchema(argschema.ArgSchema): + a_subschema = argschema.fields.Nested(SubSchema) + a_float = argschema.fields.Float() + + +def main(): + + parser = argschema.ArgSchemaParser( + schema_type=MySchema, + input_sources=[ # each source provided here will be checked against command-line arguments + argschema.sources.json_source.JsonSource, # ArgschemaParser includes this source by default + argschema.sources.yaml_source.YamlSource + ] + ) + + print(parser.args) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/multisource_example.yaml b/examples/multisource_example.yaml new file mode 100644 index 00000000..7323832d --- /dev/null +++ b/examples/multisource_example.yaml @@ -0,0 +1,3 @@ +a_subschema : + an_int: 13 +a_float: 16.7 \ No newline at end of file diff --git a/test/fields/test_deprecated.py b/test/fields/test_deprecated.py deleted file mode 100644 index a678074b..00000000 --- a/test/fields/test_deprecated.py +++ /dev/null @@ -1,25 +0,0 @@ -import pytest -from argschema import ArgSchemaParser, ArgSchema -from argschema.fields import OptionList -import marshmallow as mm - - -class OptionSchema(ArgSchema): - a = OptionList([1, 2, 3], required=True, description='one of 1,2,3') - - -def test_option_list(): - input_data = { - 'a': 1 - } - ArgSchemaParser( - input_data=input_data, schema_type=OptionSchema, args=[]) - - -def test_bad_option(): - input_data = { - 'a': 4 - } - with pytest.raises(mm.ValidationError): - ArgSchemaParser( - input_data=input_data, schema_type=OptionSchema, args=[]) diff --git a/test/sources/test_json.py b/test/sources/test_json.py new file mode 100644 index 00000000..08bbb8d0 --- /dev/null +++ b/test/sources/test_json.py @@ -0,0 +1,27 @@ +import json + +import pytest + +from argschema.sources import json_source + + +def test_json_source_get_dict(tmpdir_factory): + path = str(tmpdir_factory.mktemp("test_json_source").join("inp.json")) + + with open(path, "w") as jf: + json.dump({"a": 12}, jf) + + source = json_source.JsonSource() + source.load_config({"input_json": path}) + + assert source.get_dict()["a"] == 12 + +def test_json_sink_put_dict(tmpdir_factory): + path = str(tmpdir_factory.mktemp("test_json_source").join("out.json")) + + sink = json_source.JsonSink() + sink.load_config({"output_json": path}) + sink.put_dict({"a": 13}) + + with open(path, "r") as jf: + assert json.load(jf)["a"] == 13 \ No newline at end of file diff --git a/test/sources/test_parser_integration.py b/test/sources/test_parser_integration.py new file mode 100644 index 00000000..c607d28e --- /dev/null +++ b/test/sources/test_parser_integration.py @@ -0,0 +1,125 @@ +import json +import os + +import pytest +import yaml + +import argschema +from argschema.sources.json_source import JsonSource, JsonSink +from argschema.sources.yaml_source import YamlSource, YamlSink +from argschema.sources.source import MultipleConfigurationError + + +class MyNestedSchema(argschema.schemas.DefaultSchema): + one = argschema.fields.Int(required=True,description="nested integer") + two = argschema.fields.Boolean(required=True,description="a nested boolean") + +class MySchema(argschema.ArgSchema): + a = argschema.fields.Int(required=True,description="parameter a") + b = argschema.fields.Str(required=False,default="my value",description="optional b string parameter") + nest = argschema.fields.Nested(MyNestedSchema,description="a nested schema") + +class MyOutputSchema(argschema.schemas.DefaultSchema): + a = argschema.fields.Int(required=True,description="parameter a") + b = argschema.fields.Str(required=False,default="my value",description="optional b string parameter") + +class MyParser(argschema.ArgSchemaParser): + default_schema = MySchema + +@pytest.fixture(scope='module') +def json_inp(tmpdir_factory): + file_in = tmpdir_factory.mktemp('test').join('test_input_json.json') + input_data = { + 'a':5, + 'nest':{ + 'one':7, + 'two':False + } + } + + with open(str(file_in),'w') as fp: + json.dump(input_data, fp) + + return str(file_in) + +@pytest.fixture(scope='module') +def yaml_inp(tmpdir_factory): + file_in = tmpdir_factory.mktemp('test').join('test_input_yaml.yaml') + input_data = { + 'a':6, + 'nest':{ + 'one':8, + 'two':False + } + } + + with open(str(file_in),'w') as fp: + yaml.dump(input_data, fp) + + return str(file_in) + + +@pytest.mark.parametrize("inp_sources", [ + JsonSource(), [JsonSource()], JsonSource, [JsonSource] +]) +def test_json_input_args(json_inp, inp_sources): + parser = MyParser( + input_sources=inp_sources, + args=["--input_json", + json_inp] + ) + + assert parser.args["a"] == 5 + +@pytest.mark.parametrize("inp_sources", [ + JsonSource(), [JsonSource()], JsonSource, [JsonSource] +]) +def test_json_input_data(json_inp, inp_sources): + parser = MyParser( + input_sources=inp_sources, + input_data={"input_json":json_inp}, + args=[] + ) + + assert parser.args["a"] == 5 + +def test_multisource_arg(yaml_inp): + parser = MyParser( + input_sources=[JsonSource, YamlSource], + args=["--input_yaml", yaml_inp] + ) + assert parser.args["a"] == 6 + +def test_multisource_arg_conflict(json_inp, yaml_inp): + with pytest.raises(MultipleConfigurationError): + parser = MyParser( + input_sources=[JsonSource, YamlSource], + args=["--input_yaml", yaml_inp, "--input_json", json_inp] + ) + +def test_multisink(yaml_inp): + out_path = os.path.join(os.path.dirname(yaml_inp), "out.json") + + parser = MyParser( + output_schema_type=MyOutputSchema, + input_sources=YamlSource, + output_sinks=[YamlSink, JsonSink], + args=["--input_yaml", yaml_inp, "--output_json", out_path] + ) + + parser.output({"a": 12, "b": "16"}) + with open(out_path, "r") as out_file: + obt = json.load(out_file) + assert obt["a"] == 12 + +def test_multisink_conflicting(yaml_inp, json_inp): + yaml_out = os.path.join(os.path.dirname(yaml_inp), "out.yaml") + json_out = os.path.join(os.path.dirname(json_inp), "out.json") + + with pytest.raises(MultipleConfigurationError): + parser = MyParser( + output_schema_type=MyOutputSchema, + input_sources=[YamlSource], + output_sinks=[JsonSink, YamlSink], + args=["--output_yaml", yaml_out, "--output_json", json_out] + ) diff --git a/test/sources/test_url.py b/test/sources/test_url.py new file mode 100644 index 00000000..5215476e --- /dev/null +++ b/test/sources/test_url.py @@ -0,0 +1,41 @@ +import requests +import mock +from argschema.sources.url_source import UrlSource +from argschema import ArgSchemaParser + + +def mocked_requests_get(*args, **kwargs): + class MockResponse: + def __init__(self, json_data, status_code): + self.json_data = json_data + self.status_code = status_code + + def raise_for_status(self): + if self.status_code >= 400: + raise requests.exceptions.HTTPError() + + def json(self): + return self.json_data + + if args[0] == 'http://localhost:88/test.json': + return MockResponse({ + 'a': 7, + 'nest': { + 'one': 7, + 'two': False + } + }, 200) + return MockResponse(None, 404) + + +@mock.patch('requests.get', side_effect=mocked_requests_get) +def test_url_parser_get_dict(mock_get): + source = UrlSource() + source.load_config({ + "input_host": "localhost", + "input_port": 88, + "input_url": "test.json", + }) + + obtained = source.get_dict() + assert obtained["a"] == 7 \ No newline at end of file diff --git a/test/sources/test_yaml.py b/test/sources/test_yaml.py new file mode 100644 index 00000000..bad20703 --- /dev/null +++ b/test/sources/test_yaml.py @@ -0,0 +1,25 @@ +import pytest +import yaml + +from argschema.sources import yaml_source + +def test_json_source_get_dict(tmpdir_factory): + path = str(tmpdir_factory.mktemp("test_yaml_source").join("inp.yaml")) + + with open(path, "w") as jf: + yaml.dump({"a": 12}, jf) + + source = yaml_source.YamlSource() + source.load_config({"input_yaml": path}) + + assert source.get_dict()["a"] == 12 + +def test_json_sink_put_dict(tmpdir_factory): + path = str(tmpdir_factory.mktemp("test_yaml_source").join("out.yaml")) + + sink = yaml_source.YamlSink() + sink.load_config({"output_yaml": path}) + sink.put_dict({"a": 13}) + + with open(path, "r") as jf: + assert yaml.load(jf)["a"] == 13 \ No newline at end of file diff --git a/test/test_argschema_parser.py b/test/test_argschema_parser.py index c63a85b2..8a9e95d5 100644 --- a/test/test_argschema_parser.py +++ b/test/test_argschema_parser.py @@ -85,11 +85,12 @@ def test_parser_output(tmpdir_factory): 'nest': { 'one': 7, 'two': False - } + }, + 'output_json': str(json_path), + 'output_json_indent': 2 } mod = MyParser(input_data=input_data, args=[]) - - mod.output(mod.args, output_path=str(json_path), indent=2) + mod.output(mod.args) with open(str(json_path), 'r') as jf: obt = json.load(jf) assert(obt['nest']['one'] == mod.args['nest']['one']) diff --git a/test/test_cli_overrides.py b/test/test_cli_overrides.py index 146fd287..8f380b9e 100644 --- a/test/test_cli_overrides.py +++ b/test/test_cli_overrides.py @@ -65,14 +65,6 @@ def test_data(inputdir, inputfile, outputdir, outputfile): return data -@pytest.fixture -def deprecated_data(): - data = { - "list_deprecated": [300, 200, 800, 1000], - } - return data - - class MyNestedSchema(DefaultSchema): a = fields.Int(required=True) b = fields.Boolean(required=True) @@ -104,10 +96,6 @@ class MySchema(ArgSchema): uuid = fields.UUID(required=True) -class MyDeprecatedSchema(ArgSchema): - list_deprecated = fields.List(fields.Int, required=True) - - def test_unexpected_input(test_data): with pytest.raises(SystemExit): ArgSchemaParser(test_data, schema_type=MySchema, @@ -222,15 +210,28 @@ def test_override_list(test_data): args=["--list", "invalid"]) -def test_override_list_deprecated(deprecated_data): - with pytest.warns(FutureWarning): - mod = ArgSchemaParser(deprecated_data, schema_type=MyDeprecatedSchema, - args=["--list_deprecated", "1000", "3000"]) - assert(mod.args["list_deprecated"] == [1000, 3000]) - with pytest.raises(mm.ValidationError): - mod = ArgSchemaParser(deprecated_data, - schema_type=MyDeprecatedSchema, - args=["--list_deprecated", "[1000,3000]"]) +# @pytest.fixture +# def deprecated_data(): +# data = { +# "list_deprecated": [300, 200, 800, 1000], +# } +# return data +# +# +# class MyDeprecatedSchema(ArgSchema): +# list_deprecated = fields.List(fields.Int, required=True) +# +# +# def test_override_list_deprecated(deprecated_data): +# with pytest.warns(FutureWarning): +# mod = ArgSchemaParser(input_data=deprecated_data, +# schema_type=MyDeprecatedSchema, +# args=["--list_deprecated", "1000", "3000"]) +# assert(mod.args["list_deprecated"] == [1000, 3000]) +# with pytest.raises(mm.ValidationError): +# mod = ArgSchemaParser(deprecated_data, +# schema_type=MyDeprecatedSchema, +# args=["--list_deprecated", "[1000,3000]"]) # def test_override_localdatetime(test_data): diff --git a/test/test_first_test.py b/test/test_first_test.py index 561be6c5..9fd178f7 100644 --- a/test/test_first_test.py +++ b/test/test_first_test.py @@ -16,15 +16,8 @@ def test_bad_path(): ArgSchemaParser(input_data=example, args=[]) -def test_simple_example(tmpdir): - file_in = tmpdir.join('test_input_json.json') - file_in.write('nonesense') - - file_out = tmpdir.join('test_output.json') - +def test_simple_example(): example = { - "input_json": str(file_in), - "output_json": str(file_out), "log_level": "CRITICAL"} jm = ArgSchemaParser(input_data=example, args=[]) @@ -126,7 +119,7 @@ def test_simple_extension_write_overwrite(simple_extension_file): def test_simple_extension_write_overwrite_list(simple_extension_file): args = ['--input_json', str(simple_extension_file), - '--test.d', '6', '7', '8', '9'] + '--test.d', "[6,7,8,9]"] mod = ArgSchemaParser(schema_type=SimpleExtension, args=args) assert len(mod.args['test']['d']) == 4 @@ -139,27 +132,6 @@ def test_bad_input_json_argparse(): # TESTS DEMONSTRATING BAD BEHAVIOR OF DEFAULT LOADING -class MyExtensionOld(mm.Schema): - a = mm.fields.Str(description='a string') - b = mm.fields.Int(description='an integer') - c = mm.fields.Int(description='an integer', default=10) - d = mm.fields.List(mm.fields.Int, - description='a list of integers') - - -class SimpleExtensionOld(ArgSchema): - test = mm.fields.Nested(MyExtensionOld, default=None, required=True) - - -def test_simple_extension_old_pass(): - mod = ArgSchemaParser( - input_data=SimpleExtension_example_valid, - schema_type=SimpleExtensionOld, args=[]) - assert mod.args['test']['a'] == 'hello' - assert mod.args['test']['b'] == 1 - assert mod.args['test']['c'] == 10 - assert len(mod.args['test']['d']) == 3 - class RecursiveSchema(argschema.schemas.DefaultSchema): children = mm.fields.Nested("self", many=True, diff --git a/test/test_output.py b/test/test_output.py index 0678f3a7..4f24de73 100644 --- a/test/test_output.py +++ b/test/test_output.py @@ -1,6 +1,7 @@ from argschema import ArgSchemaParser from argschema.schemas import DefaultSchema from argschema.fields import Str, Int, NumpyArray +from argschema.sources import JsonSink import json import numpy as np import pytest @@ -96,7 +97,8 @@ def test_alt_output(tmpdir): "b": 5, "M": M } - mod.output(output, str(file_out_2)) + sink = JsonSink(output_json=str(file_out_2)) + mod.output(output, sink=sink) with open(str(file_out_2), 'r') as fp: actual_output = json.load(fp) assert actual_output == expected_output diff --git a/test/test_utils.py b/test/test_utils.py index 5f440a08..134603ca 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -133,6 +133,7 @@ def test_schema_argparser_with_baseball(): parser = utils.schema_argparser(schema) help = parser.format_help() help = help.replace('\n', '').replace(' ', '') + print(help) assert( '--strikesSTRIKEShowmanystrikes(0-2)(REQUIRED)(validoptionsare[0,1,2])' in help) # in python3.9, the format changed slightly such that