diff --git a/src/test_pyspark_training/example_test.py b/example_test.py similarity index 69% rename from src/test_pyspark_training/example_test.py rename to example_test.py index 0255f62..19b41b4 100644 --- a/src/test_pyspark_training/example_test.py +++ b/example_test.py @@ -1,3 +1,4 @@ - - -def test_example_test(): + + +def test_example_test(): + return 0 \ No newline at end of file diff --git a/init.py b/init.py index e03871a..6413406 100644 --- a/init.py +++ b/init.py @@ -1,34 +1,6 @@ import os import findspark from pyspark.sql import SparkSession -import pyspark.sql.functions as F - -spark = SparkSession.builder.master("local[*]").getOrCreate() - - -sample_data = [ - {"name": "John D.", "age": 30}, - {"name": "Alice G.", "age": 25}, - {"name": "Bob T.", "age": 35}, - {"name": "Eve A.", "age": 28} -] -df = spark.createDataFrame(sample_data) - - - - - -transformed_df = remove_extra_spaces(df, "name") -transformed_df.show() - - -def main(): - init_env() - print("hey there") - - -if __name__ == "__main__": - main() def init_env(): @@ -36,4 +8,22 @@ def init_env(): os.environ["SPARK_HOME"] = "C:\\SPARK\\spark-3.1.1-bin-hadoop3.2" os.environ["HADOOP_HOME"] = "C:\\SPARK\\hadoop" - findspark.init() \ No newline at end of file + findspark.init() + + +def init_spark(): + spark = SparkSession.builder.master("local[*]").getOrCreate() + df = spark.createDataFrame([ + {'name': 'OUI OUI', 'age': 30}, + ]) + df.show() + + +def main(): + print("hey there") + init_env() + init_spark() + + +if __name__ == "__main__": + main() diff --git a/src/pyspark_training/output_dataset_1/remove_extra_space.py b/src/pyspark_training/output_dataset_1/remove_extra_spaces.py similarity index 69% rename from src/pyspark_training/output_dataset_1/remove_extra_space.py rename to src/pyspark_training/output_dataset_1/remove_extra_spaces.py index 0952e86..7edca2f 100644 --- a/src/pyspark_training/output_dataset_1/remove_extra_space.py +++ b/src/pyspark_training/output_dataset_1/remove_extra_spaces.py @@ -1,3 +1,6 @@ -def remove_extra_spaces(df, column_name): - df_transformed = df.withColumn(column_name, F.regexp_replace(F.col(column_name), "\\s+", " ")) - return df_transformed \ No newline at end of file +import pyspark.sql.functions as F + + +def remove_extra_spaces(df, column_name): + df_transformed = df.withColumn(column_name, F.regexp_replace(F.col(column_name), "\\s+", " ")) + return df_transformed diff --git a/src/test_pyspark_training/conftest.py b/src/test_pyspark_training/conftest.py new file mode 100644 index 0000000..a21fbac --- /dev/null +++ b/src/test_pyspark_training/conftest.py @@ -0,0 +1,24 @@ +import os +import findspark +import logging +import pytest +from pyspark.sql import SparkSession + + +@pytest.fixture +def spark_session(request): + os.environ["JAVA_HOME"] = "C:\\Program Files\\Java\\jdk-11" + os.environ["SPARK_HOME"] = "C:\\SPARK\\spark-3.1.1-bin-hadoop3.2" + os.environ["HADOOP_HOME"] = "C:\\SPARK\\hadoop" + findspark.init() + + spark = SparkSession.builder.master("local[*]").getOrCreate() + request.addfinalizer(lambda: spark.stop()) + quiet_py4j() + return spark + + +def quiet_py4j(): + """Suppress spark logging for the test context.""" + logger = logging.getLogger('py4j') + logger.setLevel(logging.WARN) diff --git a/src/test_pyspark_training/lib_test_utils.py b/src/test_pyspark_training/lib_test_utils.py index 06a01c5..c22b46a 100644 --- a/src/test_pyspark_training/lib_test_utils.py +++ b/src/test_pyspark_training/lib_test_utils.py @@ -1,16 +1,15 @@ +from pyspark_test import assert_pyspark_df_equal + def assert_df_equal(df1, df2): try: - assert df1.schema() == df2.schema() + assert df1.schema == df2.schema except AssertionError: print('Error Schema') - print(df1.schema()) - print(df1.schema()) + print('df1\n') + df1.printSchema() + print('df2\n') + df2.printSchema() - try: - assert df1.equals(df2) - except AssertionError: - print('Error Schema') - df1.show() - df2.show() + assert_pyspark_df_equal(df1, df2) diff --git a/src/test_pyspark_training/test_output_dataset_1/test_remove_extra_space/test_remove_extra_space.py b/src/test_pyspark_training/test_output_dataset_1/test_remove_extra_space/test_remove_extra_space.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/test_pyspark_training/test_output_dataset_1/test_remove_extra_space/__init__.py b/src/test_pyspark_training/test_output_dataset_1/test_remove_extra_spaces/__init__.py similarity index 100% rename from src/test_pyspark_training/test_output_dataset_1/test_remove_extra_space/__init__.py rename to src/test_pyspark_training/test_output_dataset_1/test_remove_extra_spaces/__init__.py diff --git a/src/test_pyspark_training/test_output_dataset_1/test_remove_extra_spaces/test_remove_extra_spaces.py b/src/test_pyspark_training/test_output_dataset_1/test_remove_extra_spaces/test_remove_extra_spaces.py new file mode 100644 index 0000000..883bbf3 --- /dev/null +++ b/src/test_pyspark_training/test_output_dataset_1/test_remove_extra_spaces/test_remove_extra_spaces.py @@ -0,0 +1,39 @@ +from pyspark.sql import types as T +from src.test_pyspark_training.lib_test_utils import assert_df_equal +from src.pyspark_training.output_dataset_1.remove_extra_spaces import remove_extra_spaces + + +def test_remove_extra_spaces(spark_session): + + input_schema = T.StructType( + [ + T.StructField('name', T.StringType(), False), + T.StructField('age', T.IntegerType(), False), + ] + ) + input_data = [ + ('John D.', 30), + ('Alice G.', 25), + ('Bob T.', 35), + ('Eve A.', 28), + ] + input_df = spark_session.createDataFrame(input_data, input_schema) + + + expected_schema = T.StructType( + [ + T.StructField('name', T.StringType(), False), + T.StructField('age', T.IntegerType(), False), + ] + ) + expected_data = [ + ('John D.', 30), + ('Alice G.', 25), + ('Bob T.', 35), + ('Eve A.', 28), + ] + expected_df = spark_session.createDataFrame(expected_data, expected_schema) + + df = remove_extra_spaces(input_df, 'name') + + assert_df_equal(df, expected_df)