diff --git a/splunklib/searchcommands/internals.py b/splunklib/searchcommands/internals.py index abceac30..5f20c3fa 100644 --- a/splunklib/searchcommands/internals.py +++ b/splunklib/searchcommands/internals.py @@ -554,7 +554,7 @@ def write_record(self, record): def write_records(self, records): self._ensure_validity() - records = list(records) + records = [] if records is NotImplemented else list(records) write_record = self._write_record for record in records: write_record(record) diff --git a/splunklib/searchcommands/reporting_command.py b/splunklib/searchcommands/reporting_command.py index 5df3dc7e..e455a159 100644 --- a/splunklib/searchcommands/reporting_command.py +++ b/splunklib/searchcommands/reporting_command.py @@ -77,21 +77,26 @@ def map(self, records): """ return NotImplemented - def prepare(self): - - phase = self.phase + def _has_custom_method(self, method_name): + method = getattr(self.__class__, method_name, None) + base_method = getattr(ReportingCommand, method_name, None) + return callable(method) and (method is not base_method) - if phase == 'map': - # noinspection PyUnresolvedReferences - self._configuration = self.map.ConfigurationSettings(self) + def prepare(self): + if self.phase == 'map': + if self._has_custom_method('map'): + phase_method = getattr(self.__class__, 'map') + self._configuration = phase_method.ConfigurationSettings(self) + else: + self._configuration = self.ConfigurationSettings(self) return - if phase == 'reduce': + if self.phase == 'reduce': streaming_preop = chain((self.name, 'phase="map"', str(self._options)), self.fieldnames) self._configuration.streaming_preop = ' '.join(streaming_preop) return - raise RuntimeError(f'Unrecognized reporting command phase: {json_encode_string(str(phase))}') + raise RuntimeError(f'Unrecognized reporting command phase: {json_encode_string(str(self.phase))}') def reduce(self, records): """ Override this method to produce a reporting data structure. diff --git a/tests/searchcommands/test_reporting_command.py b/tests/searchcommands/test_reporting_command.py index 2111447d..dbda9cd8 100644 --- a/tests/searchcommands/test_reporting_command.py +++ b/tests/searchcommands/test_reporting_command.py @@ -32,3 +32,42 @@ def reduce(self, records): data = list(data_chunk.data) assert len(data) == 1 assert int(data[0]['sum']) == sum(range(0, 10)) + + +def test_simple_reporting_command_with_map(): + @searchcommands.Configuration() + class MapAndReduceReportingCommand(searchcommands.ReportingCommand): + def map(self, records): + for record in records: + record["value"] = str(int(record["value"]) * 2) + yield record + + def reduce(self, records): + total = 0 + for record in records: + total += int(record["value"]) + yield {"sum": total} + + cmd = MapAndReduceReportingCommand() + ifile = io.BytesIO() + + input_data = [{"value": str(i)} for i in range(5)] + + mapped_data = list(cmd.map(input_data)) + + ifile.write(chunky.build_getinfo_chunk()) + ifile.write(chunky.build_data_chunk(mapped_data)) + ifile.seek(0) + + ofile = io.BytesIO() + cmd._process_protocol_v2([], ifile, ofile) + + ofile.seek(0) + chunk_stream = chunky.ChunkedDataStream(ofile) + chunk_stream.read_chunk() + data_chunk = chunk_stream.read_chunk() + assert data_chunk.meta['finished'] is True + + result = list(data_chunk.data) + expected_sum = sum(i * 2 for i in range(5)) + assert int(result[0]["sum"]) == expected_sum