Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions pyspark_datasources/fake.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import List
from datetime import datetime, timedelta
import random

from pyspark.sql.datasource import (
DataSource,
DataSourceReader,
DataSourceStreamReader,
InputPartition,
)
from pyspark.sql.types import StringType, StructType
from pyspark.sql.types import StringType, StructType, TimestampType


def _validate_faker_schema(schema):
Expand All @@ -19,18 +21,26 @@ def _validate_faker_schema(schema):
fake = Faker()
for field in schema.fields:
try:
getattr(fake, field.name)()
if field.dataType == StringType():
getattr(fake, field.name)()
elif field.dataType == TimestampType():
continue
except AttributeError:
raise Exception(
f"Unable to find a method called `{field.name}` in faker. "
f"Please check Faker's documentation to see supported methods."
)
if field.dataType != StringType():
if field.dataType not in (StringType(), TimestampType()):
raise Exception(
f"Field `{field.name}` is not a StringType. "
f"Only StringType is supported in the fake datasource."
f"Field `{field.name}` is not a StringType or TimestampType(). "
f"Only StringType and TimestampType are supported in the fake datasource."
)

class GenerateDateTime:

@classmethod
def random_datetime(cls) -> datetime:
return datetime.utcnow() + timedelta(days = random.randint(-365, 0), hours = random.randint(-23, 0), minutes = random.randint(-59, 0), seconds = random.randint(-59, 0), microseconds = random.randint(-999000, 0))

class FakeDataSource(DataSource):
"""
Expand Down Expand Up @@ -140,7 +150,7 @@ def read(self, partition):
for _ in range(num_rows):
row = []
for field in self.schema.fields:
value = getattr(fake, field.name)()
value = getattr(fake, field.name)() if field.dataType == StringType() else GenerateDateTime.random_datetime()
row.append(value)
yield tuple(row)

Expand Down Expand Up @@ -169,6 +179,6 @@ def read(self, partition):
for _ in range(partition.value):
row = []
for field in self.schema.fields:
value = getattr(fake, field.name)()
value = getattr(fake, field.name)() if field.dataType == StringType() else GenerateDateTime.random_datetime()
row.append(value)
yield tuple(row)
17 changes: 16 additions & 1 deletion tests/test_data_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pyspark.sql import SparkSession
from pyspark_datasources import *

from pyspark.sql.types import TimestampType, StringType, StructType, StructField

@pytest.fixture
def spark():
Expand Down Expand Up @@ -30,6 +30,8 @@ def test_fake_datasource_stream(spark):
)
spark.sql("SELECT * FROM result").show()
assert spark.sql("SELECT * FROM result").count() == 3
df = spark.table("result")
assert len(df.columns) == 4


def test_fake_datasource(spark):
Expand All @@ -40,6 +42,19 @@ def test_fake_datasource(spark):
assert len(df.columns) == 4


def test_fake_timestamp_column(spark):
spark.dataSource.register(FakeDataSource)
schema = StructType([StructField("name", StringType(), True), StructField("zipcode", StringType(), True), StructField("state", StringType(), True), StructField("date", TimestampType(), True)])
df = spark.read.format("fake").schema(schema).load()
df_columns = [d.name for d in df.schema.fields]
df_datatypes = [d.dataType for d in df.schema.fields]
df.show()
assert df.count() == 3
assert len(df.columns) == 4
assert df_columns == ["name", "zipcode", "state", "date"]
assert df_datatypes[-1] == TimestampType()


def test_kaggle_datasource(spark):
spark.dataSource.register(KaggleDataSource)
df = spark.read.format("kaggle").options(handle="yasserh/titanic-dataset").load("Titanic-Dataset.csv")
Expand Down