Init PyTest
This commit is contained in:
parent
c4fdb2860c
commit
25c2e6b7cb
8 changed files with 100 additions and 44 deletions
|
|
@ -1,3 +1,4 @@
|
||||||
|
|
||||||
|
|
||||||
def test_example_test():
|
def test_example_test():
|
||||||
|
return 0
|
||||||
46
init.py
46
init.py
|
|
@ -1,34 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import findspark
|
import findspark
|
||||||
from pyspark.sql import SparkSession
|
from pyspark.sql import SparkSession
|
||||||
import pyspark.sql.functions as F
|
|
||||||
|
|
||||||
spark = SparkSession.builder.master("local[*]").getOrCreate()
|
|
||||||
|
|
||||||
|
|
||||||
sample_data = [
|
|
||||||
{"name": "John D.", "age": 30},
|
|
||||||
{"name": "Alice G.", "age": 25},
|
|
||||||
{"name": "Bob T.", "age": 35},
|
|
||||||
{"name": "Eve A.", "age": 28}
|
|
||||||
]
|
|
||||||
df = spark.createDataFrame(sample_data)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
transformed_df = remove_extra_spaces(df, "name")
|
|
||||||
transformed_df.show()
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
init_env()
|
|
||||||
print("hey there")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
||||||
|
|
||||||
def init_env():
|
def init_env():
|
||||||
|
|
@ -37,3 +9,21 @@ def init_env():
|
||||||
os.environ["HADOOP_HOME"] = "C:\\SPARK\\hadoop"
|
os.environ["HADOOP_HOME"] = "C:\\SPARK\\hadoop"
|
||||||
|
|
||||||
findspark.init()
|
findspark.init()
|
||||||
|
|
||||||
|
|
||||||
|
def init_spark():
|
||||||
|
spark = SparkSession.builder.master("local[*]").getOrCreate()
|
||||||
|
df = spark.createDataFrame([
|
||||||
|
{'name': 'OUI OUI', 'age': 30},
|
||||||
|
])
|
||||||
|
df.show()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("hey there")
|
||||||
|
init_env()
|
||||||
|
init_spark()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,6 @@
|
||||||
|
import pyspark.sql.functions as F
|
||||||
|
|
||||||
|
|
||||||
def remove_extra_spaces(df, column_name):
|
def remove_extra_spaces(df, column_name):
|
||||||
df_transformed = df.withColumn(column_name, F.regexp_replace(F.col(column_name), "\\s+", " "))
|
df_transformed = df.withColumn(column_name, F.regexp_replace(F.col(column_name), "\\s+", " "))
|
||||||
return df_transformed
|
return df_transformed
|
||||||
24
src/test_pyspark_training/conftest.py
Normal file
24
src/test_pyspark_training/conftest.py
Normal file
|
|
@ -0,0 +1,24 @@
|
||||||
|
import os
|
||||||
|
import findspark
|
||||||
|
import logging
|
||||||
|
import pytest
|
||||||
|
from pyspark.sql import SparkSession
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def spark_session(request):
|
||||||
|
os.environ["JAVA_HOME"] = "C:\\Program Files\\Java\\jdk-11"
|
||||||
|
os.environ["SPARK_HOME"] = "C:\\SPARK\\spark-3.1.1-bin-hadoop3.2"
|
||||||
|
os.environ["HADOOP_HOME"] = "C:\\SPARK\\hadoop"
|
||||||
|
findspark.init()
|
||||||
|
|
||||||
|
spark = SparkSession.builder.master("local[*]").getOrCreate()
|
||||||
|
request.addfinalizer(lambda: spark.stop())
|
||||||
|
quiet_py4j()
|
||||||
|
return spark
|
||||||
|
|
||||||
|
|
||||||
|
def quiet_py4j():
|
||||||
|
"""Suppress spark logging for the test context."""
|
||||||
|
logger = logging.getLogger('py4j')
|
||||||
|
logger.setLevel(logging.WARN)
|
||||||
|
|
@ -1,16 +1,15 @@
|
||||||
|
from pyspark_test import assert_pyspark_df_equal
|
||||||
|
|
||||||
|
|
||||||
def assert_df_equal(df1, df2):
|
def assert_df_equal(df1, df2):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
assert df1.schema() == df2.schema()
|
assert df1.schema == df2.schema
|
||||||
except AssertionError:
|
except AssertionError:
|
||||||
print('Error Schema')
|
print('Error Schema')
|
||||||
print(df1.schema())
|
print('df1\n')
|
||||||
print(df1.schema())
|
df1.printSchema()
|
||||||
|
print('df2\n')
|
||||||
|
df2.printSchema()
|
||||||
|
|
||||||
try:
|
assert_pyspark_df_equal(df1, df2)
|
||||||
assert df1.equals(df2)
|
|
||||||
except AssertionError:
|
|
||||||
print('Error Schema')
|
|
||||||
df1.show()
|
|
||||||
df2.show()
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,39 @@
|
||||||
|
from pyspark.sql import types as T
|
||||||
|
from src.test_pyspark_training.lib_test_utils import assert_df_equal
|
||||||
|
from src.pyspark_training.output_dataset_1.remove_extra_spaces import remove_extra_spaces
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_extra_spaces(spark_session):
|
||||||
|
|
||||||
|
input_schema = T.StructType(
|
||||||
|
[
|
||||||
|
T.StructField('name', T.StringType(), False),
|
||||||
|
T.StructField('age', T.IntegerType(), False),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
input_data = [
|
||||||
|
('John D.', 30),
|
||||||
|
('Alice G.', 25),
|
||||||
|
('Bob T.', 35),
|
||||||
|
('Eve A.', 28),
|
||||||
|
]
|
||||||
|
input_df = spark_session.createDataFrame(input_data, input_schema)
|
||||||
|
|
||||||
|
|
||||||
|
expected_schema = T.StructType(
|
||||||
|
[
|
||||||
|
T.StructField('name', T.StringType(), False),
|
||||||
|
T.StructField('age', T.IntegerType(), False),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
expected_data = [
|
||||||
|
('John D.', 30),
|
||||||
|
('Alice G.', 25),
|
||||||
|
('Bob T.', 35),
|
||||||
|
('Eve A.', 28),
|
||||||
|
]
|
||||||
|
expected_df = spark_session.createDataFrame(expected_data, expected_schema)
|
||||||
|
|
||||||
|
df = remove_extra_spaces(input_df, 'name')
|
||||||
|
|
||||||
|
assert_df_equal(df, expected_df)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue