How I Used Spark Broadcast Joins to Speed Up Cross-Border Analysis
Efficiency is the name of the game when you’ve got massive data volumes that include cross-border and regional aggregations like those you might find in fintech.
As part of a project to simulate the data challenges faced by major payment processors, I built a local data pipeline using PySpark. My goal: Ingest 1 million rows of mock financial transaction data, perform complex aggregations, and optimize the performance using Spark Broadcast Joins.
Generating Mock Data
Before analyzing a single row of data, I had to first create it! I wrote a simple Python script to help me do that.
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
def generate_data(num_rows=1_000_000):
countries = ['USA', 'GBR', 'FRA', 'DEU', 'AUS', 'JPN', 'CAN', 'BRA', 'IND', 'CHN']
currencies = {'USA': 'USD', 'GBR': 'GBP', 'FRA': 'EUR', 'DEU': 'EUR',
'AUS': 'AUD', 'JPN': 'JPY', 'CAN': 'CAD', 'BRA': 'BRL', 'IND': 'INR', 'CHN': 'CNY'}
# Generate random card IDs (10 digit numbers)
card_ids = np.random.randint(1000000000, 9999999999, size=num_rows, dtype=np.int64)
# Randomly select issuer countries
issuer_countries = np.random.choice(countries, size=num_rows)
# Randomly select merchant countries
merchant_countries = np.random.choice(countries, size=num_rows)
# Random amounts between 5 and 500
amounts = np.round(np.random.uniform(5, 500, size=num_rows), 2)
# Random date within the last year
start_date = datetime.now() - timedelta(days=365)
random_days = np.random.randint(0, 365, size=num_rows)
dates = [start_date + timedelta(days=int(x)) for x in random_days]
dates_str = [d.strftime('%Y-%m-%d') for d in dates]
# Match currency to merchant country
txn_currencies = [currencies[ctry] for ctry in merchant_countries]
# Merge into a dataframe
df = pd.DataFrame({
'card_id': card_ids,
'issuer_country': issuer_countries,
'merchant_country': merchant_countries,
'amount': amounts,
'currency': txn_currencies,
'transaction_date': dates_str
})
# Save to CSV
df.to_csv('transactions.csv', index=False)
if __name__ == "__main__":
generate_data()
This script creates a CSV file named transactions.csv with 1 million rows of
mock data including fake card IDs, issuer countries, merchant countries,
amounts, currencies, and dates.
Setting Up PySpark locally
Since I’m working on my local machine and not a big cluster, I simulated a real scenario inside a Jupyter notebook. In the first cell of the notebook, I imported the necessary libraries, setup the environment variables, and started a Spark session.
import sys
import os
from pyspark.sql import SparkSession
os.environ['PYSPARK_PYTHON'] = sys.executable
os.environ['PYSPARK_DRIVER_PYTHON'] = sys.executable
os.environ['SPARK_LOCAL_IP'] = '127.0.0.1'
# Start Spark session
print(f"Targeting Java at: {os.environ['JAVA_HOME']}")
print("Starting Spark Engine...")
spark = SparkSession.builder \
.appName("TransactionAnalysis") \
.master("local[*]") \
.config("spark.driver.memory", "2g") \
.getOrCreate()
print("Spark Engine is running!")
Next, I ingested the CSV data into a Spark DataFrame.
# Read the CSV
df = spark.read.csv(
"transactions.csv",
header=True,
inferSchema=True
)
# Show the first 5 rows
df.show(5)
# Check the schema
df.printSchema()
Non-SQL DataFrame Queries
Before moving on to Spark SQL, let’s take a look at how we can perform some simple analysis without SQL.
Question: What is the total transaction amount per issuer country?
from pyspark.sql.functions import col, sum, format_number
# 1. Group by Country
# 2. Sum the amount
# 3. Rename the column to 'total_amount'
# 4. Sort by highest amount
analysis_df = df.groupBy("issuer_country") \
.agg(sum("amount").alias("total_amount")) \
.orderBy(col("total_amount").desc())
# Show top 10 results
analysis_df.select("issuer_country", format_number("total_amount", 2).alias("formatted_total")).show(10)
| issuer_country | formatted_total |
|---|---|
| CHN | 25,372,065.99 |
| GBR | 25,315,401.06 |
| AUS | 25,310,252.68 |
| USA | 25,224,832.77 |
| DEU | 25,224,543.69 |
| JPN | 25,215,778.19 |
| CAN | 25,198,698.23 |
| FRA | 25,128,502.91 |
| BRA | 25,109,452.75 |
| IND | 25,073,186.29 |
Implementing Business Logic with Spark SQL
To enable more complex queries, I created a temporary view so I could use Spark SQL. This simulates what Hive does, giving structure to the raw file.
# Register the DataFrame as a temporary SQL view
df.createOrReplaceTempView("transactions")
Say we have a business requirement to calculate cross-border volume (where the issuer country is different from the merchant country). Using Spark SQL, I treated the DataFrame like a Hive table. I summed up the volumn by “corridor”, which is the path from Country A to Country B.
cross_border_query = """
SELECT
issuer_country as origin,
merchant_country as destination,
count(*) as transaction_count,
round(sum(amount), 2) as total_volume
FROM transactions
WHERE issuer_country <> merchant_country
GROUP BY issuer_country, merchant_country
ORDER BY total_volume DESC
"""
cross_border_df = spark.sql(cross_border_query)
cross_border_df.show(5)
| origin | destination | transaction_count | total_volume |
|---|---|---|---|
| CAN | AUS | 10260 | 2602635.12 |
| GBR | CAN | 10224 | 2584589.64 |
| CHN | IND | 10097 | 2580652.0 |
| CHN | BRA | 10186 | 2574164.64 |
| CHN | GBR | 10096 | 2572917.47 |
Optimizing with Broadcast Joins
Say now, there is a business requirement to aggregate Regional Totals (e.g., NAM, EMEA, AP).
The Naive Approach
Initially, I you may think to use a CASE WHEN statement inside the SQL query
to map countries to regions.
WITH regional_data AS (
SELECT
CASE
WHEN issuer_country IN ('USA', 'CAN', 'MEX') THEN 'NAM'
WHEN issuer_country IN ('GBR', 'FRA', 'DEU', 'ITA', 'ESP') THEN 'EMEA'
WHEN issuer_country IN ('CHN', 'IND', 'AUS', 'JPN', 'SGP') THEN 'AP'
ELSE 'Other'
END as region,
amount
FROM transactions
)
SELECT
region,
count(*) as TxnCount,
round(sum(amount), 2) as total_volume
FROM regional_data
GROUP BY
region
ORDER BY total_volume DESC
While functional, this is hard to maintain. If a country changes regions or a new region is added, you have to rewrite the query logic.
Lookup Tables and the Broadcast Join
I refactored the solution to use a Lookup Table (country_code to region).
However, joining a 1-million-row transaction table with a tiny 20-row region
table can trigger a shuffle
join,
causing unnecessary network traffic across nodes.
The solution, at least in this case, is the Spark broadcast join. This copies the tiny lookup table to every worker node in the cluster. This allows the join to happen locally in memory, eliminating the need to shuffle the massive transaction data across the network.
First, we create the small lookup DataFrame. This would normally be ingested from a CSV or a database, but for simplicity, I hardcoded it here.
from pyspark.sql.types import StructType, StructField, StringType
region_data = [
("USA", "NAM"), ("CAN", "NAM"), ("MEX", "NAM"),
("GBR", "EMEA"), ("FRA", "EMEA"), ("DEU", "EMEA"), ("ITA", "EMEA"),
("CHN", "AP"), ("IND", "AP"), ("AUS", "AP"), ("JPN", "AP")
]
schema = StructType([
StructField("country_code", StringType(), True),
StructField("region_name", StringType(), True)
])
region_lookup_df = spark.createDataFrame(region_data, schema)
Now we can use a broadcast join to speed up the enrichment process.
from pyspark.sql.functions import broadcast
# Use 'left' join in case a country in the transactions isn't in our lookup table
enriched_df = df.join(
broadcast(region_lookup_df),
df.issuer_country == region_lookup_df.country_code,
"left"
)
# Aggregate on the new column, `region_name`
result_df = enriched_df.groupBy("region_name") \
.sum("amount") \
.withColumnRenamed("sum(amount)", "total_volume")
result_df.show()
| region_name | total_volume |
|---|---|
| NULL | 2.510945275000003E7 |
| NAM | 5.0423531000000134E7 |
| EMEA | 7.566844766000012E7 |
| AP | 1.009712831499999E8 |
Verification via The Execution Plan
To prove the optimization worked, I ran result_df.explain(). The physical plan
confirmed the strategy:
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- HashAggregate(keys=[region_name#162], functions=[sum(amount#20)])
+- Exchange hashpartitioning(region_name#162, 200), ENSURE_REQUIREMENTS, [plan_id=688]
+- HashAggregate(keys=[region_name#162], functions=[partial_sum(amount#20)])
+- Project [amount#20, region_name#162]
+- BroadcastHashJoin [issuer_country#18], [country_code#161], LeftOuter, BuildRight, false
:- FileScan csv [issuer_country#18,amount#20] Batched: false, DataFilters: [], Format: CSV, Location: InMemoryFileIndex(1 paths)[file:/transactions.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<issuer_country:string,amount:double>
+- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, false]),false), [plan_id=683]
+- Filter isnotnull(country_code#161)
+- Scan ExistingRDD[country_code#161,region_name#162]
The presence of BroadcastHashJoin confirms that Spark avoided the expensive
shuffle, resulting in a much more efficient aggregation.
Data Lake Storage
Finally, you’ll usually want to do more with the output than just display it in the console. We want to store them in Parquet format. Parquet is a columnar storage format optimized for read-heavy analytics; the type you’ll find in Hive/Impala environments.
regional_df.coalesce(1).write.mode("overwrite").parquet("output_regional")
coalesce(1) forces Spark to merge results into a single file. This is easier
to open and see, but slower for massive data. Just be aware of that trade-off.
Check for the existence of a new output_regional directory. It contains the
data in Parquet format.
Understanding Big Data
This project demonstrates that efficient big data analysis requires that you understand the underlying architecture. You can creat a scalable pipeline capable of handling complex financial logic efficiently when you leverage broadcast joins for dimension tables and Parquet for storage.
You can find the complete code for this project on GitHub: Spark SQL Analysis.
Travis Horn