Profile photo of Travis Horn Travis Horn

How I Used Spark Broadcast Joins to Speed Up Cross-Border Analysis

2026-01-23
How I Used Spark Broadcast Joins to Speed Up Cross-Border Analysis

Efficiency is the name of the game when you’ve got massive data volumes that include cross-border and regional aggregations like those you might find in fintech.

As part of a project to simulate the data challenges faced by major payment processors, I built a local data pipeline using PySpark. My goal: Ingest 1 million rows of mock financial transaction data, perform complex aggregations, and optimize the performance using Spark Broadcast Joins.

Generating Mock Data

Before analyzing a single row of data, I had to first create it! I wrote a simple Python script to help me do that.

import pandas as pd
import numpy as np
from datetime import datetime, timedelta

def generate_data(num_rows=1_000_000):
    countries = ['USA', 'GBR', 'FRA', 'DEU', 'AUS', 'JPN', 'CAN', 'BRA', 'IND', 'CHN']
    currencies = {'USA': 'USD', 'GBR': 'GBP', 'FRA': 'EUR', 'DEU': 'EUR',
                  'AUS': 'AUD', 'JPN': 'JPY', 'CAN': 'CAD', 'BRA': 'BRL', 'IND': 'INR', 'CHN': 'CNY'}

    # Generate random card IDs (10 digit numbers)
    card_ids = np.random.randint(1000000000, 9999999999, size=num_rows, dtype=np.int64)

    # Randomly select issuer countries
    issuer_countries = np.random.choice(countries, size=num_rows)

    # Randomly select merchant countries
    merchant_countries = np.random.choice(countries, size=num_rows)

    # Random amounts between 5 and 500
    amounts = np.round(np.random.uniform(5, 500, size=num_rows), 2)

    # Random date within the last year
    start_date = datetime.now() - timedelta(days=365)
    random_days = np.random.randint(0, 365, size=num_rows)
    dates = [start_date + timedelta(days=int(x)) for x in random_days]
    dates_str = [d.strftime('%Y-%m-%d') for d in dates]

    # Match currency to merchant country
    txn_currencies = [currencies[ctry] for ctry in merchant_countries]

    # Merge into a dataframe
    df = pd.DataFrame({
        'card_id': card_ids,
        'issuer_country': issuer_countries,
        'merchant_country': merchant_countries,
        'amount': amounts,
        'currency': txn_currencies,
        'transaction_date': dates_str
    })

    # Save to CSV
    df.to_csv('transactions.csv', index=False)

if __name__ == "__main__":
    generate_data()

This script creates a CSV file named transactions.csv with 1 million rows of mock data including fake card IDs, issuer countries, merchant countries, amounts, currencies, and dates.

Setting Up PySpark locally

Since I’m working on my local machine and not a big cluster, I simulated a real scenario inside a Jupyter notebook. In the first cell of the notebook, I imported the necessary libraries, setup the environment variables, and started a Spark session.

import sys
import os
from pyspark.sql import SparkSession


os.environ['PYSPARK_PYTHON'] = sys.executable
os.environ['PYSPARK_DRIVER_PYTHON'] = sys.executable
os.environ['SPARK_LOCAL_IP'] = '127.0.0.1'

# Start Spark session
print(f"Targeting Java at: {os.environ['JAVA_HOME']}")
print("Starting Spark Engine...")

spark = SparkSession.builder \
    .appName("TransactionAnalysis") \
    .master("local[*]") \
    .config("spark.driver.memory", "2g") \
    .getOrCreate()

print("Spark Engine is running!")

Next, I ingested the CSV data into a Spark DataFrame.

# Read the CSV
df = spark.read.csv(
    "transactions.csv",
    header=True,
    inferSchema=True
)

# Show the first 5 rows
df.show(5)

# Check the schema
df.printSchema()

Non-SQL DataFrame Queries

Before moving on to Spark SQL, let’s take a look at how we can perform some simple analysis without SQL.

Question: What is the total transaction amount per issuer country?

from pyspark.sql.functions import col, sum, format_number

# 1. Group by Country
# 2. Sum the amount
# 3. Rename the column to 'total_amount'
# 4. Sort by highest amount
analysis_df = df.groupBy("issuer_country") \
    .agg(sum("amount").alias("total_amount")) \
    .orderBy(col("total_amount").desc())

# Show top 10 results
analysis_df.select("issuer_country", format_number("total_amount", 2).alias("formatted_total")).show(10)
issuer_countryformatted_total
CHN25,372,065.99
GBR25,315,401.06
AUS25,310,252.68
USA25,224,832.77
DEU25,224,543.69
JPN25,215,778.19
CAN25,198,698.23
FRA25,128,502.91
BRA25,109,452.75
IND25,073,186.29

Implementing Business Logic with Spark SQL

To enable more complex queries, I created a temporary view so I could use Spark SQL. This simulates what Hive does, giving structure to the raw file.

# Register the DataFrame as a temporary SQL view
df.createOrReplaceTempView("transactions")

Say we have a business requirement to calculate cross-border volume (where the issuer country is different from the merchant country). Using Spark SQL, I treated the DataFrame like a Hive table. I summed up the volumn by “corridor”, which is the path from Country A to Country B.

cross_border_query = """
    SELECT
        issuer_country as origin,
        merchant_country as destination,
        count(*) as transaction_count,
        round(sum(amount), 2) as total_volume
    FROM transactions
    WHERE issuer_country <> merchant_country
    GROUP BY issuer_country, merchant_country
    ORDER BY total_volume DESC
"""

cross_border_df = spark.sql(cross_border_query)

cross_border_df.show(5)
origindestinationtransaction_counttotal_volume
CANAUS102602602635.12
GBRCAN102242584589.64
CHNIND100972580652.0
CHNBRA101862574164.64
CHNGBR100962572917.47

Optimizing with Broadcast Joins

Say now, there is a business requirement to aggregate Regional Totals (e.g., NAM, EMEA, AP).

The Naive Approach

Initially, I you may think to use a CASE WHEN statement inside the SQL query to map countries to regions.

WITH regional_data AS (
    SELECT
        CASE
            WHEN issuer_country IN ('USA', 'CAN', 'MEX') THEN 'NAM'
            WHEN issuer_country IN ('GBR', 'FRA', 'DEU', 'ITA', 'ESP') THEN 'EMEA'
            WHEN issuer_country IN ('CHN', 'IND', 'AUS', 'JPN', 'SGP') THEN 'AP'
            ELSE 'Other'
        END as region,
        amount
    FROM transactions
)
SELECT
    region,
    count(*) as TxnCount,
    round(sum(amount), 2) as total_volume

FROM regional_data
GROUP BY
    region
ORDER BY total_volume DESC

While functional, this is hard to maintain. If a country changes regions or a new region is added, you have to rewrite the query logic.

Lookup Tables and the Broadcast Join

I refactored the solution to use a Lookup Table (country_code to region). However, joining a 1-million-row transaction table with a tiny 20-row region table can trigger a shuffle join, causing unnecessary network traffic across nodes.

The solution, at least in this case, is the Spark broadcast join. This copies the tiny lookup table to every worker node in the cluster. This allows the join to happen locally in memory, eliminating the need to shuffle the massive transaction data across the network.

First, we create the small lookup DataFrame. This would normally be ingested from a CSV or a database, but for simplicity, I hardcoded it here.

from pyspark.sql.types import StructType, StructField, StringType

region_data = [
    ("USA", "NAM"), ("CAN", "NAM"), ("MEX", "NAM"),
    ("GBR", "EMEA"), ("FRA", "EMEA"), ("DEU", "EMEA"), ("ITA", "EMEA"),
    ("CHN", "AP"), ("IND", "AP"), ("AUS", "AP"), ("JPN", "AP")
]

schema = StructType([
    StructField("country_code", StringType(), True),
    StructField("region_name", StringType(), True)
])

region_lookup_df = spark.createDataFrame(region_data, schema)

Now we can use a broadcast join to speed up the enrichment process.

from pyspark.sql.functions import broadcast

# Use 'left' join in case a country in the transactions isn't in our lookup table
enriched_df = df.join(
    broadcast(region_lookup_df),
    df.issuer_country == region_lookup_df.country_code,
    "left"
)

# Aggregate on the new column, `region_name`
result_df = enriched_df.groupBy("region_name") \
    .sum("amount") \
    .withColumnRenamed("sum(amount)", "total_volume")

result_df.show()
region_nametotal_volume
NULL2.510945275000003E7
NAM5.0423531000000134E7
EMEA7.566844766000012E7
AP1.009712831499999E8

Verification via The Execution Plan

To prove the optimization worked, I ran result_df.explain(). The physical plan confirmed the strategy:

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- HashAggregate(keys=[region_name#162], functions=[sum(amount#20)])
   +- Exchange hashpartitioning(region_name#162, 200), ENSURE_REQUIREMENTS, [plan_id=688]
      +- HashAggregate(keys=[region_name#162], functions=[partial_sum(amount#20)])
         +- Project [amount#20, region_name#162]
            +- BroadcastHashJoin [issuer_country#18], [country_code#161], LeftOuter, BuildRight, false
               :- FileScan csv [issuer_country#18,amount#20] Batched: false, DataFilters: [], Format: CSV, Location: InMemoryFileIndex(1 paths)[file:/transactions.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<issuer_country:string,amount:double>
               +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, false]),false), [plan_id=683]
                  +- Filter isnotnull(country_code#161)
                     +- Scan ExistingRDD[country_code#161,region_name#162]

The presence of BroadcastHashJoin confirms that Spark avoided the expensive shuffle, resulting in a much more efficient aggregation.

Data Lake Storage

Finally, you’ll usually want to do more with the output than just display it in the console. We want to store them in Parquet format. Parquet is a columnar storage format optimized for read-heavy analytics; the type you’ll find in Hive/Impala environments.

regional_df.coalesce(1).write.mode("overwrite").parquet("output_regional")

coalesce(1) forces Spark to merge results into a single file. This is easier to open and see, but slower for massive data. Just be aware of that trade-off.

Check for the existence of a new output_regional directory. It contains the data in Parquet format.

Understanding Big Data

This project demonstrates that efficient big data analysis requires that you understand the underlying architecture. You can creat a scalable pipeline capable of handling complex financial logic efficiently when you leverage broadcast joins for dimension tables and Parquet for storage.

You can find the complete code for this project on GitHub: Spark SQL Analysis.

Cover photo by George C on Unsplash.

Here are some more articles you might like: