diff --git a/dbldatagen/__init__.py b/dbldatagen/__init__.py index 3c41a5d7..69f17493 100644 --- a/dbldatagen/__init__.py +++ b/dbldatagen/__init__.py @@ -42,12 +42,13 @@ from .spark_singleton import SparkSingleton from .text_generators import TemplateGenerator, ILText, TextGenerator from .text_generator_plugins import PyfuncText, PyfuncTextFactory, FakerTextFactory, fakerText +from .text_generatestring import GenerateString from .html_utils import HtmlUtils __all__ = ["data_generator", "data_analyzer", "schema_parser", "daterange", "nrange", "column_generation_spec", "utils", "function_builder", "spark_singleton", "text_generators", "datarange", "datagen_constants", - "text_generator_plugins", "html_utils" + "text_generator_plugins", "html_utils", "text_generatestring" ] diff --git a/dbldatagen/column_generation_spec.py b/dbldatagen/column_generation_spec.py index db535ce3..7cc2c819 100644 --- a/dbldatagen/column_generation_spec.py +++ b/dbldatagen/column_generation_spec.py @@ -1142,7 +1142,7 @@ def _applyPrefixSuffixExpressions(self, cprefix, csuffix, new_def): new_def = concat(new_def.astype(IntegerType()), lit(text_separator), lit(csuffix)) return new_def - def _applyTextGenerationExpression(self, new_def, use_pandas_optimizations): + def _applyTextGenerationExpression(self, new_def, use_pandas_optimizations=True): """Apply text generation expression to column expression :param new_def : column definition being created @@ -1153,6 +1153,9 @@ def _applyTextGenerationExpression(self, new_def, use_pandas_optimizations): # while it seems like this could use a shared instance, this does not work if initialized # in a class method tg = self.textGenerator + + new_def = tg.prepareBaseValue(new_def) + if use_pandas_optimizations: self.executionHistory.append(f".. text generation via pandas scalar udf `{tg}`") u_value_from_generator = pandas_udf(tg.pandasGenerateText, diff --git a/dbldatagen/data_analyzer.py b/dbldatagen/data_analyzer.py index 5aec5245..ecacd95c 100644 --- a/dbldatagen/data_analyzer.py +++ b/dbldatagen/data_analyzer.py @@ -14,6 +14,7 @@ import pyspark.sql.functions as F from .utils import strip_margins +from .html_utils import HtmlUtils from .spark_singleton import SparkSingleton @@ -359,7 +360,7 @@ def _scriptDataGeneratorCode(cls, schema, dataSummary=None, sourceDf=None, suppr return "\n".join(stmts) @classmethod - def scriptDataGeneratorFromSchema(cls, schema, suppressOutput=False, name=None): + def scriptDataGeneratorFromSchema(cls, schema, suppressOutput=False, name=None, asHtml=False): """ Generate outline data generator code from an existing dataframe @@ -373,16 +374,24 @@ def scriptDataGeneratorFromSchema(cls, schema, suppressOutput=False, name=None): The dataframe to be analyzed is the dataframe passed to the constructor of the DataAnalyzer object. :param schema: Pyspark schema - i.e manually constructed StructType or return value from `dataframe.schema` - :param suppressOutput: Suppress printing of generated code if True + :param suppressOutput: Suppress printing of generated code if True. If asHtml is True, output is suppressed :param name: Optional name for data generator - :return: String containing skeleton code + :param asHtml: If True, will generate Html suitable for notebook ``displayHtml``. + :return: String containing skeleton code (in Html form if `asHtml` is True) """ - return cls._scriptDataGeneratorCode(schema, - suppressOutput=suppressOutput, - name=name) + omit_output_printing = suppressOutput or asHtml + + generated_code = cls._scriptDataGeneratorCode(schema, + suppressOutput=omit_output_printing, + name=name) + + if asHtml: + generated_code = HtmlUtils.formatCodeAsHtml(generated_code) + + return generated_code - def scriptDataGeneratorFromData(self, suppressOutput=False, name=None): + def scriptDataGeneratorFromData(self, suppressOutput=False, name=None, asHtml=False): """ Generate outline data generator code from an existing dataframe @@ -395,14 +404,17 @@ def scriptDataGeneratorFromData(self, suppressOutput=False, name=None): The dataframe to be analyzed is the Spark dataframe passed to the constructor of the DataAnalyzer object - :param suppressOutput: Suppress printing of generated code if True + :param suppressOutput: Suppress printing of generated code if True. If asHtml is True, output is suppressed :param name: Optional name for data generator - :return: String containing skeleton code + :param asHtml: If True, will generate Html suitable for notebook ``displayHtml``. + :return: String containing skeleton code (in Html form if `asHtml` is True) """ assert self._df is not None assert type(self._df) is ssql.DataFrame, "sourceDf must be a valid Pyspark dataframe" + omit_output_printing = suppressOutput or asHtml + if self._dataSummary is None: df_summary = self.summarizeToDF() @@ -411,8 +423,13 @@ def scriptDataGeneratorFromData(self, suppressOutput=False, name=None): row_key_pairs = row.asDict() self._dataSummary[row['measure_']] = row_key_pairs - return self._scriptDataGeneratorCode(self._df.schema, - suppressOutput=suppressOutput, - name=name, - dataSummary=self._dataSummary, - sourceDf=self._df) + generated_code = self._scriptDataGeneratorCode(self._df.schema, + suppressOutput=omit_output_printing, + name=name, + dataSummary=self._dataSummary, + sourceDf=self._df) + + if asHtml: + generated_code = HtmlUtils.formatCodeAsHtml(generated_code) + + return generated_code diff --git a/dbldatagen/text_generatestring.py b/dbldatagen/text_generatestring.py new file mode 100644 index 00000000..fdb561bd --- /dev/null +++ b/dbldatagen/text_generatestring.py @@ -0,0 +1,181 @@ +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +This file defines the RandomStr text generator +""" + +import math +import random + +import numpy as np +import pandas as pd + +import pyspark.sql.functions as F + +from .text_generators import TextGenerator +from .text_generators import _DIGITS_ZERO, _LETTERS_UPPER, _LETTERS_LOWER, _LETTERS_ALL + + +class GenerateString(TextGenerator): # lgtm [py/missing-equals] + """This class handles the generation of string text of specified length drawn from alphanumeric characters. + + The set of chars to be used can be modified based on the parameters + + This will generate deterministic strings chosen from the pool of characters `0-9`, `a-z`, `A-Z`, or from a + custom character range if specified. + + :param length: length of string. Can be integer, or tuple (min, max) + :param leadingAlpha: If True, leading character will be in range a-zAA-Z + :param allUpper: If True, any alpha chars will be uppercase + :param allLower: If True, any alpha chars will be lowercase + :param allAlpha: If True, all chars will be non numeric + :param customChars: If supplied, specifies a list of chars to use, or string of chars to use. + + This method will generate deterministic strings varying in size from `minLength` to `maxLength`. + The characters chosen will be in the range 0-9`, `a-z`, `A-Z` unless modified using the `leadingAlpha`, + `allUpper`, `allLower`, `allAlpha` or `customChars` parameters. + + The modifiers can be combined - for example GenerateString(1, 5, leadingAlpha=True, allUpper=True) + + When the length is specified to be a tuple, it wll generate variable length strings of lengths from the lower bound + to the upper bound inclusive. + + The strings are generated deterministically so that they can be used for predictable primary and foreign keys. + + If the column definition that includes this specifies `random` then the string generation will be determined by a + seeded random number according to the rules for random numbers and random seeds used in other columns + + If random is false, then the string will be generated from a pseudo random sequence generated purely from the + SQL hash of the `baseColumns` + + .. note:: + If customChars are specified, then the flag `allAlpha` will only remove digits. + + """ + + def __init__(self, length, leadingAlpha=True, allUpper=False, allLower=False, allAlpha=False, customChars=None): + super().__init__() + + assert not customChars or isinstance(customChars, (list, str)), \ + "`customChars` should be list of characters or string containing custom chars" + + assert not allUpper or not allLower, "allUpper and allLower cannot both be True" + + if isinstance(customChars, str): + assert len(customChars) > 0, "string of customChars must be non-empty" + elif isinstance(customChars, list): + assert all(isinstance(c, str) for c in customChars) + assert len(customChars) > 0, "list of customChars must be non-empty" + + self.leadingAlpha = leadingAlpha + self.allUpper = allUpper + self.allLower = allLower + self.allAlpha = allAlpha + + # determine base alphabet + if isinstance(customChars, list): + charAlphabet = set("".join(customChars)) + elif isinstance(customChars, str): + charAlphabet = set(customChars) + else: + charAlphabet = set(_LETTERS_ALL).union(set(_DIGITS_ZERO)) + + if allLower: + charAlphabet = charAlphabet.difference(set(_LETTERS_UPPER)) + elif allUpper: + charAlphabet = charAlphabet.difference(set(_LETTERS_LOWER)) + + if allAlpha: + charAlphabet = charAlphabet.difference(set(_DIGITS_ZERO)) + + self._charAlphabet = np.array(list(charAlphabet)) + + if leadingAlpha: + self._firstCharAlphabet = np.array(list(charAlphabet.difference(set(_DIGITS_ZERO)))) + else: + self._firstCharAlphabet = self._charAlphabet + + # compute string lengths + if isinstance(length, int): + self._minLength = length + self._maxLength = length + elif isinstance(length, tuple): + assert len(length) == 2, "only 2 elements can be specified if length is a tuple" + assert all(isinstance(el, int) for el in length) + self._minLength, self._maxLength = length + else: + raise ValueError("`length` must be an integer or a tuple of two integers") + + # compute bounds for generated strings + bounds = [len(self._firstCharAlphabet)] + for ix in range(1, self._maxLength): + bounds.append(len(self._charAlphabet)) + + self._bounds = bounds + + def __repr__(self): + return f"GenerateString(length={(self._minLength, self._maxLength)}, leadingAlpha={self.leadingAlpha})" + + def make_variable_length_mask(self, v, lengths): + """ given 2-d array of dimensions[r, c] and lengths of dimensions[r] + + generate mask for each row where col_index[r,c] < lengths[r] + """ + print(v.shape, lengths.shape) + assert v.shape[0] == lengths.shape[0], "values and lengths must agree on dimension 0]" + _, c_ix = np.indices(v.shape) + + return (c_ix.T < lengths.T).T + + def mk_bounds(self, v, minLength, maxLength): + rng = np.random.default_rng(42) + v_bounds = np.full(v.shape[0], (maxLength - minLength) + 1) + return rng.integers(v_bounds) + minLength + + def prepareBaseValue(self, baseDef): + """ Prepare the base value for processing + + :param baseDef: base value expression + :return: base value expression unchanged + + For generate string processing , we'll use the SQL function abs(hash(baseDef) + + This will ensure that even if there are multiple base values, only a single value is passed to the UDF + """ + return F.abs(F.hash(baseDef)) + + def pandasGenerateText(self, v): + """ entry point to use for pandas udfs + + Implementation uses vectorized implementation of process + + :param v: Pandas series of values passed as base values + :return: Pandas series of expanded templates + + """ + # placeholders is numpy array used to hold results + + rnds = np.full((v.shape[0], self._maxLength), len(self._charAlphabet), dtype=np.object_) + + rng = self.getNPRandomGenerator() + rnds2 = rng.integers(rnds) + + placeholders = np.full((v.shape[0], self._maxLength), '', dtype=np.object_) + + lengths = v.to_numpy() % (self._maxLength - self._minLength) + self._minLength + + v1 = np.full((v.shape[0], self._maxLength), -1) + + placeholder_mask = self.make_variable_length_mask(placeholders, lengths) + masked_placeholders = np.ma.MaskedArray(placeholders, mask=placeholder_mask) + + masked_placeholders[~placeholder_mask] = self._charAlphabet[rnds2[~placeholder_mask]] + + output = pd.Series(list(placeholders)) + + # join strings in placeholders + results = output.apply(lambda placeholder_items: "".join([str(elem) for elem in placeholder_items])) + + return results diff --git a/dbldatagen/text_generators.py b/dbldatagen/text_generators.py index 965350be..403e06d8 100644 --- a/dbldatagen/text_generators.py +++ b/dbldatagen/text_generators.py @@ -161,6 +161,16 @@ def getAsTupleOrElse(v, defaultValue, valueName): return defaultValue + def prepareBaseValue(self, baseDef): + """ Prepare the base value for processing + + :param baseDef: base value expression + :return: base value expression unchanged + + Derived classes are expected to override this if needed + """ + return baseDef + class TemplateGenerator(TextGenerator): # lgtm [py/missing-equals] """This class handles the generation of text from templates diff --git a/docs/utils/mk_quick_index.py b/docs/utils/mk_quick_index.py index c3d08953..524aaba9 100644 --- a/docs/utils/mk_quick_index.py +++ b/docs/utils/mk_quick_index.py @@ -33,6 +33,8 @@ "grouping": "main classes"}, "text_generator_plugins.py": {"briefDesc": "Text data generation", "grouping": "main classes"}, + "text_generatestring.py": {"briefDesc": "Text data generation", + "grouping": "main classes"}, "data_analyzer.py": {"briefDesc": "Analysis of existing data", "grouping": "main classes"}, "function_builder.py": {"briefDesc": "Internal utilities to create functions related to weights", diff --git a/tests/test_generation_from_data.py b/tests/test_generation_from_data.py index fab15809..25c36aa1 100644 --- a/tests/test_generation_from_data.py +++ b/tests/test_generation_from_data.py @@ -71,6 +71,9 @@ def test_code_generation1(self, generation_spec, setupLogging): ast_tree = ast.parse(generatedCode) assert ast_tree is not None + generatedCode2 = analyzer.scriptDataGeneratorFromData(asHtml=True) + assert generatedCode in generatedCode2 + def test_code_generation_from_schema(self, generation_spec, setupLogging): df_source_data = generation_spec.build() generatedCode = dg.DataAnalyzer.scriptDataGeneratorFromSchema(df_source_data.schema) @@ -82,6 +85,10 @@ def test_code_generation_from_schema(self, generation_spec, setupLogging): ast_tree = ast.parse(generatedCode) assert ast_tree is not None + generatedCode2 = dg.DataAnalyzer.scriptDataGeneratorFromSchema(df_source_data.schema, asHtml=True) + + assert generatedCode in generatedCode2 + def test_summarize(self, testLogger, generation_spec): testLogger.info("Building test data") diff --git a/tests/test_text_generatestring.py b/tests/test_text_generatestring.py new file mode 100644 index 00000000..9d6fc599 --- /dev/null +++ b/tests/test_text_generatestring.py @@ -0,0 +1,99 @@ +import pytest +import pyspark.sql.functions as F +from pyspark.sql.types import BooleanType, DateType +from pyspark.sql.types import StructType, StructField, IntegerType, StringType, TimestampType + +import dbldatagen as dg + +spark = dg.SparkSingleton.getLocalInstance("unit tests") + +spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "20000") +spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true") + +#: list of digits for template generation +_DIGITS_ZERO = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] + +#: list of uppercase letters for template generation +_LETTERS_UPPER = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', + 'Q', 'R', 'T', 'S', 'U', 'V', 'W', 'X', 'Y', 'Z'] + +#: list of lowercase letters for template generation +_LETTERS_LOWER = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'] + +#: list of all letters uppercase and lowercase +_LETTERS_ALL = _LETTERS_LOWER + _LETTERS_UPPER + +#: list of alphanumeric chars in lowercase +_ALNUM_LOWER = _LETTERS_LOWER + _DIGITS_ZERO + +#: list of alphanumeric chars in uppercase +_ALNUM_UPPER = _LETTERS_UPPER + _DIGITS_ZERO + + +# Test manipulation and generation of test data for a large schema +class TestTextGenerateString: + + @pytest.mark.parametrize("length, leadingAlpha, allUpper, allLower, allAlpha, customChars", + [ + (5, True, True, False, False, None), + (5, True, False, True, False, None), + (5, True, False, False, True, None), + (5, False, False, False, False, None), + (5, False, True, False, True, None), + (5, False, False, True, True, None), + (5, False, False, False, False, "01234567890ABCDEF"), + ]) + def test_basics(self, length, leadingAlpha, allUpper, allLower, allAlpha, customChars): + + tg1 = dg.GenerateString(length, leadingAlpha=leadingAlpha, allUpper=allUpper, allLower=allLower, + allAlpha=allAlpha, customChars=customChars) + + assert tg1._charAlphabet is not None + assert tg1._firstCharAlphabet is not None + + if allUpper and allAlpha: + alphabet = _LETTERS_UPPER + elif allLower and allAlpha: + alphabet = _LETTERS_LOWER + elif allLower: + alphabet = _LETTERS_LOWER + _DIGITS_ZERO + elif allUpper: + alphabet = _LETTERS_UPPER + _DIGITS_ZERO + elif allAlpha: + alphabet = _LETTERS_UPPER + _LETTERS_LOWER + else: + alphabet = _LETTERS_UPPER + _LETTERS_LOWER + _DIGITS_ZERO + + if customChars is not None: + alphabet = set(alphabet).intersection(set(customChars)) + + assert set(tg1._charAlphabet) == set(alphabet) + + @pytest.mark.parametrize("genstr", + [ + dg.GenerateString((1, 10)), + dg.GenerateString((1, 10), leadingAlpha=True), + dg.GenerateString((4, 64), allUpper=True), + dg.GenerateString((10, 20), allLower=True), + dg.GenerateString((1, 10)), + dg.GenerateString((3, 15)), + dg.GenerateString((17, 22)), + dg.GenerateString((1, 10)), + ]) + def test_simple_data(self, genstr): + dgspec = (dg.DataGenerator(sparkSession=spark, name="alt_data_set", rows=10000, + partitions=4, seedMethod='hash_fieldname', verbose=True, + seedColumnName="_id") + .withIdOutput() + .withColumn("code2", IntegerType(), min=0, max=10) + .withColumn("code3", StringType(), values=['a', 'b', 'c']) + .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) + .withColumn("code5", StringType(), text=dg.GenerateString((1, 10))) + ) + + fieldsFromGenerator = set(dgspec.getOutputColumnNames()) + + df_testdata = dgspec.build() + + df_testdata.show() diff --git a/tests/test_text_generation.py b/tests/test_text_generation.py index fb23d9d3..374449cc 100644 --- a/tests/test_text_generation.py +++ b/tests/test_text_generation.py @@ -272,6 +272,8 @@ def test_small_ILText_driven_data_generation(self): df_iltext_data = testDataSpec.build() + df_iltext_data.show() + counts = df_iltext_data.agg( F.countDistinct("paras").alias("paragraphs_count") ).collect()[0]