Init PyTest

This commit is contained in:
Yûki VACHOT 2024-01-05 17:10:52 +01:00
parent c4fdb2860c
commit 25c2e6b7cb
8 changed files with 100 additions and 44 deletions

View file

@ -1,3 +1,4 @@
def test_example_test():
def test_example_test():
return 0

48
init.py
View file

@ -1,34 +1,6 @@
import os
import findspark
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
spark = SparkSession.builder.master("local[*]").getOrCreate()
sample_data = [
{"name": "John D.", "age": 30},
{"name": "Alice G.", "age": 25},
{"name": "Bob T.", "age": 35},
{"name": "Eve A.", "age": 28}
]
df = spark.createDataFrame(sample_data)
transformed_df = remove_extra_spaces(df, "name")
transformed_df.show()
def main():
init_env()
print("hey there")
if __name__ == "__main__":
main()
def init_env():
@ -36,4 +8,22 @@ def init_env():
os.environ["SPARK_HOME"] = "C:\\SPARK\\spark-3.1.1-bin-hadoop3.2"
os.environ["HADOOP_HOME"] = "C:\\SPARK\\hadoop"
findspark.init()
findspark.init()
def init_spark():
spark = SparkSession.builder.master("local[*]").getOrCreate()
df = spark.createDataFrame([
{'name': 'OUI OUI', 'age': 30},
])
df.show()
def main():
print("hey there")
init_env()
init_spark()
if __name__ == "__main__":
main()

View file

@ -1,3 +1,6 @@
def remove_extra_spaces(df, column_name):
df_transformed = df.withColumn(column_name, F.regexp_replace(F.col(column_name), "\\s+", " "))
return df_transformed
import pyspark.sql.functions as F
def remove_extra_spaces(df, column_name):
df_transformed = df.withColumn(column_name, F.regexp_replace(F.col(column_name), "\\s+", " "))
return df_transformed

View file

@ -0,0 +1,24 @@
import os
import findspark
import logging
import pytest
from pyspark.sql import SparkSession
@pytest.fixture
def spark_session(request):
os.environ["JAVA_HOME"] = "C:\\Program Files\\Java\\jdk-11"
os.environ["SPARK_HOME"] = "C:\\SPARK\\spark-3.1.1-bin-hadoop3.2"
os.environ["HADOOP_HOME"] = "C:\\SPARK\\hadoop"
findspark.init()
spark = SparkSession.builder.master("local[*]").getOrCreate()
request.addfinalizer(lambda: spark.stop())
quiet_py4j()
return spark
def quiet_py4j():
"""Suppress spark logging for the test context."""
logger = logging.getLogger('py4j')
logger.setLevel(logging.WARN)

View file

@ -1,16 +1,15 @@
from pyspark_test import assert_pyspark_df_equal
def assert_df_equal(df1, df2):
try:
assert df1.schema() == df2.schema()
assert df1.schema == df2.schema
except AssertionError:
print('Error Schema')
print(df1.schema())
print(df1.schema())
print('df1\n')
df1.printSchema()
print('df2\n')
df2.printSchema()
try:
assert df1.equals(df2)
except AssertionError:
print('Error Schema')
df1.show()
df2.show()
assert_pyspark_df_equal(df1, df2)

View file

@ -0,0 +1,39 @@
from pyspark.sql import types as T
from src.test_pyspark_training.lib_test_utils import assert_df_equal
from src.pyspark_training.output_dataset_1.remove_extra_spaces import remove_extra_spaces
def test_remove_extra_spaces(spark_session):
input_schema = T.StructType(
[
T.StructField('name', T.StringType(), False),
T.StructField('age', T.IntegerType(), False),
]
)
input_data = [
('John D.', 30),
('Alice G.', 25),
('Bob T.', 35),
('Eve A.', 28),
]
input_df = spark_session.createDataFrame(input_data, input_schema)
expected_schema = T.StructType(
[
T.StructField('name', T.StringType(), False),
T.StructField('age', T.IntegerType(), False),
]
)
expected_data = [
('John D.', 30),
('Alice G.', 25),
('Bob T.', 35),
('Eve A.', 28),
]
expected_df = spark_session.createDataFrame(expected_data, expected_schema)
df = remove_extra_spaces(input_df, 'name')
assert_df_equal(df, expected_df)