Skip to content
Open
20 changes: 15 additions & 5 deletions pyspark_datasources/fake.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import List
from datetime import datetime, timedelta
import random

from pyspark.sql.datasource import (
DataSource,
DataSourceReader,
DataSourceStreamReader,
InputPartition,
)
from pyspark.sql.types import StringType, StructType
from pyspark.sql.types import StringType, StructType, TimestampType


def _validate_faker_schema(schema):
Expand All @@ -19,18 +21,26 @@ def _validate_faker_schema(schema):
fake = Faker()
for field in schema.fields:
try:
getattr(fake, field.name)()
if field.dataType == StringType():
getattr(fake, field.name)()
elif field.dataType == TimestampType():
continue
except AttributeError:
raise Exception(
f"Unable to find a method called `{field.name}` in faker. "
f"Please check Faker's documentation to see supported methods."
)
if field.dataType != StringType():
if field.dataType not in (StringType(), TimestampType()):
raise Exception(
f"Field `{field.name}` is not a StringType. "
f"Field `{field.name}` is not a StringType or TimestampType(). "
f"Only StringType is supported in the fake datasource."
)

class GenerateDateTime:

@classmethod
def randomDate(cls):
return datetime.utcnow() + timedelta(days = random.randint(-365, 0), hours = random.randint(-23, 0), minutes = random.randint(-59, 0), seconds = random.randint(-59, 0), milliseconds = random.randint(-999, 0))

class FakeDataSource(DataSource):
"""
Expand Down Expand Up @@ -140,7 +150,7 @@ def read(self, partition):
for _ in range(num_rows):
row = []
for field in self.schema.fields:
value = getattr(fake, field.name)()
value = getattr(fake, field.name)() if field.dataType == StringType() else getattr(GenerateDateTime, 'randomDate')()
row.append(value)
yield tuple(row)

Expand Down