Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
name: CI

on:
push:
branches: [ main, master ]
pull_request:
branches: [ main, master ]

jobs:
test:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ['3.9', '3.10', '3.11', '3.12']

steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Install Poetry
uses: snok/[email protected]
with:
version: latest
virtualenvs-create: true
virtualenvs-in-project: true

- name: Load cached venv
id: cached-poetry-dependencies
uses: actions/cache@v4
with:
path: .venv
key: venv-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}

- name: Install dependencies
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: poetry install --no-interaction --no-root --extras "all"

- name: Install project
run: poetry install --no-interaction --extras "all"

- name: Run tests
run: poetry run pytest tests/ -v

- name: Run tests with coverage
run: |
poetry run pytest tests/ --cov=pyspark_datasources --cov-report=xml --cov-report=term-missing

26 changes: 18 additions & 8 deletions pyspark_datasources/fake.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import List
from datetime import datetime, timedelta
import random

from pyspark.sql.datasource import (
DataSource,
DataSourceReader,
DataSourceStreamReader,
InputPartition,
)
from pyspark.sql.types import StringType, StructType
from pyspark.sql.types import StringType, StructType, TimestampType


def _validate_faker_schema(schema):
Expand All @@ -19,18 +21,26 @@ def _validate_faker_schema(schema):
fake = Faker()
for field in schema.fields:
try:
getattr(fake, field.name)()
if field.dataType == StringType():
getattr(fake, field.name)()
elif field.dataType == TimestampType():
continue
except AttributeError:
raise Exception(
f"Unable to find a method called `{field.name}` in faker. "
f"Please check Faker's documentation to see supported methods."
)
if field.dataType != StringType():
if field.dataType not in (StringType(), TimestampType()):
raise Exception(
f"Field `{field.name}` is not a StringType. "
f"Only StringType is supported in the fake datasource."
f"Field `{field.name}` is not a StringType or TimestampType(). "
f"Only StringType and TimestampType are supported in the fake datasource."
)

class GenerateDateTime:

@classmethod
def random_datetime(cls) -> datetime:
return datetime.utcnow() + timedelta(days = random.randint(-365, 0), hours = random.randint(-23, 0), minutes = random.randint(-59, 0), seconds = random.randint(-59, 0), microseconds = random.randint(-999000, 0))

class FakeDataSource(DataSource):
"""
Expand Down Expand Up @@ -114,7 +124,7 @@ def name(cls):
return "fake"

def schema(self):
return "name string, date string, zipcode string, state string"
return "name string, date string, zipcode string, state string, creationDate timestamp"

def reader(self, schema: StructType) -> "FakeDataSourceReader":
_validate_faker_schema(schema)
Expand All @@ -140,7 +150,7 @@ def read(self, partition):
for _ in range(num_rows):
row = []
for field in self.schema.fields:
value = getattr(fake, field.name)()
value = getattr(fake, field.name)() if field.dataType == StringType() else GenerateDateTime.random_datetime()
row.append(value)
yield tuple(row)

Expand Down Expand Up @@ -169,6 +179,6 @@ def read(self, partition):
for _ in range(partition.value):
row = []
for field in self.schema.fields:
value = getattr(fake, field.name)()
value = getattr(fake, field.name)() if field.dataType == StringType() else GenerateDateTime.random_datetime()
row.append(value)
yield tuple(row)
10 changes: 8 additions & 2 deletions tests/test_data_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pyspark.sql import SparkSession
from pyspark_datasources import *

from pyspark.sql.types import TimestampType

@pytest.fixture
def spark():
Expand Down Expand Up @@ -30,14 +30,20 @@ def test_fake_datasource_stream(spark):
)
spark.sql("SELECT * FROM result").show()
assert spark.sql("SELECT * FROM result").count() == 3
df = spark.table("result")
df_datatypes = [d.dataType for d in df.schema.fields]
assert len(df.columns) == 5
assert TimestampType() in df_datatypes


def test_fake_datasource(spark):
spark.dataSource.register(FakeDataSource)
df = spark.read.format("fake").load()
df_datatypes = [d.dataType for d in df.schema.fields]
df.show()
assert df.count() == 3
assert len(df.columns) == 4
assert len(df.columns) == 5
assert TimestampType() in df_datatypes


def test_kaggle_datasource(spark):
Expand Down