diff --git a/pyspark_datasources/fake.py b/pyspark_datasources/fake.py index 2b9447c..d58b71e 100644 --- a/pyspark_datasources/fake.py +++ b/pyspark_datasources/fake.py @@ -1,4 +1,6 @@ from typing import List +from datetime import datetime, timedelta +import random from pyspark.sql.datasource import ( DataSource, @@ -6,7 +8,7 @@ DataSourceStreamReader, InputPartition, ) -from pyspark.sql.types import StringType, StructType +from pyspark.sql.types import StringType, StructType, TimestampType def _validate_faker_schema(schema): @@ -19,18 +21,26 @@ def _validate_faker_schema(schema): fake = Faker() for field in schema.fields: try: - getattr(fake, field.name)() + if field.dataType == StringType(): + getattr(fake, field.name)() + elif field.dataType == TimestampType(): + continue except AttributeError: raise Exception( f"Unable to find a method called `{field.name}` in faker. " f"Please check Faker's documentation to see supported methods." ) - if field.dataType != StringType(): + if field.dataType not in (StringType(), TimestampType()): raise Exception( - f"Field `{field.name}` is not a StringType. " - f"Only StringType is supported in the fake datasource." + f"Field `{field.name}` is not a StringType or TimestampType(). " + f"Only StringType and TimestampType are supported in the fake datasource." ) +class GenerateDateTime: + + @classmethod + def random_datetime(cls) -> datetime: + return datetime.utcnow() + timedelta(days = random.randint(-365, 0), hours = random.randint(-23, 0), minutes = random.randint(-59, 0), seconds = random.randint(-59, 0), microseconds = random.randint(-999000, 0)) class FakeDataSource(DataSource): """ @@ -140,7 +150,7 @@ def read(self, partition): for _ in range(num_rows): row = [] for field in self.schema.fields: - value = getattr(fake, field.name)() + value = getattr(fake, field.name)() if field.dataType == StringType() else GenerateDateTime.random_datetime() row.append(value) yield tuple(row) @@ -169,6 +179,6 @@ def read(self, partition): for _ in range(partition.value): row = [] for field in self.schema.fields: - value = getattr(fake, field.name)() + value = getattr(fake, field.name)() if field.dataType == StringType() else GenerateDateTime.random_datetime() row.append(value) yield tuple(row) diff --git a/tests/test_data_sources.py b/tests/test_data_sources.py index ad6a2b0..fdc4ab6 100644 --- a/tests/test_data_sources.py +++ b/tests/test_data_sources.py @@ -2,7 +2,7 @@ from pyspark.sql import SparkSession from pyspark_datasources import * - +from pyspark.sql.types import TimestampType, StringType, StructType, StructField @pytest.fixture def spark(): @@ -30,6 +30,8 @@ def test_fake_datasource_stream(spark): ) spark.sql("SELECT * FROM result").show() assert spark.sql("SELECT * FROM result").count() == 3 + df = spark.table("result") + assert len(df.columns) == 4 def test_fake_datasource(spark): @@ -40,6 +42,19 @@ def test_fake_datasource(spark): assert len(df.columns) == 4 +def test_fake_timestamp_column(spark): + spark.dataSource.register(FakeDataSource) + schema = StructType([StructField("name", StringType(), True), StructField("zipcode", StringType(), True), StructField("state", StringType(), True), StructField("date", TimestampType(), True)]) + df = spark.read.format("fake").schema(schema).load() + df_columns = [d.name for d in df.schema.fields] + df_datatypes = [d.dataType for d in df.schema.fields] + df.show() + assert df.count() == 3 + assert len(df.columns) == 4 + assert df_columns == ["name", "zipcode", "state", "date"] + assert df_datatypes[-1] == TimestampType() + + def test_kaggle_datasource(spark): spark.dataSource.register(KaggleDataSource) df = spark.read.format("kaggle").options(handle="yasserh/titanic-dataset").load("Titanic-Dataset.csv")