diff --git a/src/pyspark_training/output_dataset_1/remove_extra_spaces.py b/src/pyspark_training/output_dataset_1/clean_output_dataset_1.py similarity index 51% rename from src/pyspark_training/output_dataset_1/remove_extra_spaces.py rename to src/pyspark_training/output_dataset_1/clean_output_dataset_1.py index 7edca2f..e0addb0 100644 --- a/src/pyspark_training/output_dataset_1/remove_extra_spaces.py +++ b/src/pyspark_training/output_dataset_1/clean_output_dataset_1.py @@ -1,4 +1,16 @@ import pyspark.sql.functions as F +from pyspark.sql import DataFrame + + +def clean_output_dataset_1(df: DataFrame) -> DataFrame: + """ + + :param df: + :return: + """ + df = remove_extra_spaces(df, 'name') + + return df def remove_extra_spaces(df, column_name): diff --git a/src/pyspark_training/output_dataset_1/compute_output_dataset_1.py b/src/pyspark_training/output_dataset_1/compute_output_dataset_1.py new file mode 100644 index 0000000..11002ca --- /dev/null +++ b/src/pyspark_training/output_dataset_1/compute_output_dataset_1.py @@ -0,0 +1,31 @@ +import pyspark.sql.functions as F +from pyspark.sql import DataFrame +from src.pyspark_training.output_dataset_1.clean_output_dataset_1 import clean_output_dataset_1 + + +def compute_output_dataset_1(df: DataFrame) -> DataFrame: + + df = clean_output_dataset_1(df) + + df = add_life_stage(df) + + return df + + +def add_life_stage(df: DataFrame) -> DataFrame: + """ + Add life stage + child if age < 13 + teenager if age >= 13 and <= 19 + adult for age>20 + :param df: + :return: + """ + df = df.withColumn( + 'life_stage', + F.when(F.col('age') < 13, F.lit('child')) + .when(F.col('age').between(13, 19), F.lit('teenager')) + .otherwise(F.lit('adult')) + ) + + return df diff --git a/src/test_pyspark_training/test_output_dataset_1/test_remove_extra_spaces/__init__.py b/src/test_pyspark_training/test_output_dataset_1/test_clean_output_dataset_1/__init__.py similarity index 100% rename from src/test_pyspark_training/test_output_dataset_1/test_remove_extra_spaces/__init__.py rename to src/test_pyspark_training/test_output_dataset_1/test_clean_output_dataset_1/__init__.py diff --git a/src/test_pyspark_training/test_output_dataset_1/test_remove_extra_spaces/test_remove_extra_spaces.py b/src/test_pyspark_training/test_output_dataset_1/test_clean_output_dataset_1/test_remove_extra_spaces.py similarity index 91% rename from src/test_pyspark_training/test_output_dataset_1/test_remove_extra_spaces/test_remove_extra_spaces.py rename to src/test_pyspark_training/test_output_dataset_1/test_clean_output_dataset_1/test_remove_extra_spaces.py index 717dd33..302cfb5 100644 --- a/src/test_pyspark_training/test_output_dataset_1/test_remove_extra_spaces/test_remove_extra_spaces.py +++ b/src/test_pyspark_training/test_output_dataset_1/test_clean_output_dataset_1/test_remove_extra_spaces.py @@ -1,6 +1,6 @@ from pyspark.sql import types as T from src.test_pyspark_training.lib_test_utils import assert_df_equal -from src.pyspark_training.output_dataset_1.remove_extra_spaces import remove_extra_spaces +from src.pyspark_training.output_dataset_1.clean_output_dataset_1 import remove_extra_spaces def test_remove_extra_spaces(spark_session): diff --git a/src/test_pyspark_training/test_output_dataset_1/test_compute_output_dataset_1/__init__.py b/src/test_pyspark_training/test_output_dataset_1/test_compute_output_dataset_1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/test_pyspark_training/test_output_dataset_1/test_compute_output_dataset_1/test_add_life_stage.py b/src/test_pyspark_training/test_output_dataset_1/test_compute_output_dataset_1/test_add_life_stage.py new file mode 100644 index 0000000..de1be0d --- /dev/null +++ b/src/test_pyspark_training/test_output_dataset_1/test_compute_output_dataset_1/test_add_life_stage.py @@ -0,0 +1,43 @@ +from pyspark.sql import types as T +from src.test_pyspark_training.lib_test_utils import assert_df_equal +from src.pyspark_training.output_dataset_1.compute_output_dataset_1 import add_life_stage + + +def test_add_life_stage(spark_session): + + input_schema = T.StructType( + [ + T.StructField('name', T.StringType(), False), + T.StructField('age', T.IntegerType(), False), + ] + ) + input_data = [ + ('Alice G.', 13), + ('John B.', 20), + ('Jack W.', 19), + ('Bob T.', 35), + ('John D.', 9), + ('Eve A.', 12), + ] + input_df = spark_session.createDataFrame(input_data, input_schema) + + expected_schema = T.StructType( + [ + T.StructField('name', T.StringType(), False), + T.StructField('age', T.IntegerType(), False), + T.StructField('life_stage', T.StringType(), False), + ] + ) + expected_data = [ + ('Alice G.', 13, 'teenager'), + ('John B.', 20, 'adult'), + ('Jack W.', 19, 'teenager'), + ('Bob T.', 35, 'adult'), + ('John D.', 9, 'child'), + ('Eve A.', 12, 'child'), + ] + expected_df = spark_session.createDataFrame(expected_data, expected_schema) + + df = add_life_stage(input_df) + + assert_df_equal(df, expected_df)