# notebook parameters
import os
spark_master = "local[*]"
app_name = "augment"
input_file = os.path.join("data", "WA_Fn-UseC_-Telco-Customer-Churn-.csv")
output_prefix = ""
output_mode = "overwrite"
output_kind = "parquet"
driver_memory = '12g'
executor_memory = '8g'
dup_times = 100
import churn.augment
churn.augment.register_options(
spark_master = spark_master,
app_name = app_name,
input_file = input_file,
output_prefix = output_prefix,
output_mode = output_mode,
output_kind = output_kind,
driver_memory = driver_memory,
executor_memory = executor_memory,
dup_times = dup_times,
use_decimal = True
)
We're going to make sure we're running with a compatible JVM first — if we run on macOS, we might get one that doesn't work with Scala.
from os import getenv
getenv("JAVA_HOME")
import pyspark
session = pyspark.sql.SparkSession.builder \
.master(spark_master) \
.appName(app_name) \
.config("spark.driver.memory", driver_memory) \
.config("spark.executor.memory", executor_memory) \
.getOrCreate()
session
Most of the fields are strings representing booleans or categoricals, but a few (tenure
, MonthlyCharges
, and TotalCharges
) are numeric.
from churn.augment import load_supplied_data
df = load_supplied_data(session, input_file)
The training data schema looks like this:
We want to divide the data frame into several frames that we can join together in an ETL job.
Those frames will look like this:
df.printSchema()
We'll start by generating a series of monthly charges, then a series of account creation events, and finally a series of churn events. billingEvents
is the data frame containing all of these events: account activation, account termination, and individual payment events.
from churn.augment import billing_events
billingEvents = billing_events(df)
Our next step is to generate customer metadata, which includes the following fields:
We'll calculate date of birth by using the hash of the customer ID as a pseudorandom number and then assuming that ages are uniformly distributed between 18-65 and exponentially distributed over 65.
from churn.augment import customer_meta
customerMeta = customer_meta(df)
Now we can generate customer phone features, which include:
from churn.augment import phone_features
customerPhoneFeatures = phone_features(df)
Customer internet features include:
from churn.augment import internet_features
customerInternetFeatures = internet_features(df)
Customer account features include:
from churn.augment import account_features
customerAccountFeatures = account_features(df)
%%time
from churn.augment import write_df
write_df(billingEvents, "billing_events", partition_by="month")
write_df(customerMeta, "customer_meta", skip_replication=True)
write_df(customerPhoneFeatures, "customer_phone_features")
write_df(customerInternetFeatures.orderBy("customerID"), "customer_internet_features")
write_df(customerAccountFeatures, "customer_account_features")
for f in ["billing_events", "customer_meta", "customer_phone_features", "customer_internet_features", "customer_account_features"]:
output_df = session.read.parquet("%s.parquet" % f)
print(f, output_df.select("customerID").distinct().count())
import pyspark.sql.functions as F
from functools import reduce
output_dfs = []
for f in ["billing_events", "customer_meta", "customer_phone_features", "customer_internet_features", "customer_account_features"]:
output_dfs.append(
session.read.parquet("%s.parquet" % f).select(
F.lit(f).alias("table"),
"customerID"
)
)
all_customers = reduce(lambda l, r: l.unionAll(r), output_dfs)
each_table = all_customers.groupBy("table").agg(F.approx_count_distinct("customerID").alias("approx_unique_customers"))
overall = all_customers.groupBy(F.lit("all").alias("table")).agg(F.approx_count_distinct("customerID").alias("approx_unique_customers"))
each_table.union(overall).show()
rows = each_table.union(overall).collect()
dict([(row[0], row[1]) for row in rows])