KEMBAR78
Pyspark Basics | PDF | Apache Spark | Data
0% found this document useful (0 votes)
78 views74 pages

Pyspark Basics

This document provides a comprehensive overview of Apache Spark, focusing on its core concepts and data structures: RDDs, DataFrames, and Datasets. It explains the differences between these structures, their advantages, and how to create and manipulate them using SparkSession and SparkContext. Additionally, it covers the architecture of Spark, including the roles of the driver and executors, and the benefits of using DataFrames for data processing.
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
78 views74 pages

Pyspark Basics

This document provides a comprehensive overview of Apache Spark, focusing on its core concepts and data structures: RDDs, DataFrames, and Datasets. It explains the differences between these structures, their advantages, and how to create and manipulate them using SparkSession and SparkContext. Additionally, it covers the architecture of Spark, including the roles of the driver and executors, and the benefits of using DataFrames for data processing.
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 74

Data engineering is a fascinating and crucial field in today's data-driven world.

As your
expert guide, I've prepared a comprehensive PDF covering the core concepts of Apache
Spark, an essential tool for any data engineer. This document will walk you through Spark's
fundamentals, from its core data structures to advanced transformations and optimizations.

Let's dive in!

Spark Basics
Apache Spark is a powerful open-source, distributed processing system used for big data
workloads. It provides an analytics engine that is significantly faster than Hadoop
MapReduce.

Difference between RDD, DataFrame, and Dataset


Spark offers three main data structures: RDDs (Resilient Distributed Datasets),
DataFrames, and Datasets. Understanding their differences is key to leveraging Spark
effectively.

RDD: Low-level, object-oriented, more control but less optimization


An RDD is the fundamental data structure of Spark. It's a fault-tolerant, immutable,
distributed collection of objects that can be operated on in parallel.
●​ Low-level: RDDs offer the most granular control over your data. You work directly
with Python/Scala/Java objects.
●​ Object-oriented: When you create an RDD, you're essentially creating a distributed
collection of objects of a specific type.
●​ Less Optimization: Spark's internal optimizers have limited visibility into the data's
structure within an RDD. This means RDD operations are often less optimized than
DataFrame or Dataset operations, as Spark doesn't know the schema of the data.
You, as the developer, are responsible for optimizing your RDD transformations.

Example (Python):

Python

●​ from pyspark import SparkContext


●​
●​ # Initialize SparkContext
●​ sc = SparkContext("local", "RDD_Example")
●​
●​ # Create an RDD from a list
●​ data = [1, 2, 3, 4, 5]
●​ rdd = sc.parallelize(data)
●​
●​ # Perform a transformation (e.g., multiply by 2)
●​ rdd_transformed = rdd.map(lambda x: x * 2)
●​
●​ # Collect and print the results
●​ print("RDD transformed data:", rdd_transformed.collect())
●​
●​ sc.stop()

DataFrame: Distributed collection of data with schema; optimized with


Catalyst
A DataFrame is a distributed collection of data organized into named columns.
Conceptually, it's equivalent to a table in a relational database or a data frame in R/Python
(like Pandas).
●​ Schema-aware: Unlike RDDs, DataFrames have a well-defined schema, providing
structure to the data. This schema allows Spark to understand the data's
organization, which is crucial for optimization.
●​ Optimized with Catalyst Optimizer: This is the key advantage of DataFrames.
Spark's Catalyst Optimizer uses the schema information to optimize queries,
performing operations like filter pushdown, predicate pushdown, and column pruning.
This leads to significantly better performance compared to RDDs.
●​ Language Agnostic: DataFrames are available in Scala, Java, Python, and R.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​
●​ # Initialize SparkSession
●​ spark = SparkSession.builder.appName("DataFrame_Example").getOrCreate()
●​
●​ # Create a DataFrame from a list of tuples
●​ data = [("Alice", 1), ("Bob", 2), ("Charlie", 3)]
●​ columns = ["Name", "ID"]
●​ df = spark.createDataFrame(data, columns)
●​
●​ # Show the DataFrame
●​ print("DataFrame:")
●​ df.show()
●​
●​ # Print the schema
●​ print("DataFrame Schema:")
●​ df.printSchema()
●​
●​ # Perform a transformation (e.g., filter)
●​ df_filtered = df.filter(df.ID > 1)
●​ print("Filtered DataFrame:")
●​ df_filtered.show()
●​
●​ spark.stop()

Dataset (Scala/Java only): Strongly typed like RDD, with optimizations


like DataFrame
A Dataset is a strongly-typed collection of JVM objects that combines the best of RDDs and
DataFrames. It's available only in Scala and Java.
●​ Strongly Typed: Like RDDs, Datasets provide compile-time type safety. This means
you get errors at compile time if you try to perform an operation on a column that
doesn't exist or has an incompatible type.
●​ Optimized with Catalyst: Like DataFrames, Datasets benefit from the Catalyst
Optimizer, providing excellent performance.
●​ Encoder: Datasets use encoders to serialize and deserialize JVM objects to and
from Spark's internal Tungsten binary format. This allows for efficient storage and
processing.

Conceptual Example (Scala - illustrating the idea):

Scala

●​ case class Person(name: String, age: Long)


●​
●​ val peopleDF = spark.read.json("examples/src/main/resources/people.json")
●​
●​ // DataFrames can be converted to a Dataset by providing a class.
●​ val peopleDS = peopleDF.as[Person]
●​
●​ peopleDS.filter(_.age > 25).show()

Why DataFrame is preferred (ease of use, Catalyst optimizer, Tungsten


execution)
DataFrames are generally preferred for most Spark workloads due to a combination of
factors:
●​ Ease of Use: DataFrames offer a high-level, SQL-like API that's intuitive and easy to
use, especially for those familiar with relational databases or tools like Pandas. This
abstracts away much of the complexity of distributed processing.
●​ Catalyst Optimizer: This is the primary reason for DataFrame's superior
performance. The Catalyst Optimizer builds a logical and physical plan for your
transformations, then optimizes them by:
○​ Rule-based optimization: Applying rules like predicate pushdown (moving
filters closer to the data source) and column pruning (selecting only
necessary columns).
○​ Cost-based optimization: Choosing the most efficient physical execution
plan based on data statistics.
●​ Tungsten Execution Engine: DataFrames leverage Spark's Tungsten engine, which
performs operations directly on serialized binary data in memory. This reduces
memory overhead and improves CPU utilization, leading to significantly faster
execution. Tungsten includes:
○​ Memory Management: Optimized memory allocation and deallocation.
○​ Code Generation: Generates optimized bytecode for operations, avoiding
costly JVM object creation.
○​ Cache Locality: Improves data access patterns.

In essence, DataFrames strike a perfect balance between developer productivity (ease of


use) and execution performance (optimizations).

SparkSession & SparkContext


To interact with Spark, you need to set up a SparkSession or SparkContext. These are the
entry points to Spark's functionality.

SparkSession.builder.getOrCreate()
The SparkSession is the unified entry point for Spark 2.x and later. It subsumes the
functionality of SparkContext, SQLContext, HiveContext, and StreamingContext.

SparkSession.builder.getOrCreate() is the recommended way to initialize a SparkSession:


●​ builder: Returns a SparkSession.Builder object, allowing you to configure various
Spark properties.
●​ getOrCreate():
○​ If a SparkSession instance already exists, it returns the existing one.
○​ If no SparkSession exists, it creates a new one based on the builder's
configuration.

This method ensures that you always have a single, active SparkSession, preventing issues
with multiple sessions.

Example (Python):

Python
●​ from pyspark.sql import SparkSession
●​
●​ # Build a SparkSession
●​ spark = SparkSession.builder \
●​ .appName("MySparkApp") \
●​ .master("local[*]") \
●​ .config("spark.executor.memory", "2g") \
●​ .getOrCreate()
●​
●​ print("SparkSession created successfully!")
●​
●​ # You can access the SparkContext from the SparkSession
●​ sc = spark.sparkContext
●​ print(f"SparkContext app name: {sc.appName}")
●​
●​ spark.stop()

Configuration: appName, master, config(key, value)


When building a SparkSession, you can configure various properties:
●​ appName(name): Sets a name for your application, which appears in the Spark UI.
This helps identify your application among others running on the cluster.
●​ master(url): Specifies the master URL for the cluster.
○​ local: Runs Spark locally with one thread.
○​ local[*]: Runs Spark locally with as many worker threads as logical cores on
your machine.
○​ spark://host:port: Connects to a standalone Spark cluster.
○​ yarn: Connects to a YARN cluster.
○​ mesos://host:port: Connects to a Mesos cluster.
●​ config(key, value): Sets a Spark configuration property. You can use this for a wide
range of settings, from memory allocation to shuffle behavior.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​
●​ spark = SparkSession.builder \
●​ .appName("ConfiguredSparkApp") \
●​ .master("local[4]") \
●​ .config("spark.sql.shuffle.partitions", "200") # Sets default shuffle partitions
●​ .config("spark.driver.memory", "1g") # Allocates memory for the driver
●​ .getOrCreate()
●​
●​ print("Spark configuration:")
●​ print(f"App Name: {spark.conf.get('spark.app.name')}")
●​ print(f"Master: {spark.conf.get('spark.master')}")
●​ print(f"Shuffle Partitions: {spark.conf.get('spark.sql.shuffle.partitions')}")
●​
●​ spark.stop()

Accessing SparkContext from SparkSession


While SparkSession is the primary entry point, you can still access the underlying
SparkContext instance from a SparkSession using spark.sparkContext.

The SparkContext is responsible for connecting to the Spark cluster, creating RDDs, and
broadcasting variables. Although most new functionalities are exposed through
SparkSession, SparkContext is still important for low-level RDD operations or broadcasting
data.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​
●​ spark =
SparkSession.builder.appName("AccessSparkContext").master("local").getOrCreate()
●​
●​ # Access SparkContext from SparkSession
●​ sc = spark.sparkContext
●​
●​ print(f"SparkContext application ID: {sc.applicationId}")
●​ print(f"SparkContext version: {sc.version}")
●​
●​ # Example of an RDD operation using sc
●​ rdd_data = sc.parallelize([10, 20, 30])
●​ print("RDD sum:", rdd_data.sum())
●​
●​ spark.stop()

Understanding driver vs executors


In a Spark cluster, there are two main components: the driver and executors.
●​ Driver Program:
○​ Runs on the driver node (or the machine where you launched your Spark
application).
○​ Contains the SparkSession (and SparkContext).
○​ Is responsible for converting user code into a DAG (Directed Acyclic Graph)
of RDD/DataFrame transformations.
○​ Schedules tasks on executors.
○​ Maintains information about the cluster, like available executors.
○​ Aggregates results from executors.
●​ Executors:
○​ Are worker processes that run on worker nodes in the cluster.
○​ Are responsible for executing the actual tasks assigned by the driver.
○​ Perform computations, store data (if caching is involved), and return results to
the driver.
○​ Each executor can run multiple tasks concurrently.

Analogy: Think of the driver as the project manager and executors as the workers. The
project manager breaks down the project into smaller tasks, assigns them to workers, and
collects the results. The workers perform the actual work.

Communication Flow:
1.​ User Code to DAG: Your Spark application code (e.g., DataFrame transformations)
is first translated into a logical plan by the driver.
2.​ Logical to Physical Plan: The Catalyst Optimizer then converts the logical plan into
an optimized physical plan.
3.​ Task Scheduling: The driver breaks down the physical plan into stages and then
into individual tasks. It distributes these tasks to the available executors.
4.​ Task Execution: Executors receive tasks, process their assigned portion of the data,
and return results or status updates to the driver.

This distributed architecture allows Spark to process vast amounts of data in parallel.

DataFrame Basics
DataFrames are central to modern Spark applications. Let's explore how to create them and
perform some fundamental operations.

Creating DataFrames using:


Lists
You can easily create a DataFrame from a Python list (or Scala/Java equivalent) by
providing the data and an optional schema.

Example (Python):

Python
●​ from pyspark.sql import SparkSession
●​ from pyspark.sql.types import StructType, StructField, StringType, IntegerType
●​
●​ spark = SparkSession.builder.appName("CreateDFFromList").getOrCreate()
●​
●​ # 1. Using a list of tuples with inferred schema (less explicit but quicker)
●​ data_inferred = [("Alice", 1), ("Bob", 2), ("Charlie", 3)]
●​ columns_inferred = ["Name", "ID"]
●​ df_inferred = spark.createDataFrame(data_inferred, columns_inferred)
●​ print("DataFrame from list (inferred schema):")
●​ df_inferred.show()
●​ df_inferred.printSchema()
●​
●​ print("-" * 30)
●​
●​ # 2. Using a list of rows with explicit schema (recommended for robustness)
●​ data_explicit = [
●​ ("David", 4, "New York"),
●​ ("Eve", 5, "London"),
●​ ("Frank", 6, "Paris")
●​ ]
●​ schema_explicit = StructType([
●​ StructField("Name", StringType(), True),
●​ StructField("Age", IntegerType(), True),
●​ StructField("City", StringType(), True)
●​ ])
●​ df_explicit = spark.createDataFrame(data_explicit, schema=schema_explicit)
●​ print("DataFrame from list (explicit schema):")
●​ df_explicit.show()
●​ df_explicit.printSchema()
●​
●​ spark.stop()

RDDs
You can convert an RDD into a DataFrame, especially if you have existing RDD-based logic
or data. This is often done by converting the RDD to an RDD of Row objects and then
defining a schema.

Example (Python):

Python

●​ from pyspark.sql import SparkSession, Row


●​ from pyspark.sql.types import StructType, StructField, StringType, IntegerType
●​
●​ spark = SparkSession.builder.appName("CreateDFFromRDD").getOrCreate()
●​ sc = spark.sparkContext
●​
●​ # Create an RDD of tuples
●​ rdd_data = sc.parallelize([
●​ ("Alice", 30),
●​ ("Bob", 25),
●​ ("Charlie", 35)
●​ ])
●​
●​ # Define the schema
●​ schema = StructType([
●​ StructField("Name", StringType(), True),
●​ StructField("Age", IntegerType(), True)
●​ ])
●​
●​ # Convert RDD of tuples to RDD of Row objects and then to DataFrame
●​ # Method 1: Infer schema (less control)
●​ df_from_rdd_inferred = spark.createDataFrame(rdd_data, ["Name", "Age"])
●​ print("DataFrame from RDD (inferred schema):")
●​ df_from_rdd_inferred.show()
●​ df_from_rdd_inferred.printSchema()
●​
●​ print("-" * 30)
●​
●​ # Method 2: Explicitly map to Row objects with a defined schema (recommended)
●​ # Each element in the RDD needs to be a Row object with named fields matching the
schema
●​ rdd_rows = rdd_data.map(lambda p: Row(Name=p[0], Age=p[1]))
●​ df_from_rdd_explicit = spark.createDataFrame(rdd_rows, schema)
●​ print("DataFrame from RDD (explicit schema):")
●​ df_from_rdd_explicit.show()
●​ df_from_rdd_explicit.printSchema()
●​
●​ spark.stop()

External files
Spark excels at reading data from various external file formats. CSV, JSON, and Parquet are
common choices.

Example (Python):

To run this example, you'd need sample people.csv and people.json files.

people.csv:
Code snippet

●​ name,age
●​ Alice,30
●​ Bob,25
●​ Charlie,35

people.json:

JSON

●​ {"name":"Alice","age":30}
●​ {"name":"Bob","age":25}
●​ {"name":"Charlie","age":35}

Python

●​ from pyspark.sql import SparkSession


●​
●​ spark = SparkSession.builder.appName("CreateDFFromFiles").getOrCreate()
●​
●​ # Create dummy files for demonstration
●​ # Ensure these files exist in your working directory or provide full paths
●​ csv_data = """name,age
●​ Alice,30
●​ Bob,25
●​ Charlie,35
●​ """
●​ with open("people.csv", "w") as f:
●​ f.write(csv_data)
●​
●​ json_data = """{"name":"Alice","age":30}
●​ {"name":"Bob","age":25}
●​ {"name":"Charlie","age":35}
●​ """
●​ with open("people.json", "w") as f:
●​ f.write(json_data)
●​
●​ # Read CSV file
●​ df_csv = spark.read.csv("people.csv", header=True, inferSchema=True)
●​ print("DataFrame from CSV:")
●​ df_csv.show()
●​ df_csv.printSchema()
●​
●​ print("-" * 30)
●​
●​ # Read JSON file
●​ df_json = spark.read.json("people.json")
●​ print("DataFrame from JSON:")
●​ df_json.show()
●​ df_json.printSchema()
●​
●​ # For Parquet, you'd first write a DataFrame to Parquet, then read it back
●​ # df_csv.write.mode("overwrite").parquet("people.parquet")
●​ # df_parquet = spark.read.parquet("people.parquet")
●​ # print("\nDataFrame from Parquet:")
●​ # df_parquet.show()
●​ # df_parquet.printSchema()
●​
●​ spark.stop()

Basic operations:
Once you have a DataFrame, you can perform various fundamental operations.

select
select() is used to select columns from a DataFrame. You can select specific columns, or
use expressions.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.functions import col
●​
●​ spark = SparkSession.builder.appName("DFBasicOps").getOrCreate()
●​
●​ data = [("Alice", 30, "New York"), ("Bob", 25, "London"), ("Charlie", 35, "Paris")]
●​ columns = ["Name", "Age", "City"]
●​ df = spark.createDataFrame(data, columns)
●​ df.show()
●​
●​ # Select a single column
●​ df.select("Name").show()
●​
●​ # Select multiple columns
●​ df.select("Name", "City").show()
●​
●​ # Select columns using col() function (recommended for robustness)
●​ df.select(col("Name"), col("Age")).show()
●​
●​ # Select a column and alias it
●​ df.select(df.Name.alias("Full_Name"), df.Age).show()
●​
●​ spark.stop()

filter
filter() (or where()) is used to filter rows based on a given condition. It returns a new
DataFrame containing only the rows that satisfy the condition.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.functions import col
●​
●​ spark = SparkSession.builder.appName("DFBasicOps").getOrCreate()
●​
●​ data = [("Alice", 30, "New York"), ("Bob", 25, "London"), ("Charlie", 35, "Paris")]
●​ columns = ["Name", "Age", "City"]
●​ df = spark.createDataFrame(data, columns)
●​ df.show()
●​
●​ # Filter using string expression
●​ df.filter("Age > 28").show()
●​
●​ # Filter using column object
●​ df.filter(col("Age") <= 30).show()
●​
●​ # Filter with multiple conditions (AND)
●​ df.filter((col("Age") > 25) & (col("City") == "New York")).show()
●​
●​ # Filter with multiple conditions (OR)
●​ df.filter((col("Age") < 25) | (col("City") == "Paris")).show()
●​
●​ spark.stop()

where
where() is an alias for filter(). They behave identically.
Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.functions import col
●​
●​ spark = SparkSession.builder.appName("DFBasicOps").getOrCreate()
●​
●​ data = [("Alice", 30, "New York"), ("Bob", 25, "London"), ("Charlie", 35, "Paris")]
●​ columns = ["Name", "Age", "City"]
●​ df = spark.createDataFrame(data, columns)
●​ df.show()
●​
●​ # Using where()
●​ df.where(col("City") == "London").show()
●​
●​ spark.stop()

show
show() displays the contents of the DataFrame in a tabular format. It's very useful for
inspecting your data.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​
●​ spark = SparkSession.builder.appName("DFBasicOps").getOrCreate()
●​
●​ data = [("Alice", 30, "New York"), ("Bob", 25, "London"), ("Charlie", 35, "Paris")]
●​ columns = ["Name", "Age", "City"]
●​ df = spark.createDataFrame(data, columns)
●​
●​ # Show entire DataFrame
●​ print("Entire DataFrame:")
●​ df.show()
●​
●​ # Show first N rows
●​ print("First 2 rows:")
●​ df.show(2)
●​
●​ # Show without truncating column values
●​ print("DataFrame without truncating:")
●​ df.show(truncate=False)
●​
●​ spark.stop()

printSchema
printSchema() displays the schema (column names and their data types) of the DataFrame
in a tree-like format. This is crucial for understanding your data's structure and for
debugging.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.types import StructType, StructField, StringType, IntegerType
●​
●​ spark = SparkSession.builder.appName("DFBasicOps").getOrCreate()
●​
●​ data = [("Alice", 30, "New York"), ("Bob", 25, "London", ["reading", "cycling"]), ("Charlie",
35, "Paris", [])]
●​ schema = StructType([
●​ StructField("Name", StringType(), True),
●​ StructField("Age", IntegerType(), True),
●​ StructField("City", StringType(), True),
●​ StructField("Hobbies", ArrayType(StringType()), True)
●​ ])
●​ df = spark.createDataFrame(data, schema=schema)
●​
●​ print("DataFrame Schema:")
●​ df.printSchema()
●​
●​ spark.stop()

DataFrame Transformations
DataFrame transformations are operations that return a new DataFrame without modifying
the original. They are lazy - meaning they are not executed until an action (like show(),
count(), collect(), write()) is called.

withColumn, drop, alias, selectExpr, rename


withColumn
withColumn() is used to add a new column or replace an existing column in a DataFrame.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.functions import col, lit
●​
●​ spark = SparkSession.builder.appName("DFTransformations").getOrCreate()
●​
●​ data = [("Alice", 30), ("Bob", 25), ("Charlie", 35)]
●​ columns = ["Name", "Age"]
●​ df = spark.createDataFrame(data, columns)
●​ df.show()
●​
●​ # Add a new column 'City' with a literal value
●​ df_with_city = df.withColumn("City", lit("Unknown"))
●​ print("After adding 'City' column:")
●​ df_with_city.show()
●​
●​ # Add a new column 'Age_in_Months' based on an existing column
●​ df_with_months = df.withColumn("Age_in_Months", col("Age") * 12)
●​ print("After adding 'Age_in_Months' column:")
●​ df_with_months.show()
●​
●​ # Replace an existing column (e.g., convert 'Age' to StringType)
●​ df_replaced_age = df.withColumn("Age", col("Age").cast("string"))
●​ print("After replacing 'Age' column type:")
●​ df_replaced_age.printSchema()
●​ df_replaced_age.show()
●​
●​ spark.stop()

drop
drop() is used to remove one or more columns from a DataFrame.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​
●​ spark = SparkSession.builder.appName("DFTransformations").getOrCreate()
●​
●​ data = [("Alice", 30, "New York"), ("Bob", 25, "London"), ("Charlie", 35, "Paris")]
●​ columns = ["Name", "Age", "City"]
●​ df = spark.createDataFrame(data, columns)
●​ df.show()
●​
●​ # Drop a single column
●​ df_no_city = df.drop("City")
●​ print("After dropping 'City' column:")
●​ df_no_city.show()
●​
●​ # Drop multiple columns
●​ df_no_age_city = df.drop("Age", "City")
●​ print("After dropping 'Age' and 'City' columns:")
●​ df_no_age_city.show()
●​
●​ spark.stop()

alias
alias() is used to rename a column or an expression. It's often used within select or
withColumn.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.functions import col
●​
●​ spark = SparkSession.builder.appName("DFTransformations").getOrCreate()
●​
●​ data = [("Alice", 30), ("Bob", 25)]
●​ columns = ["Name", "Age"]
●​ df = spark.createDataFrame(data, columns)
●​ df.show()
●​
●​ # Using alias() in select
●​ df.select(col("Name").alias("Full_Name"), col("Age")).show()
●​
●​ # Using alias() with an expression in select
●​ df.select((col("Age") * 2).alias("Double_Age")).show()
●​
●​ # Using alias() with withColumn (less common as withColumnRenamed is more direct)
●​ df_aliased_age = df.withColumn("Age", col("Age").alias("YearsOld")) # This will actually
just rename the column 'Age' to 'YearsOld' if it's a new column, or if it's replacing it, it will
create a new column named 'Age' with the content of 'YearsOld'
●​ print("Using alias with withColumn (results in a new column if original name is kept):")
●​ df_aliased_age.show()
●​ df_aliased_age.printSchema()
●​
●​ spark.stop()

selectExpr
selectExpr() allows you to use SQL-like expressions to select columns and apply
transformations directly. It's very convenient for complex expressions.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​
●​ spark = SparkSession.builder.appName("DFTransformations").getOrCreate()
●​
●​ data = [("Alice", 30, "New York"), ("Bob", 25, "London")]
●​ columns = ["Name", "Age", "City"]
●​ df = spark.createDataFrame(data, columns)
●​ df.show()
●​
●​ # Select specific columns and rename
●​ df.selectExpr("Name", "Age as Person_Age", "City").show()
●​
●​ # Perform calculations using SQL expressions
●​ df.selectExpr("Name", "Age * 12 as Age_in_Months", "upper(City) as
City_Upper").show()
●​
●​ # Apply conditional logic
●​ df.selectExpr("Name", "CASE WHEN Age > 28 THEN 'Adult' ELSE 'Young' END as
Category").show()
●​
●​ spark.stop()

rename
Spark DataFrames don't have a direct rename method for columns like Pandas. Instead, you
typically use withColumnRenamed() or select() with alias().

withColumnRenamed() is the most common and direct way to rename a single


column.
Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​
●​ spark = SparkSession.builder.appName("DFTransformations").getOrCreate()
●​
●​ data = [("Alice", 30), ("Bob", 25)]
●​ columns = ["Name", "Age"]
●​ df = spark.createDataFrame(data, columns)
●​ df.show()
●​
●​ # Rename a single column using withColumnRenamed
●​ df_renamed = df.withColumnRenamed("Name", "Full_Name")
●​ print("After renaming 'Name' to 'Full_Name':")
●​ df_renamed.show()
●​ df_renamed.printSchema()
●​
●​ # Renaming multiple columns (chaining or using select)
●​ # Chaining withColumnRenamed
●​ df_renamed_multiple = df.withColumnRenamed("Name",
"PersonName").withColumnRenamed("Age", "PersonAge")
●​ print("After renaming multiple columns (chaining):")
●​ df_renamed_multiple.show()
●​
●​ # Using select with alias for multiple renames
●​ df_renamed_select = df.select(
●​ df.Name.alias("Person_Name"),
●​ df.Age.alias("Person_Age")
●​ )
●​ print("After renaming multiple columns (using select with alias):")
●​ df_renamed_select.show()
●​
●​ spark.stop()

Chaining transformations
One of the most powerful features of Spark DataFrames is the ability to chain
transformations. This makes your code more concise, readable, and allows Spark to
optimize the entire sequence of operations more effectively due to lazy evaluation.

Example (Python):

Python
●​ from pyspark.sql import SparkSession
●​ from pyspark.sql.functions import col, lit
●​
●​ spark = SparkSession.builder.appName("ChainingTransformations").getOrCreate()
●​
●​ data = [("Alice", 30, "NY", 75000),
●​ ("Bob", 25, "LD", 60000),
●​ ("Charlie", 35, "NY", 90000),
●​ ("David", 22, "SF", 55000),
●​ ("Eve", 40, "LD", 100000)]
●​ columns = ["Name", "Age", "City_Code", "Salary"]
●​ df = spark.createDataFrame(data, columns)
●​ df.show()
●​
●​ # Chain multiple transformations:
●​ # 1. Filter for age > 25
●​ # 2. Add a new column 'Bonus'
●​ # 3. Select specific columns and rename one
●​ # 4. Show the result
●​ df_processed = df.filter(col("Age") > 25) \
●​ .withColumn("Bonus", col("Salary") * 0.10) \
●​ .select(col("Name"),
●​ col("Age").alias("YearsOld"),
●​ col("Salary"),
●​ col("Bonus"),
●​ lit("Processed").alias("Status")) \
●​ .sort(col("Salary").desc()) # Add a sort for good measure
●​
●​ print("Chained transformations result:")
●​ df_processed.show()
●​
●​ spark.stop()

Creating calculated columns


Calculated columns are new columns derived from existing columns using expressions or
functions.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.functions import col, concat_ws, lit, when
●​
●​ spark = SparkSession.builder.appName("CalculatedColumns").getOrCreate()
●​
●​ data = [("Alice", "Smith", 30, 75000),
●​ ("Bob", "Johnson", 25, 60000),
●​ ("Charlie", "Brown", 35, 90000)]
●​ columns = ["FirstName", "LastName", "Age", "Salary"]
●​ df = spark.createDataFrame(data, columns)
●​ df.show()
●​
●​ # 1. Calculate Full Name
●​ df_full_name = df.withColumn("FullName", concat_ws(" ", col("FirstName"),
col("LastName")))
●​ print("After adding 'FullName':")
●​ df_full_name.show()
●​
●​ # 2. Calculate Annual Bonus (10% of salary if age > 30, else 5%)
●​ df_bonus = df.withColumn("AnnualBonus",
●​ when(col("Age") > 30, col("Salary") * 0.10)
●​ .otherwise(col("Salary") * 0.05))
●​ print("After adding 'AnnualBonus':")
●​ df_bonus.show()
●​
●​ # 3. Combine multiple calculations and selections
●​ df_combined = df.withColumn("FullName", concat_ws(" ", col("FirstName"),
col("LastName"))) \
●​ .withColumn("ExperienceCategory",
●​ when(col("Age") < 28, "Junior")
●​ .when(col("Age") < 35, "Mid-level")
●​ .otherwise("Senior")) \
●​ .select("FullName", "Age", "Salary", "ExperienceCategory")
●​ print("After combining multiple calculations:")
●​ df_combined.show()
●​
●​ spark.stop()

Column Operations
Spark provides a rich set of functions in pyspark.sql.functions that operate on Column
objects, enabling powerful data manipulations.

Using col, lit, expr, when, otherwise


These are fundamental functions for constructing expressions on DataFrames.
●​ col(column_name): References a column by its name. It's the safest way to refer to
columns, especially when column names might conflict with keywords or contain
special characters.
●​ lit(value): Creates a literal column with the given value. Useful for adding constant
values to a DataFrame.
●​ expr(sql_expression_string): Allows you to use SQL expressions directly within
DataFrame operations. Very powerful for complex logic.
●​ when(condition, value): Implements conditional logic (like IF or CASE in SQL). If
condition is true, it returns value.
●​ otherwise(value): Used in conjunction with when() to specify the default value if none
of the when() conditions are met.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.functions import col, lit, expr, when, otherwise
●​
●​ spark = SparkSession.builder.appName("ColumnOperations").getOrCreate()
●​
●​ data = [("Alice", 75, "A"), ("Bob", 88, "B"), ("Charlie", 62, "C"), ("David", 95, "A")]
●​ columns = ["Name", "Score", "Grade"]
●​ df = spark.createDataFrame(data, columns)
●​ df.show()
●​
●​ # Using col(): Select 'Name' and 'Score'
●​ df.select(col("Name"), col("Score")).show()
●​
●​ # Using lit(): Add a 'Status' column with a constant value
●​ df.withColumn("Status", lit("Processed")).show()
●​
●​ # Using expr(): Add a 'DoubleScore' column using SQL expression
●​ df.withColumn("DoubleScore", expr("Score * 2")).show()
●​
●​ # Using when() and otherwise(): Create a 'PassStatus' column
●​ df.withColumn("PassStatus", when(col("Score") >= 70, "Pass").otherwise("Fail")).show()
●​
●​ # Combining them:
●​ df.select(
●​ col("Name"),
●​ col("Score"),
●​ lit("Exam").alias("AssessmentType"), # Add a literal column
●​ expr("Score / 100 * 100").alias("Percentage"), # Calculate percentage
●​ when(col("Grade") == "A", "Excellent") # Conditional logic based on Grade
●​ .when(col("Grade") == "B", "Good")
●​ .otherwise("Needs Improvement")
●​ .alias("Performance")
●​ ).show()
●​
●​ spark.stop()

Creating conditional logic


Conditional logic is essential for data cleaning, transformation, and feature engineering.
when() and otherwise() are your go-to tools for this.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.functions import col, when, lit
●​
●​ spark = SparkSession.builder.appName("ConditionalLogic").getOrCreate()
●​
●​ data = [("Alice", 18, 50000),
●​ ("Bob", 25, 60000),
●​ ("Charlie", 32, 80000),
●​ ("David", 40, 95000),
●​ ("Eve", 16, 40000)]
●​ columns = ["Name", "Age", "Salary"]
●​ df = spark.createDataFrame(data, columns)
●​ df.show()
●​
●​ # Create a 'AgeGroup' column based on Age
●​ df_with_age_group = df.withColumn("AgeGroup",
●​ when(col("Age") < 20, "Teenager")
●​ .when(col("Age") >= 20, "Adult")
●​ .otherwise("Senior")) # This 'otherwise' is technically redundant
here
●​ print("DataFrame with 'AgeGroup':")
●​ df_with_age_group.show()
●​
●​ # Create a 'TaxBracket' column based on Salary
●​ df_with_tax_bracket = df.withColumn("TaxBracket",
●​ when(col("Salary") < 50000, "Low")
●​ .when((col("Salary") >= 50000) & (col("Salary") < 80000),
"Medium")
●​ .otherwise("High"))
●​ print("DataFrame with 'TaxBracket':")
●​ df_with_tax_bracket.show()
●​
●​ spark.stop()
Nesting conditions using multiple when
You can chain multiple when() clauses to create more complex conditional logic, similar to
CASE WHEN ... THEN ... WHEN ... THEN ... ELSE ... END in SQL. The conditions are
evaluated in order. The first when() condition that evaluates to true will have its
corresponding value returned.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.functions import col, when, lit
●​
●​ spark = SparkSession.builder.appName("NestedConditions").getOrCreate()
●​
●​ data = [("Alice", 85),
●​ ("Bob", 72),
●​ ("Charlie", 91),
●​ ("David", 60),
●​ ("Eve", 45)]
●​ columns = ["Student", "Score"]
●​ df = spark.createDataFrame(data, columns)
●​ df.show()
●​
●​ # Assign grades based on score using nested when
●​ df_with_grades = df.withColumn("Grade",
●​ when(col("Score") >= 90, "A")
●​ .when(col("Score") >= 80, "B")
●​ .when(col("Score") >= 70, "C")
●​ .when(col("Score") >= 60, "D")
●​ .otherwise("F")) # Default if no other conditions met
●​ print("DataFrame with assigned Grades:")
●​ df_with_grades.show()
●​
●​ # More complex example: Eligibility for a scholarship
●​ df_with_scholarship = df.withColumn("ScholarshipStatus",
●​ when((col("Score") >= 90) & (col("Student") == "Charlie"), lit("Full
Scholarship"))
●​ .when(col("Score") >= 85, lit("Partial Scholarship"))
●​ .when(col("Score") >= 70, lit("Eligibility Review"))
●​ .otherwise(lit("Not Eligible")))
●​ print("DataFrame with Scholarship Status:")
●​ df_with_scholarship.show()
●​
●​ spark.stop()

Data Types & Schema


Understanding and managing data types and schemas is fundamental for robust data
engineering with Spark.

StructType, StructField, ArrayType, MapType


Spark's pyspark.sql.types module provides classes to define schemas programmatically.
●​ StructType: Represents the schema of a DataFrame, which is a list of StructField
objects. It defines the structure of a row.
●​ StructField: Represents a column within a StructType. It takes three arguments:
○​ name (string): The name of the column.
○​ dataType (DataType): The data type of the column (e.g., StringType(),
IntegerType(), BooleanType()).
○​ nullable (boolean): Whether the column can contain null values. True if
nullable, False otherwise.
●​ ArrayType(elementType, containsNull): Represents an array (list) of elements of a
specific type.
○​ elementType: The data type of the elements in the array.
○​ containsNull: Whether the array can contain null elements.
●​ MapType(keyType, valueType, valueContainsNull): Represents a map (dictionary)
with key-value pairs.
○​ keyType: The data type of the keys.
○​ valueType: The data type of the values.
○​ valueContainsNull: Whether the map values can be null.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.types import StructType, StructField, StringType, IntegerType,
ArrayType, MapType
●​
●​ spark = SparkSession.builder.appName("DataTypesSchema").getOrCreate()
●​
●​ # Define a complex schema
●​ schema = StructType([
●​ StructField("employee_id", IntegerType(), False), # Not nullable
●​ StructField("name", StructType([ # Nested StructType
●​ StructField("first", StringType(), True),
●​ StructField("last", StringType(), True)
●​ ]), True),
●​ StructField("skills", ArrayType(StringType(), True), True), # Array of strings
●​ StructField("contact_info", MapType(StringType(), StringType(), True), True) # Map
with string keys and values
●​ ])
●​
●​ # Create data that conforms to the schema
●​ data = [(1, ("Alice", "Smith"), ["Python", "Spark"], {"email": "alice@example.com",
"phone": "123-456-7890"}),
●​ (2, ("Bob", "Johnson"), ["Java", "SQL"], {"email": "bob@example.com"}),
●​ (3, ("Charlie", None), [], {}) # Example with null in nested struct, empty array, empty
map
●​ ]
●​
●​ df = spark.createDataFrame(data, schema)
●​
●​ print("DataFrame with complex schema:")
●​ df.show(truncate=False)
●​ df.printSchema()
●​
●​ spark.stop()

inferSchema vs defining schema manually


When reading data from external files, Spark can either infer the schema or you can provide
it manually.
●​ inferSchema=True (Automatic Schema Inference):
○​ Spark samples a portion of the data to guess the column names and data
types.
○​ Pros: Convenient, less manual work, good for quick exploration.
○​ Cons: Can be slow (requires an extra pass over the data). May infer incorrect
types (e.g., all strings if data is dirty, or integer instead of string for IDs). Not
suitable for production pipelines where schema stability is critical.
○​ Usage: spark.read.csv("path/to/file.csv", header=True, inferSchema=True)
●​ Defining Schema Manually:
○​ You explicitly define the StructType and StructField for each column.
○​ Pros: Fast (no extra pass for inference). Robust (prevents unexpected type
changes due to data variations). Enables early detection of schema
mismatches. Recommended for production ETL.
○​ Cons: More verbose, requires knowing the schema beforehand.
○​ Usage: spark.read.csv("path/to/file.csv", header=True, schema=my_schema)

Example (Python):
Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.types import StructType, StructField, StringType, IntegerType
●​
●​ spark = SparkSession.builder.appName("InferVsManualSchema").getOrCreate()
●​
●​ # Create a dummy CSV file
●​ csv_content = """id,name,age,city
●​ 1,Alice,30,New York
●​ 2,Bob,25,London
●​ 3,Charlie,35,Paris
●​ 4,David,NULL,Berlin
●​ """
●​ with open("people_data.csv", "w") as f:
●​ f.write(csv_content)
●​
●​ # 1. Infer Schema
●​ df_inferred = spark.read.csv("people_data.csv", header=True, inferSchema=True)
●​ print("DataFrame with Inferred Schema:")
●​ df_inferred.printSchema()
●​ df_inferred.show()
●​
●​ # Notice 'age' might be inferred as IntegerType, but if 'NULL' is present it might become
StringType or FloatType for missing values.
●​
●​ # 2. Define Schema Manually
●​ manual_schema = StructType([
●​ StructField("id", IntegerType(), False),
●​ StructField("name", StringType(), True),
●​ StructField("age", IntegerType(), True), # Define as IntegerType, expecting nulls
●​ StructField("city", StringType(), True)
●​ ])
●​
●​ df_manual = spark.read.csv("people_data.csv", header=True, schema=manual_schema)
●​ print("\nDataFrame with Manual Schema:")
●​ df_manual.printSchema()
●​ df_manual.show()
●​
●​ spark.stop()

cast() for type conversions


The cast() function is used to convert a column from one data type to another. This is crucial
for data cleaning and ensuring data integrity.
Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.functions import col
●​ from pyspark.sql.types import IntegerType, StringType, DoubleType, DateType
●​
●​ spark = SparkSession.builder.appName("TypeConversions").getOrCreate()
●​
●​ data = [("1", "10.5", "2023-01-15"),
●​ ("2", "20.7", "2023-02-20"),
●​ ("3", "30.9", "2023-03-25")]
●​ columns = ["id_str", "price_str", "date_str"]
●​ df = spark.createDataFrame(data, columns)
●​ print("Original DataFrame Schema:")
●​ df.printSchema()
●​ df.show()
●​
●​ # Cast 'id_str' to IntegerType
●​ df_casted = df.withColumn("id_int", col("id_str").cast(IntegerType()))
●​
●​ # Cast 'price_str' to DoubleType
●​ df_casted = df_casted.withColumn("price_double", col("price_str").cast(DoubleType()))
●​
●​ # Cast 'date_str' to DateType
●​ df_casted = df_casted.withColumn("date_date", col("date_str").cast(DateType()))
●​
●​ print("\nDataFrame After Casts:")
●​ df_casted.printSchema()
●​ df_casted.show()
●​
●​ # Handling invalid casts:
●​ # If a value cannot be cast, it will result in a null.
●​ data_invalid = [("1", "abc"), ("2", "123")]
●​ df_invalid = spark.createDataFrame(data_invalid, ["id", "num_str"])
●​ print("\nDataFrame with potential invalid casts:")
●​ df_invalid.show()
●​ df_invalid.withColumn("num_int", col("num_str").cast(IntegerType())).show()
●​
●​ spark.stop()

Schema evolution awareness


Schema evolution refers to the ability to handle changes in the schema of data over time.
This is a common challenge in big data systems, especially when dealing with
semi-structured data sources or streaming data.

Spark, particularly when working with formats like Parquet (which has built-in schema
evolution support), is generally good at handling schema evolution.

Common Scenarios and How Spark Handles Them:


1.​ Adding new columns: If new columns are added to the source data, Spark will
typically include them in the DataFrame's schema. Older records (without the new
columns) will have null values for those columns.
2.​ Dropping columns: If columns are removed from the source, Spark will ignore them
if you read with an existing schema, or the schema will update if inferred.
3.​ Changing data types: This is the trickiest. If a data type changes in a way that is
compatible (e.g., Integer to Long), Spark might handle it. However, incompatible
changes (e.g., String to Integer where strings aren't valid numbers) will likely lead to
null values or errors.
4.​ Reordering columns: Spark (especially Parquet) generally handles column
reordering gracefully as it identifies columns by name, not by position.

Best Practices for Schema Evolution:


●​ Use Parquet or Avro: These formats are designed with schema evolution in mind.
They store schema information with the data and can handle additive changes
efficiently.
●​ Define Schema Manually (when possible): While inferSchema is convenient,
explicitly defining your schema gives you more control and helps catch unexpected
schema changes early.
●​ Handle Nulls: Be prepared for null values when new columns are added.
●​ Version Control Schemas: In complex data pipelines, consider versioning your
schemas to track changes.
●​ Monitor Data Quality: Implement data quality checks to detect unexpected schema
changes or data type issues.

Example (Conceptual):

Imagine you have a Parquet file with (id INT, name STRING).

Later, new data arrives with (id INT, name STRING, age INT).

When you read the combined data with Spark, the resulting DataFrame will have (id INT,
name STRING, age INT). The older records will have null for age.

This "awareness" means you, as the data engineer, should anticipate these changes and
design your pipeline to be resilient to them.

Reading & Writing Files


Spark provides robust capabilities for reading from and writing to various file formats, which
are essential for data ingestion and persistence.

File formats: CSV, JSON, Parquet


Spark supports a multitude of file formats, with CSV, JSON, and Parquet being among the
most common.
●​ CSV (Comma Separated Values):
○​ Pros: Human-readable, widely supported, simple.
○​ Cons: No schema enforcement, no built-in compression, difficult to handle
nested data, performance issues with large datasets (requires full scan to
infer schema).
○​ Use cases: Small to medium datasets, data exchange with other systems,
ad-hoc analysis.
●​ JSON (JavaScript Object Notation):
○​ Pros: Semi-structured, human-readable, good for nested data.
○​ Cons: No schema enforcement (can be flexible but also error-prone), less
efficient for large-scale analytical queries compared to columnar formats,
requires more parsing overhead.
○​ Use cases: Log data, API responses, data interchange where schema is
flexible.
●​ Parquet:
○​ Pros: Columnar storage format (stores data column by column), highly
efficient for analytical queries (reading only necessary columns), built-in
compression and encoding, supports complex nested data structures,
schema evolution support, optimized for Spark.
○​ Cons: Not human-readable, requires Spark or other Parquet-aware engines
to read efficiently.
○​ Use cases: Highly recommended for big data storage and analytical
workloads in Spark. Data lakes, persistent storage for ETL intermediate
stages.

Read options:
When reading files, you can specify various options to control how Spark interprets the data.
●​ header: True if the first row is a header, False otherwise. (Default: False)
●​ inferSchema: True to let Spark infer the schema, False to read all columns as
StringType (or provide manual schema). (Default: False)
●​ multiline: True to read multi-line JSON records (a single JSON object spanning
multiple lines), False for one JSON object per line. (Default: False for JSON). Not
applicable for CSV.
●​ sep: Specifies the column delimiter for CSV files. (Default: ,)
●​ compression: Codec to use for compression when reading. E.g., gzip, snappy, lz4,
bzip2. Spark often infers this from the file extension.
Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​
●​ spark = SparkSession.builder.appName("ReadWriteFiles").getOrCreate()
●​
●​ # Create dummy files for demonstration
●​ # CSV
●​ csv_data = """id|name|age
●​ 1|Alice|30
●​ 2|Bob|25
●​ 3|Charlie|35
●​ """
●​ with open("data.csv", "w") as f:
●​ f.write(csv_data)
●​
●​ # JSON (multi-line record)
●​ json_data = """{"id":1, "name":"Alice", "details":{"age":30, "city":"NY"}}
●​ {"id":2, "name":"Bob", "details":{"age":25, "city":"LD"}}
●​ """
●​ with open("data.json", "w") as f:
●​ f.write(json_data)
●​
●​ # CSV read options
●​ print("Reading CSV with header and custom separator:")
●​ df_csv = spark.read.csv("data.csv", header=True, inferSchema=True, sep="|")
●​ df_csv.printSchema()
●​ df_csv.show()
●​
●​ # JSON read options
●​ print("\nReading JSON with multiline:")
●​ df_json = spark.read.json("data.json", multiline=False) # Example, even though this
JSON is single line per record
●​ df_json.printSchema()
●​ df_json.show(truncate=False)
●​
●​ # To demonstrate compression, let's write then read a gzipped CSV
●​ df_csv.write.option("compression", "gzip").csv("data_compressed.csv.gz",
mode="overwrite", header=True)
●​ print("\nReading compressed CSV:")
●​ df_compressed = spark.read.option("compression",
"gzip").csv("data_compressed.csv.gz", header=True, inferSchema=True)
●​ df_compressed.printSchema()
●​ df_compressed.show()
●​
●​ spark.stop()
Write options:
When writing DataFrames to files, you have control over the output behavior.
●​ mode(overwrite|append|ignore|errorIfExists): Specifies the save mode.
○​ overwrite: Overwrites the existing data (if any).
○​ append: Appends the new data to the existing data.
○​ ignore: If data already exists, the write operation does nothing.
○​ errorIfExists: If data already exists, an error is thrown. (Default)
●​ partitionBy(column_names): Partitions the output data by the values of specified
columns. This creates subdirectories in the output path. Improves query performance
by allowing Spark to skip scanning irrelevant partitions.
●​ bucketBy(num_buckets, column_names): Buckets the output data by hashing the
specified columns into a fixed number of buckets. Improves join performance and
sampling. Requires saving as a table (e.g., Hive table).
●​ .option() vs .options():
○​ .option(key, value): Sets a single option.
○​ .options(**kwargs): Sets multiple options using keyword arguments (Python
specific).

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​
●​ spark = SparkSession.builder.appName("WriteOptions").getOrCreate()
●​
●​ data = [("Alice", 30, "NY"),
●​ ("Bob", 25, "LD"),
●​ ("Charlie", 35, "NY"),
●​ ("David", 22, "SF")]
●​ columns = ["Name", "Age", "City"]
●​ df = spark.createDataFrame(data, columns)
●​ df.show()
●​
●​ output_base_path = "output_data"
●​
●​ # Write to CSV with overwrite mode
●​ print("Writing CSV with overwrite mode...")
●​ df.write.mode("overwrite").csv(f"{output_base_path}/csv_output_overwrite", header=True)
●​
●​ # Write to Parquet, partitioned by 'City'
●​ print("\nWriting Parquet partitioned by 'City'...")
●​ df.write.mode("overwrite").partitionBy("City").parquet(f"{output_base_path}/parquet_partiti
oned")
●​
●​ # Write to JSON with append mode (demonstrating mode, typically overwrite for first
write)
●​ # To demonstrate append, run this twice. The first time, it creates the file.
●​ # The second time, it appends.
●​ print("\nWriting JSON with append mode...")
●​ df.write.mode("append").json(f"{output_base_path}/json_output_append")
●​
●​
●​ # Using .options() (Python specific)
●​ print("\nWriting CSV using .options():")
●​ df.write.options(header=True, sep=",",
mode="overwrite").csv(f"{output_base_path}/csv_output_options")
●​
●​ print("\nWrite operations completed. Check your 'output_data' directory.")
●​
●​ spark.stop()

DataFrame Joins
Joining DataFrames is a fundamental operation in data engineering to combine data from
different sources based on common columns.

Types of joins:
Spark supports various types of joins, similar to SQL.
●​ inner (default): Returns rows when there is a match in both DataFrames.
●​ left (or left_outer): Returns all rows from the left DataFrame, and the matched rows
from the right DataFrame. If no match, nulls are introduced for columns from the right
DataFrame.
●​ right (or right_outer): Returns all rows from the right DataFrame, and the matched
rows from the left DataFrame. If no match, nulls are introduced for columns from the
left DataFrame.
●​ outer (or full_outer): Returns all rows when there is a match in one of the
DataFrames. If no match, nulls are introduced for columns from the non-matching
DataFrame.
●​ left_semi: Returns rows from the left DataFrame where there is a match in the right
DataFrame. Only returns columns from the left DataFrame. It's like an INNER
JOIN but only selecting columns from the left.
●​ left_anti: Returns rows from the left DataFrame where there is no match in the right
DataFrame. Useful for finding missing records or non-existent relationships.
Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.functions import col
●​
●​ spark = SparkSession.builder.appName("DataFrameJoins").getOrCreate()
●​
●​ # Create two DataFrames
●​ employees_data = [("Alice", 1, "HR"),
●​ ("Bob", 2, "Sales"),
●​ ("Charlie", 3, "IT"),
●​ ("David", 4, "Marketing")]
●​ employees_columns = ["Name", "EmpID", "Department"]
●​ employees_df = spark.createDataFrame(employees_data, employees_columns)
●​ employees_df.show()
●​
●​ departments_data = [(1, "HR", "New York"),
●​ (2, "Sales", "London"),
●​ (5, "Finance", "Paris")] # EmpID 5 has no matching employee
●​ departments_columns = ["DeptID", "DeptName", "Location"]
●​ departments_df = spark.createDataFrame(departments_data, departments_columns)
●​ departments_df.show()
●​
●​ # Inner Join
●​ print("\nInner Join:")
●​ employees_df.join(departments_df, employees_df.Department ==
departments_df.DeptName, "inner").show()
●​
●​ # Left Join
●​ print("\nLeft Join:")
●​ employees_df.join(departments_df, employees_df.Department ==
departments_df.DeptName, "left").show()
●​
●​ # Right Join
●​ print("\nRight Join:")
●​ employees_df.join(departments_df, employees_df.Department ==
departments_df.DeptName, "right").show()
●​
●​ # Full Outer Join
●​ print("\nFull Outer Join:")
●​ employees_df.join(departments_df, employees_df.Department ==
departments_df.DeptName, "full_outer").show()
●​
●​ # Left Semi Join
●​ print("\nLeft Semi Join (only columns from left, where match exists):")
●​ employees_df.join(departments_df, employees_df.Department ==
departments_df.DeptName, "left_semi").show()
●​
●​ # Left Anti Join
●​ print("\nLeft Anti Join (rows in left NOT in right):")
●​ employees_df.join(departments_df, employees_df.Department ==
departments_df.DeptName, "left_anti").show()
●​
●​ spark.stop()

Handling duplicate columns


When joining DataFrames on common column names, Spark can produce duplicate
columns in the result if not handled carefully.
●​ Specify join condition explicitly: The safest way is to use a boolean expression
(e.g., df1.col == df2.col) as the join condition. This keeps both columns. You can then
drop one.
●​ Use on parameter with a string or list of strings: If you join on columns with the
same name in both DataFrames, Spark will automatically combine them into a single
column.
●​ Alias columns before joining: Rename one of the conflicting columns before the
join.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.functions import col
●​
●​ spark = SparkSession.builder.appName("HandlingDuplicateColumns").getOrCreate()
●​
●​ df1_data = [("A", 10), ("B", 20)]
●​ df1_cols = ["ID", "Value1"]
●​ df1 = spark.createDataFrame(df1_data, df1_cols)
●​ df1.show()
●​ df1.printSchema()
●​
●​ df2_data = [("A", 100), ("B", 200)]
●​ df2_cols = ["ID", "Value2"]
●​ df2 = spark.createDataFrame(df2_data, df2_cols)
●​ df2.show()
●​ df2.printSchema()
●​
●​ # Case 1: Join on common column name using a string. Spark handles it.
●​ print("\nJoin on 'ID' string (Spark handles duplicate 'ID'):")
●​ df_joined_str = df1.join(df2, "ID") # or on=["ID"]
●​ df_joined_str.show()
●​ df_joined_str.printSchema() # Notice 'ID' appears only once
●​
●​ # Case 2: Join on common column name using a boolean expression. 'ID' is duplicated.
●​ print("\nJoin on 'ID' using boolean expression (duplicates 'ID'):")
●​ df_joined_bool = df1.join(df2, df1.ID == df2.ID)
●​ df_joined_bool.show()
●​ df_joined_bool.printSchema() # Notice 'ID' appears twice (df1.ID, df2.ID)
●​
●​ # To resolve duplicate columns after boolean join:
●​ # Option A: Drop one of the duplicate columns
●​ print("\nAfter dropping one of the duplicate 'ID' columns:")
●​ df_joined_bool.drop(df2.ID).show()
●​
●​ # Option B: Select specific columns and alias
●​ print("\nSelecting and aliasing to resolve duplicates:")
●​ df_resolved = df1.join(df2, df1.ID == df2.ID) \
●​ .select(df1.ID, df1.Value1, df2.Value2)
●​ df_resolved.show()
●​ df_resolved.printSchema()
●​
●​ # Case 3: Alias column before joining
●​ print("\nAlias 'ID' in df2 before joining:")
●​ df2_renamed = df2.withColumnRenamed("ID", "ID_df2")
●​ df_aliased_join = df1.join(df2_renamed, df1.ID == df2_renamed.ID_df2)
●​ df_aliased_join.show()
●​ df_aliased_join.printSchema() # Now ID and ID_df2 are distinct
●​
●​ spark.stop()

Broadcast joins (and when it triggers)


A broadcast join (also known as a map-side join) is an optimization technique in Spark for
joins involving a small DataFrame and a large DataFrame.

How it works:

Spark "broadcasts" the smaller DataFrame to all executor nodes. Each executor then holds
a copy of the smaller DataFrame in memory. When the join operation occurs, each partition
of the larger DataFrame can directly join with the broadcasted smaller DataFrame without
requiring a shuffle of the large DataFrame.

Benefits:
●​ Significant performance improvement: Eliminates the expensive shuffle phase for
the larger DataFrame, which is typically the bottleneck in joins.
●​ Reduced network I/O: Less data needs to be transferred across the network.

When it triggers:

Spark's Catalyst Optimizer automatically decides whether to perform a broadcast join based
on the size of the DataFrames involved.

●​ spark.sql.autoBroadcastJoinThreshold: This configuration property determines


the maximum size (in bytes) of a DataFrame that will be broadcast. The default value
is typically 10MB (10 * 1024 * 1024 bytes).
○​ If a DataFrame's size is less than or equal to this threshold, Spark might
broadcast it.
●​ Manually forcing broadcast: You can explicitly tell Spark to broadcast a DataFrame
using pyspark.sql.functions.broadcast(). This is useful if Spark's automatic threshold
isn't ideal for your specific use case, or if you know a DataFrame is small.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.functions import broadcast
●​
●​ spark = SparkSession.builder.appName("BroadcastJoin").getOrCreate()
●​
●​ # Configure the broadcast join threshold (e.g., to 1MB for demonstration)
●​ spark.conf.set("spark.sql.autoBroadcastJoinThreshold", 1 * 1024 * 1024) # 1MB
●​
●​ # Create a small DataFrame (will be broadcasted)
●​ small_df_data = [(1, "Apple"), (2, "Banana"), (3, "Orange")]
●​ small_df_cols = ["ID", "Fruit"]
●​ small_df = spark.createDataFrame(small_df_data, small_df_cols)
●​
●​ # Create a large DataFrame (simulated by generating more rows)
●​ large_df_data = [(i % 3 + 1, f"User_{i}") for i in range(100000)] # 100,000 rows
●​ large_df_cols = ["FruitID", "UserName"]
●​ large_df = spark.createDataFrame(large_df_data, large_df_cols)
●​
●​ # Perform the join - Spark will likely broadcast small_df automatically
●​ print("Automatic Broadcast Join (check Spark UI for details):")
●​ joined_df_auto = large_df.join(small_df, large_df.FruitID == small_df.ID, "inner")
●​ joined_df_auto.show(5)
●​ joined_df_auto.explain() # Look for "BroadcastHashJoin" or "BroadcastNestedLoopJoin"
in the plan
●​
●​ # Manually force broadcast (even if it's larger than threshold, use with caution)
●​ print("\nManual Broadcast Join (forced):")
●​ joined_df_forced = large_df.join(broadcast(small_df), large_df.FruitID == small_df.ID,
"inner")
●​ joined_df_forced.show(5)
●​ joined_df_forced.explain() # Should explicitly show BroadcastHashJoin
●​
●​ # Reset the threshold to default or stop SparkSession
●​ spark.conf.set("spark.sql.autoBroadcastJoinThreshold", 10 * 1024 * 1024) # Reset to
default 10MB
●​ spark.stop()

Key Takeaway: Always aim for broadcast joins when one side of the join is significantly
smaller than the other. Monitor Spark UI to confirm if broadcast joins are happening as
expected.

Aggregations & GroupBy


Aggregations allow you to summarize data, and groupBy() is used to perform these
aggregations on groups of rows.

groupBy().agg(), count(), sum(), avg(), min(), max()


●​ groupBy(*cols): Groups the DataFrame by one or more columns. This returns a
GroupedData object, on which you can apply aggregation functions.
●​ agg(*exprs): Applies one or more aggregation functions to the grouped data. You
can pass aggregate functions (like count(), sum(), avg(), min(), max()) directly or
define custom aggregations.
●​ count(): Returns the number of items in a group.
●​ sum(col): Returns the sum of values in a numeric column.
●​ avg(col): Returns the average of values in a numeric column.
●​ min(col): Returns the minimum value in a column.
●​ max(col): Returns the maximum value in a column.

These aggregate functions are available in pyspark.sql.functions.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.functions import col, count, sum, avg, min, max
●​
●​ spark = SparkSession.builder.appName("Aggregations").getOrCreate()
●​
●​ data = [("A", "Sales", 1000),
●​ ("B", "Sales", 1500),
●​ ("A", "HR", 800),
●​ ("C", "IT", 2000),
●​ ("B", "HR", 1200),
●​ ("C", "Sales", 1800)]
●​ columns = ["Employee", "Department", "Salary"]
●​ df = spark.createDataFrame(data, columns)
●​ df.show()
●​
●​ # Count of employees per department
●​ print("\nCount of employees per department:")
●​ df.groupBy("Department").count().show()
●​
●​ # Sum of salary per department
●​ print("\nSum of salary per department:")
●​ df.groupBy("Department").sum("Salary").show()
●​
●​ # Average salary per department
●​ print("\nAverage salary per department:")
●​ df.groupBy("Department").avg("Salary").show()
●​
●​ # Min and Max salary per department
●​ print("\nMin and Max salary per department:")
●​ df.groupBy("Department").min("Salary").max("Salary").show()
●​
●​ # Using agg() for single aggregation (equivalent to direct agg function)
●​ print("\nUsing agg() for sum of salary:")
●​ df.groupBy("Department").agg(sum("Salary")).show()
●​
●​ spark.stop()

Aggregation over multiple columns


You can group by multiple columns and perform aggregations.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.functions import col, count, sum, avg
●​
●​ spark = SparkSession.builder.appName("MultiColumnAggregation").getOrCreate()
●​
●​ data = [("A", "Sales", "NY", 1000),
●​ ("B", "Sales", "LD", 1500),
●​ ("A", "HR", "NY", 800),
●​ ("C", "IT", "SF", 2000),
●​ ("B", "HR", "LD", 1200),
●​ ("C", "Sales", "NY", 1800),
●​ ("A", "Sales", "NY", 900)]
●​ columns = ["Employee", "Department", "City", "Salary"]
●​ df = spark.createDataFrame(data, columns)
●​ df.show()
●​
●​ # Count of employees per Department and City
●​ print("\nCount of employees per Department and City:")
●​ df.groupBy("Department", "City").count().show()
●​
●​ # Sum and average of salary per Department and City
●​ print("\nSum and average of salary per Department and City:")
●​ df.groupBy("Department", "City") \
●​ .agg(sum("Salary").alias("TotalSalary"),
●​ avg("Salary").alias("AverageSalary")) \
●​ .show()
●​
●​ spark.stop()

Using agg() with multiple aggregations


The agg() function is powerful because it allows you to apply multiple aggregation functions
in a single call, and also to rename the output columns.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.functions import col, count, sum, avg, min, max
●​
●​ spark = SparkSession.builder.appName("MultipleAggregations").getOrCreate()
●​
●​ data = [("A", "Sales", 1000),
●​ ("B", "Sales", 1500),
●​ ("A", "HR", 800),
●​ ("C", "IT", 2000),
●​ ("B", "HR", 1200),
●​ ("C", "Sales", 1800)]
●​ columns = ["Employee", "Department", "Salary"]
●​ df = spark.createDataFrame(data, columns)
●​ df.show()
●​
●​ # Perform multiple aggregations on 'Salary' grouped by 'Department'
●​ print("\nMultiple aggregations on 'Salary' per 'Department':")
●​ df.groupBy("Department") \
●​ .agg(count("Employee").alias("EmployeeCount"),
●​ sum("Salary").alias("TotalSalary"),
●​ avg("Salary").alias("AverageSalary"),
●​ min("Salary").alias("MinSalary"),
●​ max("Salary").alias("MaxSalary")) \
●​ .show()
●​
●​ # You can also aggregate without a groupBy (applies to the whole DataFrame)
●​ print("\nAggregations on entire DataFrame:")
●​ df.agg(count("Employee").alias("TotalEmployees"),
●​ sum("Salary").alias("GrandTotalSalary"),
●​ avg("Salary").alias("OverallAverageSalary")).show()
●​
●​ spark.stop()

Pivot operations
Pivot operations (or cross-tabulations) transform a row into a column. They are used to
aggregate data by one column and spread the unique values of another column into new
columns, with aggregated values.

Steps for Pivoting:


1.​ groupBy(): Group the DataFrame by the column(s) that will remain as rows.
2.​ pivot(pivot_column): Specify the column whose unique values will become new
columns.
3.​ agg(agg_function): Apply an aggregation function to the values that will populate the
new pivoted columns.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.functions import sum, avg
●​
●​ spark = SparkSession.builder.appName("PivotOperations").getOrCreate()
●​
●​ data = [("RegionA", "ProductX", 100),
●​ ("RegionA", "ProductY", 150),
●​ ("RegionB", "ProductX", 200),
●​ ("RegionB", "ProductY", 120),
●​ ("RegionA", "ProductX", 50),
●​ ("RegionC", "ProductY", 300)]
●​ columns = ["Region", "Product", "Sales"]
●​ df = spark.createDataFrame(data, columns)
●​ df.show()
●​
●​ # Pivot 'Product' column to new columns, aggregating 'Sales' by 'Region'
●​ print("\nPivot on 'Product' to show sales per region per product:")
●​ df.groupBy("Region").pivot("Product").agg(sum("Sales")).show()
●​
●​ # You can also specify the exact pivot values if you know them,
●​ # which can improve performance for large number of unique pivot column values.
●​ print("\nPivot with specific product values:")
●​ df.groupBy("Region").pivot("Product", ["ProductX", "ProductY",
"ProductZ"]).agg(sum("Sales")).show()
●​
●​ # Pivoting with another aggregation (e.g., average)
●​ print("\nPivot with average sales:")
●​ df.groupBy("Region").pivot("Product").agg(avg("Sales")).show()
●​
●​ # Multiple aggregations with pivot is not directly supported in the pivot()
●​ # method itself. You would typically do one pivot then join, or use `cube`/`rollup`
●​ # or subsequent transformations if you need multiple aggregates on the pivoted columns.
●​
●​ spark.stop()

Window Functions
Window functions perform calculations across a set of table rows that are somehow related
to the current row. Unlike aggregate functions that return a single value for a group, window
functions return a value for each row.

Window specification: partitionBy, orderBy


To use a window function, you first need to define a window specification using
Window.partitionBy() and Window.orderBy().
●​ partitionBy(*cols): Divides the rows into groups (partitions) based on the specified
columns. The window function is applied independently to each partition. If omitted,
the entire DataFrame is treated as a single partition.
●​ orderBy(*cols): Orders the rows within each partition. This ordering is crucial for
functions like row_number(), rank(), lag(), lead(), which depend on the order of rows.
●​ rowsBetween(start, end) / rangeBetween(start, end): Defines the frame of rows
within a partition that the window function operates on.
○​ Window.unboundedPreceding: From the beginning of the partition.
○​ Window.currentRow: The current row.
○​ Window.unboundedFollowing: To the end of the partition.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.functions import col, row_number, rank, dense_rank, lag, lead, sum, avg
●​ from pyspark.sql.window import Window
●​
●​ spark = SparkSession.builder.appName("WindowFunctions").getOrCreate()
●​
●​ data = [("Sales", "Alice", 2023, 100),
●​ ("Sales", "Bob", 2023, 150),
●​ ("Sales", "Alice", 2024, 120),
●​ ("HR", "Charlie", 2023, 80),
●​ ("HR", "David", 2024, 90),
●​ ("Sales", "Bob", 2024, 160)]
●​ columns = ["Department", "Employee", "Year", "Sales"]
●​ df = spark.createDataFrame(data, columns)
●​ df.show()
●​
●​ # Define a window specification: Partition by Department, order by Year and Sales
●​ window_spec = Window.partitionBy("Department").orderBy("Year", "Sales")
●​
●​ # Row Number: Assigns a unique, sequential number to each row within its partition,
●​ # based on the order defined in the window.
●​ print("\nRow Number by Department and Year/Sales:")
●​ df.withColumn("row_num", row_number().over(window_spec)).show()
●​
●​ # Define a window spec for sum of sales per department over all years up to current
●​ window_sum =
Window.partitionBy("Department").orderBy("Year").rowsBetween(Window.unboundedPre
ceding, Window.currentRow)
●​ print("\nRunning Sum of Sales per Department:")
●​ df.withColumn("Running_Sum_Sales", sum("Sales").over(window_sum)).show()
●​
●​ spark.stop()

Functions: row_number, rank, dense_rank, lag, lead


These are common window functions:
●​ row_number(): Assigns a unique, sequential integer to each row within its partition,
starting from 1. (No ties)
●​ rank(): Assigns a rank to each row within its partition. If there are ties (same values
for orderBy columns), they get the same rank, and a gap is left in the ranking
sequence.
●​ dense_rank(): Similar to rank(), but no gaps are left in the ranking sequence when
there are ties.
●​ lag(column, offset, default_value): Returns the value of column from a row that is
offset rows before the current row within the partition. default_value is used if the
offset goes beyond the partition start.
●​ lead(column, offset, default_value): Returns the value of column from a row that is
offset rows after the current row within the partition. default_value is used if the offset
goes beyond the partition end.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.functions import col, row_number, rank, dense_rank, lag, lead
●​ from pyspark.sql.window import Window
●​
●​ spark = SparkSession.builder.appName("WindowRankingFunctions").getOrCreate()
●​
●​ data = [("DeptA", "Alice", 100),
●​ ("DeptA", "Bob", 120),
●​ ("DeptA", "Charlie", 120), # Tie for Bob and Charlie
●​ ("DeptA", "David", 150),
●​ ("DeptB", "Eve", 80),
●​ ("DeptB", "Frank", 80), # Tie for Eve and Frank
●​ ("DeptB", "Grace", 90)]
●​ columns = ["Department", "Employee", "Score"]
●​ df = spark.createDataFrame(data, columns)
●​ df.show()
●​
●​ # Define a window specification: Partition by Department, order by Score (descending)
●​ window_spec = Window.partitionBy("Department").orderBy(col("Score").desc())
●​
●​ print("\nRanking functions with ties:")
●​ df.withColumn("row_num", row_number().over(window_spec)) \
●​ .withColumn("rank", rank().over(window_spec)) \
●​ .withColumn("dense_rank", dense_rank().over(window_spec)) \
●​ .show()
●​
●​ # Lag and Lead Example:
●​ # Find previous and next year's sales for each employee
●​ sales_data = [("Alice", 2022, 1000),
●​ ("Alice", 2023, 1200),
●​ ("Alice", 2024, 1100),
●​ ("Bob", 2022, 1500),
●​ ("Bob", 2023, 1600)]
●​ sales_cols = ["Employee", "Year", "Sales"]
●​ sales_df = spark.createDataFrame(sales_data, sales_cols)
●​ sales_df.show()
●​
●​ # Window for lag/lead: partition by Employee, order by Year
●​ window_lag_lead = Window.partitionBy("Employee").orderBy("Year")
●​
●​ print("\nLag and Lead for Sales:")
●​ sales_df.withColumn("Prev_Year_Sales", lag("Sales", 1).over(window_lag_lead)) \
●​ .withColumn("Next_Year_Sales", lead("Sales", 1).over(window_lag_lead)) \
●​ .show()
●​
●​ spark.stop()

Use cases: De-duplication, ranking, time series logic


Window functions are incredibly versatile for various data engineering tasks:
●​ De-duplication:
○​ Find duplicate rows based on a subset of columns.
○​ Use row_number() or rank() to assign a number to each duplicate group.
○​ Filter to keep only the first (or last) occurrence.
○​ Example: Remove duplicate customer records, keeping the most recent one.
●​ Python
●​ from pyspark.sql import SparkSession
●​ from pyspark.sql.functions import col, row_number
●​ from pyspark.sql.window import Window
●​
●​ spark = SparkSession.builder.appName("Deduplication").getOrCreate()
●​
●​ duplicate_data = [("Alice", "alice@email.com", "NY", 1),
●​ ("Bob", "bob@email.com", "LD", 2),
●​ ("Alice", "alice@email.com", "LA", 3), # Duplicate Alice, different city
●​ ("Charlie", "charlie@email.com", "SF", 4),
●​ ("Bob", "bob@email.com", "LD", 5)] # Exact duplicate Bob
●​ dup_cols = ["Name", "Email", "City", "RecordID"]
●​ dup_df = spark.createDataFrame(duplicate_data, dup_cols)
●​ print("Original DataFrame with duplicates:")
●​ dup_df.show()
●​
●​ # Deduplicate keeping the record with the higher RecordID (more recent)
●​ window_spec_dedup = Window.partitionBy("Name",
"Email").orderBy(col("RecordID").desc())
●​ deduplicated_df = dup_df.withColumn("row_num",
row_number().over(window_spec_dedup)) \
●​ .filter(col("row_num") == 1) \
●​ .drop("row_num")
●​ print("\nDeduplicated DataFrame (keeping latest record):")
●​ deduplicated_df.show()
●​ spark.stop()
●​
●​
●​ Ranking:
○​ Assign ranks to items within categories (e.g., top 3 students per class, top 10
products per region).
○​ Use rank(), dense_rank(), row_number().
○​ Example: Find the top 2 sales employees per department. (See previous
examples for rank, dense_rank).
●​ Time Series Logic:
○​ Calculate moving averages, cumulative sums.
○​ Compare current values with previous or subsequent values (lag, lead).
○​ Example: Calculate month-over-month growth, identify trends.
●​ Python
●​ from pyspark.sql import SparkSession
●​ from pyspark.sql.functions import col, lag, sum
●​ from pyspark.sql.window import Window
●​
●​ spark = SparkSession.builder.appName("TimeSeries").getOrCreate()
●​
●​ time_series_data = [("ProductA", "2023-01-01", 100),
●​ ("ProductA", "2023-02-01", 120),
●​ ("ProductA", "2023-03-01", 110),
●​ ("ProductB", "2023-01-01", 200),
●​ ("ProductB", "2023-02-01", 230)]
●​ ts_cols = ["Product", "Date", "Sales"]
●​ ts_df = spark.createDataFrame(time_series_data, ts_cols)
●​ print("Original Time Series DataFrame:")
●​ ts_df.show()
●​
●​ window_ts = Window.partitionBy("Product").orderBy("Date")
●​
●​ # Calculate previous month's sales and month-over-month growth
●​ ts_df.withColumn("PreviousMonthSales", lag("Sales", 1).over(window_ts)) \
●​ .withColumn("MoM_Growth", (col("Sales") - col("PreviousMonthSales")) /
col("PreviousMonthSales")) \
●​ .show()
●​
●​ # Calculate cumulative sum of sales for each product
●​ ts_df.withColumn("CumulativeSales",
sum("Sales").over(window_ts.rowsBetween(Window.unboundedPreceding,
Window.currentRow))) \
●​ .show()
●​
●​ spark.stop()
●​
●​

Window functions are incredibly powerful and often lead to more efficient and readable code
compared to self-joins or complex aggregations for similar tasks.

Handling Nulls / Missing Data


Missing data (represented as null in Spark) is a common problem in real-world datasets.
Spark DataFrames provide convenient methods to handle them.

na.fill(), na.drop(), na.replace()


These methods are part of the DataFrameNaFunctions (accessed via df.na).
●​ na.fill(value, subset=None): Replaces null values with a specified value.
○​ value: The value to replace nulls with. Can be a single value (applied to all
numeric/string columns) or a dictionary mapping column names to values.
○​ subset: A list of column names to apply the fill operation to. If None, applies to
all columns of compatible type.
●​ na.drop(how='any', thresh=None, subset=None): Drops rows containing null
values.
○​ how:
■​ 'any' (default): Drop a row if it contains any null values.
■​ 'all': Drop a row if all its values are null.
○​ thresh: An integer. Drop a row if it has fewer than thresh non-null values.
○​ subset: A list of column names to consider for dropping. If None, considers all
columns.
●​ na.replace(subset, replacement_map): Replaces a specific value (not just nulls)
with another value.
○​ subset: A list of column names to apply the replacement to.
○​ replacement_map: A dictionary where keys are the values to be replaced and
values are the new replacement values.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.functions import col
●​
●​ spark = SparkSession.builder.appName("HandlingNulls").getOrCreate()
●​
●​ data = [("Alice", None, "New York"),
●​ ("Bob", 25, "London"),
●​ ("Charlie", 35, None),
●​ ("David", None, None),
●​ (None, 40, "Paris"),
●​ ("Eve", 28, "Berlin")]
●​ columns = ["Name", "Age", "City"]
●​ df = spark.createDataFrame(data, columns)
●​ df.show()
●​
●​ # 1. Fill nulls
●​ print("\nAfter filling nulls:")
●​ # Fill all numeric nulls with 0, all string nulls with "Unknown"
●​ df.na.fill(0, subset=["Age"]) \
●​ .na.fill("Unknown", subset=["City"]) \
●​ .show()
●​
●​ # You can also use a dictionary for multiple types/columns in one go
●​ fill_values = {"Age": 99, "City": "N/A"}
●​ df.na.fill(fill_values).show()
●​
●​ # 2. Drop nulls
●​ print("\nAfter dropping nulls:")
●​ # Drop rows with any null value
●​ df.na.drop(how='any').show()
●​
●​ # Drop rows if 'Age' or 'City' is null
●​ df.na.drop(subset=["Age", "City"]).show()
●​
●​ # Drop rows if less than 2 non-null values (threshold)
●​ df.na.drop(thresh=2).show()
●​
●​ # 3. Replace specific values (e.g., replace 'Unknown' with 'TBD' if it existed, or 25 with
26)
●​ print("\nAfter replacing specific values:")
●​ df_replace_example = spark.createDataFrame([("Apple", "Red"), ("Banana", "Yellow"),
("Grape", "Red")], ["Fruit", "Color"])
●​ df_replace_example.show()
●​ df_replace_example.na.replace("Red", "Crimson", "Color").show()
●​
●​ # Replace 25 with 99 for 'Age' column in original df
●​ df.na.replace(25, 99, "Age").show()
●​
●​ spark.stop()
Fill based on specific columns
As shown in the na.fill() example above, you can specify the subset parameter to target
specific columns for null filling. This is crucial when different columns require different default
values or strategies.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.functions import col
●​
●​ spark = SparkSession.builder.appName("FillSpecificColumns").getOrCreate()
●​
●​ data = [("Alice", None, "New York", 100.0),
●​ ("Bob", 25, "London", None),
●​ ("Charlie", 35, None, 120.5),
●​ ("David", None, None, None)]
●​ columns = ["Name", "Age", "City", "Score"]
●​ df = spark.createDataFrame(data, columns)
●​ df.show()
●​
●​ print("Filling nulls in 'Age' with 0 and 'City' with 'Unknown' and 'Score' with average:")
●​
●​ # It's common to fill numeric nulls with mean/median/mode
●​ # For simplicity, let's just pick a value for score here, or calculate mean beforehand
●​ avg_score = df.select(avg("Score")).collect()[0][0] # Calculate average score from
non-nulls
●​
●​ df_filled = df.na.fill(0, subset=["Age"]) \
●​ .na.fill("Unknown", subset=["City"]) \
●​ .na.fill(avg_score, subset=["Score"]) # Fill score with calculated average
●​
●​ df_filled.show()
●​
●​ spark.stop()

Conditional handling of nulls with when().otherwise()


For more complex null handling logic, when().otherwise() is highly effective. This allows you
to fill or modify values based on conditions involving other columns or more intricate rules.

Example (Python):

Python
●​ from pyspark.sql import SparkSession
●​ from pyspark.sql.functions import col, when, lit
●​
●​ spark = SparkSession.builder.appName("ConditionalNullHandling").getOrCreate()
●​
●​ data = [("Alice", 30, "M", None),
●​ ("Bob", 25, "M", 60000),
●​ ("Charlie", None, "F", 80000), # Age is null
●​ ("David", 40, "M", None), # Salary is null
●​ ("Eve", 22, "F", 55000),
●​ ("Frank", None, None, None)] # Age, Gender, Salary are null
●​ columns = ["Name", "Age", "Gender", "Salary"]
●​ df = spark.createDataFrame(data, columns)
●​ df.show()
●​
●​ # Conditional filling for 'Salary':
●​ # If Salary is null AND Age is less than 30, fill with 50000.
●​ # If Salary is null AND Age is 30 or more, fill with 70000.
●​ # Otherwise, keep original Salary.
●​ df_salary_filled = df.withColumn("Salary_Filled",
●​ when(col("Salary").isNull() & (col("Age") < 30), lit(50000))
●​ .when(col("Salary").isNull() & (col("Age") >= 30), lit(70000))
●​ .otherwise(col("Salary")))
●​ print("\nConditional filling for Salary:")
●​ df_salary_filled.show()
●​
●​
●​ # Conditional filling for 'Age':
●​ # If Age is null, check Gender. If Gender is 'M', fill with 30. If 'F', fill with 28.
●​ # If Gender is also null, fill with a default (e.g., 0)
●​ df_age_gender_filled = df.withColumn("Age_Filled",
●​ when(col("Age").isNull(),
●​ when(col("Gender") == "M", lit(30))
●​ .when(col("Gender") == "F", lit(28))
●​ .otherwise(lit(0))) # Default if Gender is also null
●​ .otherwise(col("Age"))) \
●​ .withColumn("Gender_Filled",
●​ when(col("Gender").isNull(), lit("Unknown"))
●​ .otherwise(col("Gender")))
●​ print("\nConditional filling for Age and Gender:")
●​ df_age_gender_filled.show()
●​
●​
●​ spark.stop()
Sorting, Filtering, Distinct, Drop
Duplicates
These are common transformations for cleaning and organizing data.

sort, orderBy, distinct, dropDuplicates


●​ sort(*cols, asc=True) / orderBy(*cols, asc=True): Sorts the DataFrame by one or
more columns.
○​ sort() and orderBy() are aliases and perform the same operation.
○​ By default, sorting is ascending. Use col("column_name").desc() for
descending.
○​ This is a wide transformation as it typically requires shuffling data across
partitions.
●​ distinct(): Returns a new DataFrame containing only the distinct rows from the
current DataFrame. This is also a wide transformation as it needs to compare all
rows globally.
●​ dropDuplicates(subset=None): Returns a new DataFrame with duplicate rows
removed.
○​ If subset is None, it considers all columns for duplication.
○​ If subset is a list of column names, it considers only those columns for
identifying duplicates, keeping the first occurrence. This is generally preferred
over distinct() for performance when you want to de-duplicate based on a
subset of columns.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.functions import col
●​
●​ spark = SparkSession.builder.appName("SortFilterDistinct").getOrCreate()
●​
●​ data = [("Alice", 30, "NY"),
●​ ("Bob", 25, "LD"),
●​ ("Charlie", 35, "NY"),
●​ ("Alice", 30, "NY"), # Duplicate row
●​ ("David", 22, "SF"),
●​ ("Charlie", 35, "LA")] # Charlie is duplicated on Name/Age, but not Name/Age/City
●​ columns = ["Name", "Age", "City"]
●​ df = spark.createDataFrame(data, columns)
●​ df.show()
●​
●​ # Sorting by a single column
●​ print("\nSorted by Age (ascending):")
●​ df.sort("Age").show()
●​
●​ # Sorting by multiple columns (Age ascending, then Name descending)
●​ print("\nSorted by Age (asc), then Name (desc):")
●​ df.orderBy("Age", col("Name").desc()).show()
●​
●​ # Filter (recap)
●​ print("\nFiltered for Age > 28:")
●​ df.filter(col("Age") > 28).show()
●​
●​ # Distinct
●​ print("\nDistinct rows (considering all columns):")
●​ df.distinct().show()
●​
●​ # Drop Duplicates (on all columns)
●​ print("\nDrop duplicates (on all columns):")
●​ df.dropDuplicates().show()
●​
●​ # Drop Duplicates on a subset of columns (e.g., Name and Age)
●​ print("\nDrop duplicates on 'Name' and 'Age' (keep first occurrence):")
●​ df.dropDuplicates(subset=["Name", "Age"]).show()
●​
●​ spark.stop()

Understanding wide vs narrow transformations


Spark transformations are categorized into two types based on how they affect data locality:
●​ Narrow Transformations:
○​ Each input partition contributes to at most one output partition.
○​ No data shuffling is required across partitions.
○​ Examples: filter(), map(), withColumn(), select(), unionAll().
○​ Performance: Generally faster because they don't involve network I/O or
data movement between nodes.
●​ Wide Transformations (Shuffles):
○​ Each input partition can contribute to multiple output partitions.
○​ Requires data to be shuffled across the network between different
nodes/executors. This is an expensive operation.
○​ Examples: groupBy(), orderBy(), sort(), distinct(), dropDuplicates(),
repartition(), join() (unless it's a broadcast join).
○​ Performance: Can be significantly slower due to network I/O,
serialization/deserialization, and disk I/O (if data doesn't fit in memory).
Shuffles can also lead to data skew, where some partitions become much
larger than others, creating bottlenecks.
Why it matters: As a data engineer, you should be aware of which transformations trigger
shuffles. Aim to minimize shuffles where possible and optimize them when they are
unavoidable (e.g., by ensuring proper partitioning, managing spark.sql.shuffle.partitions, or
using broadcast joins).

Performance: Why dropDuplicates() is preferred over


distinct()
While distinct() and dropDuplicates() can achieve similar results, especially when
de-duplicating on all columns, dropDuplicates() is generally preferred for performance
reasons and flexibility.
●​ distinct():
○​ Conceptually, it's equivalent to groupBy(...all columns...).agg(first(...all
columns...)) or a global DISTINCT in SQL.
○​ It always operates on all columns of the DataFrame.
○​ It typically involves a global shuffle and then a sort and distinct operation.
●​ dropDuplicates(subset=None):
○​ When subset is specified, Spark can optimize the shuffle more effectively. It
only needs to shuffle based on the subset columns, potentially reducing the
amount of data moved.
○​ It's designed specifically for de-duplication based on selected columns.
○​ If no subset is provided, its behavior is similar to distinct() (operating on all
columns).

Practical Preference:

When you need to de-duplicate based on a specific set of identifying columns (e.g., (user_id,
transaction_id)), using dropDuplicates(subset=["user_id", "transaction_id"]) is more efficient
because Spark only needs to consider these columns for the shuffle and comparison.
distinct() would still consider all other columns, even if you don't care about their uniqueness,
potentially leading to unnecessary data movement and comparison overhead.

Therefore, for targeted de-duplication, dropDuplicates(subset=...) is the highly


recommended and more performant approach.

Explode, Arrays, Maps, Structs


Spark DataFrames excel at handling semi-structured and nested data types like arrays,
maps, and structs, which are common in JSON or Parquet data.

Exploding arrays into rows using explode()


The explode() function transforms an array column (or map column) into individual rows for
each element in the array (or key-value pair in the map). If an array column has N elements,
it will generate N rows for that array, duplicating other column values.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.functions import explode, col
●​
●​ spark = SparkSession.builder.appName("ExplodeArrays").getOrCreate()
●​
●​ data = [("Alice", ["reading", "hiking", "cooking"]),
●​ ("Bob", ["coding", "gaming"]),
●​ ("Charlie", [])] # Empty array
●​ columns = ["Name", "Hobbies"]
●​ df = spark.createDataFrame(data, columns)
●​ df.show(truncate=False)
●​ df.printSchema()
●​
●​ # Explode the 'Hobbies' array column
●​ print("\nDataFrame after exploding 'Hobbies':")
●​ df.withColumn("Hobby", explode(col("Hobbies"))).show()
●​
●​ # What happens if a column is null or empty?
●​ data_with_null_array = [("Alice", ["reading", "hiking"]),
●​ ("Bob", None), # Null array
●​ ("Charlie", [])] # Empty array
●​ df_null_array = spark.createDataFrame(data_with_null_array, columns)
●​ df_null_array.show(truncate=False)
●​
●​ # When exploding a null or empty array, the row for that record is dropped by default.
●​ print("\nDataFrame after exploding with null/empty arrays:")
●​ df_null_array.withColumn("Hobby", explode(col("Hobbies"))).show()
●​
●​ # To keep rows with null/empty arrays, use `explode_outer` (Spark 2.4+)
●​ from pyspark.sql.functions import explode_outer
●​ print("\nDataFrame after exploding with explode_outer:")
●​ df_null_array.withColumn("Hobby", explode_outer(col("Hobbies"))).show()
●​
●​ spark.stop()
Creating and querying ArrayType, MapType, and
StructType
You can create DataFrames with these complex types and then query them using dot
notation or specific functions.

Creating these types:


●​ array(): Creates an array column.
●​ map(): Creates a map column.
●​ struct(): Creates a struct column.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.functions import array, map_from_entries, struct, col, lit
●​ from pyspark.sql.types import StructType, StructField, StringType, IntegerType,
ArrayType, MapType
●​
●​ spark = SparkSession.builder.appName("ComplexTypes").getOrCreate()
●​
●​ # Create DataFrame with ArrayType, MapType, StructType
●​ data = [("Alice", [10, 20], [("city", "NY"), ("zip", "10001")], ("Red", "Car")),
●​ ("Bob", [30], [("city", "LD")], ("Blue", "Bike"))]
●​ columns = ["Name", "Scores", "AddressMap", "Vehicle"]
●​
●​ # Manual schema definition for clarity
●​ schema = StructType([
●​ StructField("Name", StringType(), True),
●​ StructField("Scores", ArrayType(IntegerType()), True),
●​ StructField("AddressMap", MapType(StringType(), StringType()), True), # Map is
created from list of tuples
●​ StructField("Vehicle", StructType([
●​ StructField("Color", StringType(), True),
●​ StructField("Type", StringType(), True)
●​ ]), True)
●​ ])
●​
●​ df = spark.createDataFrame(data, schema=schema)
●​ df.show(truncate=False)
●​ df.printSchema()
●​
●​ # Another way to create complex types using functions
●​ df_created = spark.createDataFrame([
●​ (1, "Alice"),
●​ (2, "Bob")
●​ ]).select(
●​ col("_1").alias("ID"),
●​ col("_2").alias("Name"),
●​ array(lit("Apple"), lit("Banana")).alias("Fruits"), # ArrayType
●​ map_from_entries(array(struct(lit("key1"), lit("value1")), struct(lit("key2"),
lit("value2")))).alias("Properties"), # MapType
●​ struct(lit("Main St").alias("Street"), lit("Anytown").alias("City")).alias("Location") #
StructType
●​ )
●​ df_created.show(truncate=False)
●​ df_created.printSchema()
●​
●​ spark.stop()

Flattening nested schemas


Flattening involves transforming nested structures (StructType, ArrayType of StructType) into
a flatter structure with top-level columns.
●​ For StructType: Access elements using dot notation and then select them as new
columns.
●​ For ArrayType of StructType: Use explode() first, then access elements.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.functions import col, explode, flatten
●​
●​ spark = SparkSession.builder.appName("FlatteningSchemas").getOrCreate()
●​
●​ # Data with nested struct
●​ data_struct = [("Alice", {"street": "Main St", "city": "NY"}),
●​ ("Bob", {"street": "Elm St", "city": "LD"})]
●​ schema_struct = StructType([
●​ StructField("Name", StringType(), True),
●​ StructField("Address", StructType([
●​ StructField("street", StringType(), True),
●​ StructField("city", StringType(), True)
●​ ]), True)
●​ ])
●​ df_struct = spark.createDataFrame(data_struct, schema_struct)
●​ print("Original DataFrame with StructType:")
●​ df_struct.show(truncate=False)
●​ df_struct.printSchema()
●​
●​ # Flattening StructType
●​ print("\nFlattening StructType:")
●​ df_flattened_struct = df_struct.select(
●​ col("Name"),
●​ col("Address.street").alias("Street"),
●​ col("Address.city").alias("City")
●​ )
●​ df_flattened_struct.show(truncate=False)
●​ df_flattened_struct.printSchema()
●​
●​ # Data with ArrayType of StructType
●​ data_array_struct = [("ProductA", [{"feature_name": "Color", "value": "Red"},
{"feature_name": "Size", "value": "M"}]),
●​ ("ProductB", [{"feature_name": "Weight", "value": "1kg"}])]
●​ schema_array_struct = StructType([
●​ StructField("Product", StringType(), True),
●​ StructField("Features", ArrayType(StructType([
●​ StructField("feature_name", StringType(), True),
●​ StructField("value", StringType(), True)
●​ ])), True)
●​ ])
●​ df_array_struct = spark.createDataFrame(data_array_struct, schema_array_struct)
●​ print("\nOriginal DataFrame with ArrayType of StructType:")
●​ df_array_struct.show(truncate=False)
●​ df_array_struct.printSchema()
●​
●​ # Flattening ArrayType of StructType (requires explode first)
●​ print("\nFlattening ArrayType of StructType:")
●​ df_flattened_array_struct = df_array_struct.withColumn("exploded_features",
explode(col("Features"))) \
●​ .select(
●​ col("Product"),
●​ col("exploded_features.feature_name"),
●​ col("exploded_features.value")
●​ )
●​ df_flattened_array_struct.show(truncate=False)
●​ df_flattened_array_struct.printSchema()
●​
●​ spark.stop()

Accessing fields: col("struct.col")


You can access fields within a StructType column using dot notation:
col("parent_struct.child_field").
For ArrayType and MapType, you need specific functions or an explode() operation to
access elements.
●​ MapType access: Use col("map_col")["key"] to get a value by key.
●​ ArrayType access: Use col("array_col")[index] for direct element access (use with
caution as index might be out of bounds), or explode() for iterating over elements as
rows.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.functions import col, explode, element_at
●​
●​ spark = SparkSession.builder.appName("AccessingNestedFields").getOrCreate()
●​
●​ data = [("Alice", {"street": "Main St", "city": "NY"}, ["A", "B"], {"email": "a@ex.com",
"phone": "111"}),
●​ ("Bob", {"street": "Elm St", "city": "LD"}, ["C"], {"email": "b@ex.com"})]
●​ schema = StructType([
●​ StructField("Name", StringType(), True),
●​ StructField("Address", StructType([
●​ StructField("street", StringType(), True),
●​ StructField("city", StringType(), True)
●​ ]), True),
●​ StructField("Grades", ArrayType(StringType()), True),
●​ StructField("Contact", MapType(StringType(), StringType()), True)
●​ ])
●​ df = spark.createDataFrame(data, schema=schema)
●​ df.show(truncate=False)
●​ df.printSchema()
●​
●​ # Accessing StructType fields
●​ print("\nAccessing StructType fields:")
●​ df.select(col("Name"),
●​ col("Address.street").alias("Street"),
●​ col("Address.city").alias("City")).show()
●​
●​ # Accessing MapType fields
●​ print("\nAccessing MapType fields:")
●​ df.select(col("Name"),
●​ col("Contact.email").alias("Email"), # Using dot notation (common for maps in
some contexts)
●​ element_at(col("Contact"), "phone").alias("Phone")).show() # More robust way for
map keys
●​
●​ # Accessing ArrayType fields (specific index - use carefully if array length varies)
●​ print("\nAccessing ArrayType fields by index:")
●​ df.select(col("Name"),
●​ col("Grades")[0].alias("FirstGrade")).show() # Accessing first element
●​
●​ # Iterating ArrayType fields using explode (common for all elements)
●​ print("\nAccessing ArrayType fields using explode:")
●​ df.withColumn("Grade", explode(col("Grades"))).show()
●​
●​ spark.stop()

Working with Dates and Timestamps


Date and timestamp manipulation are crucial in many data engineering scenarios, from
time-series analysis to data warehousing. Spark provides a rich set of built-in functions for
this.

Functions: to_date, current_date, current_timestamp,


datediff, add_months, date_format, year, month,
dayofweek
These functions are available in pyspark.sql.functions.
●​ to_date(column, format=None): Converts a StringType column to a DateType
column. An optional format string can be provided.
●​ current_date(): Returns the current date as a DateType literal.
●​ current_timestamp(): Returns the current timestamp as a TimestampType literal.
●​ datediff(end_date, start_date): Returns the number of days between two DateType
columns.
●​ months_between(timestamp1, timestamp2, roundOff=True): Returns the number
of months between two timestamps.
●​ add_months(start_date, num_months): Returns the date that is num_months after
start_date.
●​ date_add(start_date, num_days): Adds num_days to start_date.
●​ date_sub(start_date, num_days): Subtracts num_days from start_date.
●​ date_format(date, format): Formats a DateType or TimestampType column to a
StringType according to the specified format.
●​ year(column): Extracts the year from a date/timestamp.
●​ month(column): Extracts the month from a date/timestamp.
●​ dayofmonth(column): Extracts the day of the month.
●​ dayofweek(column): Extracts the day of the week (1=Sunday, 7=Saturday).
●​ hour(column), minute(column), second(column): Extract time components.

Example (Python):
Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.functions import col, to_date, current_date, current_timestamp, \
●​ datediff, add_months, date_format, year, month, \
●​ dayofweek, dayofmonth, hour, minute, second
●​
●​ spark = SparkSession.builder.appName("DateTimestampFunctions").getOrCreate()
●​
●​ data = [("2023-01-01", "2023-01-31 10:30:00"),
●​ ("2023-02-15", "2023-02-15 14:00:00"),
●​ ("2022-12-25", "2022-12-25 23:59:59")]
●​ columns = ["event_date_str", "event_ts_str"]
●​ df = spark.createDataFrame(data, columns)
●​ df.show(truncate=False)
●​ df.printSchema()
●​
●​ # Convert string to DateType and TimestampType
●​ df_converted = df.withColumn("event_date", to_date(col("event_date_str"))) \
●​ .withColumn("event_ts", col("event_ts_str").cast("timestamp"))
●​ print("\nAfter converting to DateType and TimestampType:")
●​ df_converted.show(truncate=False)
●​ df_converted.printSchema()
●​
●​ # Current Date and Timestamp
●​ print("\nCurrent Date and Timestamp:")
●​ spark.range(1).select(current_date().alias("Today"),
current_timestamp().alias("Now")).show(truncate=False)
●​
●​ # Date Difference
●​ print("\nDate difference (days between current date and event_date):")
●​ df_converted.withColumn("days_since_event", datediff(current_date(),
col("event_date"))).show()
●​
●​ # Add Months
●​ print("\nDate after adding 3 months:")
●​ df_converted.withColumn("date_plus_3_months", add_months(col("event_date"),
3)).show()
●​
●​ # Date Formatting
●​ print("\nFormatted Date and Timestamp:")
●​ df_converted.withColumn("formatted_date", date_format(col("event_date"),
"yyyy/MM/dd")) \
●​ .withColumn("formatted_ts", date_format(col("event_ts"), "MM-dd-yyyy
HH:mm:ss")).show(truncate=False)
●​
●​ # Extracting components
●​ print("\nDate/Time Components:")
●​ df_converted.select(
●​ col("event_date"),
●​ year(col("event_date")).alias("Year"),
●​ month(col("event_date")).alias("Month"),
●​ dayofmonth(col("event_date")).alias("DayOfMonth"),
●​ dayofweek(col("event_date")).alias("DayOfWeek"),
●​ hour(col("event_ts")).alias("Hour"),
●​ minute(col("event_ts")).alias("Minute"),
●​ second(col("event_ts")).alias("Second")
●​ ).show()
●​
●​ spark.stop()

Date arithmetic and filtering


You can perform arithmetic operations on dates (e.g., adding/subtracting days/months) and
filter DataFrames based on date ranges.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.functions import col, to_date, date_add, date_sub, datediff
●​
●​ spark = SparkSession.builder.appName("DateArithmeticFiltering").getOrCreate()
●​
●​ data = [("TaskA", "2023-01-10"),
●​ ("TaskB", "2023-02-20"),
●​ ("TaskC", "2023-01-05"),
●​ ("TaskD", "2023-03-01"),
●​ ("TaskE", "2023-02-28")]
●​ columns = ["Task", "CompletionDate"]
●​ df = spark.createDataFrame(data, columns)
●​ df = df.withColumn("CompletionDate", to_date(col("CompletionDate")))
●​ df.show()
●​ df.printSchema()
●​
●​ # Add/Subtract days
●​ print("\nDate Arithmetic (add/subtract days):")
●​ df.withColumn("DueIn5Days", date_add(col("CompletionDate"), 5)) \
●​ .withColumn("Started5DaysBefore", date_sub(col("CompletionDate"), 5)).show()
●​
●​ # Filtering by date range
●​ print("\nFiltering for tasks completed in February 2023:")
●​ df.filter((col("CompletionDate") >= "2023-02-01") & (col("CompletionDate") <=
"2023-02-28")).show()
●​
●​ # Filtering for tasks completed within 30 days of a specific date
●​ target_date = to_date(lit("2023-02-15"))
●​ print(f"\nFiltering for tasks completed within 30 days of 2023-02-15:")
●​ df.filter(datediff(col("CompletionDate"), target_date).between(-30, 30)).show()
●​
●​ spark.stop()

Handling time zones and formats


Time zone handling is crucial for data consistency, especially when dealing with data from
different geographical locations. Spark supports explicit time zone configuration.
●​ spark.sql.session.timeZone: This Spark configuration property determines the
session's default time zone. All timestamp operations that don't explicitly specify a
time zone will use this. (Default is JVM's local time zone).
●​ to_timestamp(column, format, timeZone): Converts a string to a timestamp with
an optional time zone specification.
●​ from_unixtime(unixtime_col, format): Converts Unix timestamp (seconds since
epoch) to a formatted string.
●​ unix_timestamp(timestamp_col, format): Converts a timestamp string with a
format to a Unix timestamp.

Important Note: Spark stores timestamps internally as UTC (Coordinated Universal Time).
When you read or write timestamps, Spark converts them to/from the session's configured
time zone.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.functions import col, to_timestamp, current_timestamp, date_format
●​
●​ spark = SparkSession.builder.appName("TimezoneHandling").getOrCreate()
●​
●​ # Set session time zone for demonstration
●​ spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles")
●​ print(f"Session Time Zone: {spark.conf.get('spark.sql.session.timeZone')}")
●​
●​ data = [("2023-07-19 10:00:00",),
●​ ("2023-07-19 15:00:00",)]
●​ columns = ["timestamp_str"]
●​ df = spark.createDataFrame(data, columns)
●​ df.show(truncate=False)
●​
●​ # Convert string to timestamp (assuming input is in session timezone or no TZ specified)
●​ df_ts = df.withColumn("timestamp", col("timestamp_str").cast("timestamp"))
●​ print("\nOriginal timestamps (converted to session timezone display):")
●​ df_ts.show(truncate=False)
●​ df_ts.printSchema()
●​
●​ # Demonstrate different time zones during conversion (input string has no TZ)
●​ # Interpret the string "2023-07-19 10:00:00" as belonging to GMT and convert to session
TZ (America/Los_Angeles)
●​ df_tz_gmt = df.withColumn("ts_gmt", to_timestamp(col("timestamp_str"), "yyyy-MM-dd
HH:mm:ss", "GMT"))
●​ print("\nTimestamp interpreted as GMT, then displayed in session TZ
(America/Los_Angeles):")
●​ df_tz_gmt.show(truncate=False)
●​
●​ # Revert session time zone to UTC for comparison
●​ spark.conf.set("spark.sql.session.timeZone", "UTC")
●​ print(f"\nSession Time Zone changed to: {spark.conf.get('spark.sql.session.timeZone')}")
●​
●​ # Now observe how the same internal timestamp value is displayed differently
●​ print("\nOriginal timestamps (displayed in new session TZ - UTC):")
●​ df_ts.show(truncate=False)
●​
●​ # Format timestamp to specific time zone string
●​ print("\nFormatting timestamp to specific time zone strings:")
●​ df_ts.withColumn("formatted_utc", date_format(col("timestamp"), "yyyy-MM-dd
HH:mm:ss z")) \
●​ .withColumn("formatted_la", date_format(col("timestamp"), "yyyy-MM-dd HH:mm:ss
z", "America/Los_Angeles")) \
●​ .withColumn("formatted_london", date_format(col("timestamp"), "yyyy-MM-dd
HH:mm:ss z", "Europe/London")) \
●​ .show(truncate=False)
●​
●​ spark.stop()

Key principle: Always be explicit about time zones if your data originates from or needs to
be presented in different time zones. Prefer to store data in UTC to avoid ambiguity and
convert only at the edges of your system (ingestion and presentation).

String Functions
Spark SQL provides a comprehensive set of functions for manipulating string data, which are
crucial for data cleaning, parsing, and feature engineering. These functions are available in
pyspark.sql.functions.

Functions: substring, concat, concat_ws, split, trim,


lpad, rpad, lower, upper
●​ substring(str, pos, len): Extracts a substring from str starting at pos (1-based index)
with len characters.
●​ concat(*cols): Concatenates multiple string columns into a single string.
●​ concat_ws(sep, *cols): Concatenates multiple string columns into a single string,
separated by sep.
●​ split(str, delimiter): Splits a string column into an array of strings based on a
delimiter.
●​ trim(str): Removes leading and trailing spaces from a string.
●​ ltrim(str): Removes leading spaces.
●​ rtrim(str): Removes trailing spaces.
●​ lpad(str, len, pad): Pads str on the left with pad until it reaches len.
●​ rpad(str, len, pad): Pads str on the right with pad until it reaches len.
●​ lower(str): Converts a string to lowercase.
●​ upper(str): Converts a string to uppercase.
●​ length(str): Returns the length of a string.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.functions import col, substring, concat, concat_ws, split, \
●​ trim, ltrim, rtrim, lpad, rpad, lower, upper, length
●​
●​ spark = SparkSession.builder.appName("StringFunctions").getOrCreate()
●​
●​ data = [("Alice", "alice.smith@example.com", " NYC ", "123-456"),
●​ ("Bob", "bob.jones@example.com", "LA", "789-012"),
●​ ("Charlie", "charlie@gmail.com", "SFO", "345-678")]
●​ columns = ["Name", "Email", "City", "PhonePrefix"]
●​ df = spark.createDataFrame(data, columns)
●​ df.show(truncate=False)
●​
●​ # Substring
●​ print("\nSubstring:")
●​ df.withColumn("EmailDomain", substring(col("Email"), 6, 10)).show() # Get part of email,
e.g., 'e.com'
●​ df.withColumn("First3CharsName", substring(col("Name"), 1, 3)).show()
●​
●​ # Concat
●​ print("\nConcat / Concat_WS:")
●​ df.withColumn("FullEmail", concat(col("Name"), lit("<"), col("Email"),
lit(">"))).show(truncate=False)
●​ df.withColumn("FormattedCity", concat_ws("-", lit("City"), col("City"))).show()
●​
●​ # Split
●​ print("\nSplit Email by '@':")
●​ df.withColumn("EmailParts", split(col("Email"), "@")).show(truncate=False)
●​ df.withColumn("Domain", split(col("Email"), "@")[1]).show(truncate=False) # Access
element from array
●​
●​ # Trim, Ltrim, Rtrim
●​ print("\nTrim:")
●​ df.withColumn("TrimmedCity", trim(col("City"))).show()
●​ df.withColumn("LTrimmedCity", ltrim(col("City"))).show()
●​ df.withColumn("RTrimmedCity", rtrim(col("City"))).show()
●​
●​ # Lpad, Rpad
●​ print("\nLpad / Rpad:")
●​ df.withColumn("PaddedPhone", lpad(col("PhonePrefix"), 10, "0")).show() # Pad with '0' to
length 10
●​ df.withColumn("PaddedName", rpad(col("Name"), 10, "*")).show() # Pad with '*' to length
10
●​
●​ # Lower, Upper
●​ print("\nLower / Upper:")
●​ df.withColumn("LowerName", lower(col("Name"))) \
●​ .withColumn("UpperCity", upper(col("City"))).show()
●​
●​ # Length
●​ print("\nLength of Name:")
●​ df.withColumn("NameLength", length(col("Name"))).show()
●​
●​ spark.stop()

Regex functions: regexp_extract, regexp_replace


Regular expressions (regex) are powerful for complex string pattern matching and
manipulation.
●​ regexp_extract(str, pattern, idx): Extracts a string that matches a pattern from str.
idx specifies which capture group to return.
●​ regexp_replace(str, pattern, replacement): Replaces all substrings in str that
match the pattern with replacement.

Example (Python):
Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.functions import col, regexp_extract, regexp_replace
●​
●​ spark = SparkSession.builder.appName("RegexFunctions").getOrCreate()
●​
●​ data = [("Alice Smith", "Order #123456 - Product X (Qty: 2)", "john.doe@email.com"),
●​ ("Bob Johnson", "Product Y - REF_789012", "jane_doe@gmail.com"),
●​ ("Charlie Brown", "No Order Info", "charlie@org.net")]
●​ columns = ["CustomerName", "OrderDetails", "Email"]
●​ df = spark.createDataFrame(data, columns)
●​ df.show(truncate=False)
●​
●​ # regexp_extract: Extract Order ID (e.g., numbers after '#')
●​ print("\nExtracting Order ID:")
●​ df.withColumn("OrderID", regexp_extract(col("OrderDetails"), r"#(\d+)",
1)).show(truncate=False)
●​ # r"#(\d+)" : "#" literal, then capture group "(\d+)" one or more digits. Index 1 for the
capture group.
●​
●​ # regexp_extract: Extract domain from email
●​ print("\nExtracting Email Domain:")
●​ df.withColumn("EmailDomain", regexp_extract(col("Email"), r"@([a-zA-Z0-9.-]+)",
1)).show(truncate=False)
●​ # r"@([a-zA-Z0-9.-]+)" : "@" literal, then capture group with letters, numbers, dot, dash.
●​
●​ # regexp_replace: Mask part of an email (e.g., username with ***)
●​ print("\nMasking Email Username:")
●​ df.withColumn("MaskedEmail", regexp_replace(col("Email"), r"^[^@]+@",
"***@")).show(truncate=False)
●​ # r"^[^@]+@" : Start of string "^", one or more characters that are NOT "@" "[^@]+", then
"@" literal.
●​
●​ # regexp_replace: Remove non-alphanumeric characters from Customer Name
●​ print("\nRemoving non-alphanumeric from Customer Name:")
●​ df.withColumn("CleanCustomerName", regexp_replace(col("CustomerName"),
r"[^a-zA-Z0-9 ]", "")).show(truncate=False)
●​
●​ spark.stop()

Repartition vs Coalesce
These two functions are used to control the number of partitions in a DataFrame, which
directly impacts parallelism and performance. The key difference lies in whether they trigger
a full data shuffle.

repartition(n) — increases/decreases partitions with


shuffle
●​ Purpose: Redistributes data across n partitions, ensuring that data is evenly
distributed across the cluster nodes.
●​ Behavior: Always involves a full shuffle of data. Each row can move to any
partition.
●​ Use cases:
○​ Increasing partitions: When you have too few partitions and want to
increase parallelism for computationally intensive tasks.
○​ Decreasing partitions (but with even distribution): If you need a specific
number of partitions for downstream operations and want to ensure even
distribution, even if it means a shuffle.
○​ Balancing data skew: If some partitions are much larger than others,
repartition() can help rebalance the data.
○​ Changing partitioning key: When you need to repartition by a different set
of columns for future joins or aggregations to avoid shuffles on those
operations.
●​ Overhead: High network I/O and disk I/O due to the shuffle.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​
●​ spark = SparkSession.builder.appName("RepartitionCoalesce").getOrCreate()
●​
●​ data = [(i,) for i in range(100)] # 100 rows
●​ df = spark.createDataFrame(data, ["value"])
●​
●​ # Check initial number of partitions (might vary based on local setup)
●​ print(f"Initial partitions: {df.rdd.getNumPartitions()}")
●​
●​ # Repartition to 10 partitions
●​ df_repartitioned = df.repartition(10)
●​ print(f"Partitions after repartition(10): {df_repartitioned.rdd.getNumPartitions()}")
●​
●​ # Repartition to 2 partitions
●​ df_repartitioned_small = df.repartition(2)
●​ print(f"Partitions after repartition(2): {df_repartitioned_small.rdd.getNumPartitions()}")
●​
●​ # Explain to see the shuffle (Exchange)
●​ print("\nRepartition explain plan:")
●​ df_repartitioned.explain() # Look for 'Exchange'
●​
●​ spark.stop()

coalesce(n) — decreases partitions without shuffle


●​ Purpose: Decreases the number of partitions.
●​ Behavior: Attempts to combine existing partitions on the same nodes to achieve n
partitions. It avoids a full shuffle if possible. If you coalesce to a number smaller than
the current number of partitions, it tries to merge existing partitions, resulting in a
narrow transformation. If you try to coalesce to a larger number than current, it simply
returns the current number of partitions as it cannot increase partitions without a
shuffle.
●​ Use cases:
○​ Reducing partitions for writing: When writing to a single file or a small
number of files to avoid creating many small files (small files can be inefficient
in HDFS/cloud storage).
○​ Minimizing shuffle: If you need to reduce partitions, but want to avoid the
high cost of a full shuffle, coalesce() is preferred.
●​ Overhead: Lower overhead than repartition() as it avoids a full shuffle.

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​
●​ spark = SparkSession.builder.appName("RepartitionCoalesce").getOrCreate()
●​
●​ data = [(i,) for i in range(100)]
●​ # Create a DataFrame with a known higher number of initial partitions
●​ df = spark.sparkContext.parallelize(data, 10).toDF(["value"])
●​
●​ print(f"Initial partitions: {df.rdd.getNumPartitions()}") # Should be 10
●​
●​ # Coalesce to 5 partitions (decreases without full shuffle)
●​ df_coalesced = df.coalesce(5)
●​ print(f"Partitions after coalesce(5): {df_coalesced.rdd.getNumPartitions()}")
●​
●​ # Coalesce to 1 partition (useful for writing a single file)
●​ df_coalesced_single = df.coalesce(1)
●​ print(f"Partitions after coalesce(1): {df_coalesced_single.rdd.getNumPartitions()}")
●​
●​ # Try to coalesce to a number higher than initial partitions - it won't increase
●​ df_coalesced_larger = df.coalesce(15)
●​ print(f"Partitions after coalesce(15) (will not increase):
{df_coalesced_larger.rdd.getNumPartitions()}")
●​
●​ # Explain to see no shuffle (no Exchange for valid coalesces)
●​ print("\nCoalesce explain plan:")
●​ df_coalesced.explain() # Should not show 'Exchange' for decreasing partitions
●​
●​ spark.stop()

When to use which (write optimization, shuffle control)


Feature repartition(n) coalesce(n)

Shuffle Always triggers a full shuffle Avoids a full shuffle if possible

Partitions Can increase or decrease Can only decrease or maintain


partitions partitions (or return original if n is
higher)

Distribution Guarantees even May result in unevenly sized partitions


distribution of data

Performance Higher overhead due to Lower overhead, faster for reducing


network I/O partitions

Use Cases - Increasing parallelism - Reducing number of small files


when writing

- Balancing skewed data


- Minimizing network I/O when
decreasing partitions
- Changing partitioning key
for joins
- When precise number of - When exact partition count is less
partitions is crucial critical than avoiding shuffle

General Guidelines:
●​ Write Optimization: When writing a large DataFrame to a few files (e.g., one file),
use coalesce(1) or coalesce(N) where N is a small number. This prevents many tiny
output files, which can hurt performance and management in distributed file systems.
●​ Shuffle Control: Be mindful of repartition() as it's an expensive operation. Only use it
when a full data redistribution is necessary.
●​ Increasing Parallelism: If your tasks are CPU-bound and your current number of
partitions is too low, use repartition() to increase parallelism.
●​ Data Skew: If you observe data skew in your Spark UI (some tasks taking much
longer than others during shuffles), repartition() by the skewed column(s) can help
redistribute the data more evenly.

Best Practice:

Understand your data and workload. Monitor Spark UI to see the number of tasks, task
durations, and shuffle read/write bytes to determine if your partitioning strategy is optimal.

UDFs and Pandas UDFs


UDFs (User-Defined Functions) allow you to extend Spark's functionality by defining your
own custom functions in Python, Scala, or Java, which can then be applied to DataFrame
columns.

Regular UDFs:
Regular UDFs operate row by row, similar to a standard Python function.
●​ Definition:
○​ You define a Python function.
○​ You wrap it with pyspark.sql.functions.udf and specify the returnType
(important for Spark's optimizer).

Example (Python):

Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.functions import col, udf
●​ from pyspark.sql.types import StringType, IntegerType
●​
●​ spark = SparkSession.builder.appName("RegularUDFs").getOrCreate()
●​
●​ data = [("Alice", 30), ("Bob", 25), ("Charlie", 35)]
●​ columns = ["Name", "Age"]
●​ df = spark.createDataFrame(data, columns)
●​ df.show()
●​
●​ # Define a Python function
●​ def age_category(age):
●​ if age < 25:
●​ return "Young"
●​ elif age >= 25 and age < 35:
●​ return "Mid"
●​ else:
●​ return "Senior"
●​
●​ # Register the Python function as a UDF
●​ # Specify the returnType: StringType() in this case
●​ age_category_udf = udf(age_category, StringType())
●​
●​ # Apply the UDF to the DataFrame
●​ print("\nDataFrame with Age Category (using Regular UDF):")
●​ df.withColumn("AgeCategory", age_category_udf(col("Age"))).show()
●​
●​ # Another UDF example: simple addition
●​ def add_one(value):
●​ return value + 1
●​
●​ add_one_udf = udf(add_one, IntegerType())
●​ print("\nDataFrame with Age + 1 (using Regular UDF):")
●​ df.withColumn("AgePlusOne", add_one_udf(col("Age"))).show()
●​
●​ spark.stop()

Performance drawbacks (serialization, optimization barrier)


Despite their flexibility, regular UDFs have significant performance drawbacks:
1.​ Serialization/Deserialization Overhead: Spark's core execution engine operates on
optimized binary data formats (Tungsten). When a regular Python UDF is called, data
from Spark's internal format must be deserialized into Python objects for the UDF to
process, and then the results must be serialized back into Spark's format. This
constant conversion introduces overhead.
2.​ Optimization Barrier: The Catalyst Optimizer, which is responsible for optimizing
DataFrame operations, cannot "see inside" a regular UDF. It treats the UDF as a
black box. This means Spark cannot apply its powerful optimizations (like predicate
pushdown, column pruning, code generation) to operations inside the UDF.
3.​ Python Process Overhead: Each executor runs a Python process to execute
Python UDFs. There's overhead associated with launching and managing these
separate processes and communicating between the JVM (Spark core) and Python.
4.​ No Vectorization: Regular UDFs process data row by row, which is inefficient
compared to vectorized operations that process data in batches.

Consequence: Regular Python UDFs can be a performance bottleneck for large datasets.
They should be used sparingly and only when a built-in Spark SQL function cannot achieve
the desired logic.

Pandas UDFs:
Pandas UDFs (also known as Vectorized UDFs) leverage Apache Arrow and Pandas to
significantly improve the performance of Python UDFs. They process data in batches (as
Pandas Series or DataFrames) rather than row by row.
●​ Decorator: @pandas_udf(returnType)
○​ You define a Python function that takes one or more Pandas Series as input.
○​ It must return a Pandas Series of the same length as the input.
○​ You decorate the function with @pandas_udf and specify the returnType (a
Spark SQL data type).
●​ Series-to-Series transformations:
○​ The UDF operates on pandas.Series objects, enabling vectorized operations
that are much faster than row-by-row processing.

Performance advantage (vectorized, Arrow-based)


●​ Vectorized Execution: Instead of processing one row at a time, Pandas UDFs
process batches of rows (as Pandas Series). This allows efficient use of Pandas'
optimized operations, which are often implemented in C.
●​ Apache Arrow Optimization: Spark uses Apache Arrow, an in-memory columnar
data format, to efficiently transfer data between the JVM and Python processes. This
minimizes serialization/deserialization overhead.
●​ Catalyst Integration (Improved): While still an optimization barrier, the batch
processing nature and Arrow integration allow Spark to manage the data transfer
more efficiently, leading to better performance than regular UDFs.

When to use Pandas UDFs:


●​ When existing Python libraries (like NumPy, Pandas, Scikit-learn) have functions that
are well-suited for vectorized operations and not available as Spark SQL functions.
●​ For complex custom logic that is hard to express with native Spark functions.

Example (Python - Pandas UDF):

Python
●​ from pyspark.sql import SparkSession
●​ from pyspark.sql.functions import col, pandas_udf
●​ from pyspark.sql.types import LongType, StringType
●​ import pandas as pd
●​
●​ spark = SparkSession.builder.appName("PandasUDFs").getOrCreate()
●​
●​ data = [("Alice", 30), ("Bob", 25), ("Charlie", 35)]
●​ columns = ["Name", "Age"]
●​ df = spark.createDataFrame(data, columns)
●​ df.show()
●​
●​ # Define a Pandas UDF (Series to Series)
●​ @pandas_udf(LongType()) # Return type is LongType
●​ def multiply_by_ten(series: pd.Series) -> pd.Series:
●​ return series * 10
●​
●​ print("\nDataFrame with Age * 10 (using Pandas UDF):")
●​ df.withColumn("AgeTimesTen", multiply_by_ten(col("Age"))).show()
●​
●​
●​ # Another Pandas UDF: apply a more complex string transformation
●​ @pandas_udf(StringType())
●​ def categorize_age_pandas(ages: pd.Series) -> pd.Series:
●​ conditions = [
●​ ages < 25,
●​ (ages >= 25) & (ages < 35),
●​ ages >= 35
●​ ]
●​ choices = ["Young", "Mid", "Senior"]
●​ return pd.Series(np.select(conditions, choices, default="Unknown"))
●​
●​ import numpy as np # numpy is typically used with pandas for such operations
●​
●​ print("\nDataFrame with Age Category (using Pandas UDF):")
●​ df.withColumn("AgeCategoryPandas", categorize_age_pandas(col("Age"))).show()
●​
●​ spark.stop()

New Python UDFs (Spark 3.5+): Using Python type


annotations for defining UDFs (more optimized path)
With Spark 3.5+, a new, more streamlined way to define Python UDFs was introduced that
automatically infers the return type from Python type annotations. This also aims to provide
an optimized execution path, often leveraging PySpark's internal optimizations similar to
Pandas UDFs, without requiring the explicit @pandas_udf decorator.

This feature simplifies UDF creation and in many cases provides better performance than
traditional UDFs by internally vectorizing operations where possible.

Key Features:
●​ Type Annotation Driven: Spark automatically infers the return type of the UDF from
the Python type hints in the function signature.
●​ Optimized Execution: Spark can use a more optimized execution path (similar to
Pandas UDFs) if the types and operations allow for it, potentially using Apache Arrow
for efficient data transfer.
●​ Simpler Syntax: No need for explicit udf() wrapper or @pandas_udf decorator for
basic UDFs.

Example (Python - Spark 3.5+ style):

Python

●​ from pyspark.sql import SparkSession


●​ from pyspark.sql.functions import col
●​ # No need to import udf or pandas_udf explicitly for this style if using Spark 3.5+
●​ from pyspark.sql.types import IntegerType, StringType
●​
●​ spark = SparkSession.builder.appName("NewPythonUDFs").getOrCreate()
●​
●​ data = [("Alice", 30), ("Bob", 25), ("Charlie", 35)]
●​ columns = ["Name", "Age"]
●​ df = spark.createDataFrame(data, columns)
●​ df.show()
●​
●​ # Define a Python function with type annotations
●​ # Spark 3.5+ will automatically infer this as a UDF
●​ def increment_age(age: int) -> int:
●​ return age + 1
●​
●​ # Apply the function directly to the DataFrame.
●​ # Spark's internal mechanisms will convert this to a UDF.
●​ print("\nDataFrame with Age + 1 (using new Python UDF style):")
●​ df.withColumn("AgePlusOne", increment_age(col("Age"))).show()
●​
●​
●​ # Another example with conditional logic
●​ def get_age_status(age: int) -> str:
●​ if age < 25:
●​ return "Young"
●​ elif age >= 25 and age < 35:
●​ return "Adult"
●​ else:
●​ return "Senior"
●​
●​ print("\nDataFrame with Age Status (using new Python UDF style):")
●​ df.withColumn("AgeStatus", get_age_status(col("Age"))).show()
●​
●​ spark.stop()

Important: This new style is part of ongoing improvements in PySpark. While it simplifies
syntax and often improves performance, for the absolute best performance, especially with
heavy numerical computation, explicitly using @pandas_udf for vectorized operations is still
a strong choice. Always profile your UDFs to determine the most efficient approach for your
specific workload.

This concludes our deep dive into Apache Spark for data engineering! I hope this detailed
explanation, complete with examples, provides a solid foundation for your journey in the
world of big data. Keep practicing, keep learning, and remember that mastering Spark is an
ongoing process of building, testing, and optimizing.

●​

You might also like