Predicting user churn with PySpark

Tamuno-omi Jaja
7 min readApr 1, 2021

--

Customers are the lifeblood of every business and irrespective of the industry, they determine the business size and it’s future direction. Most times the cost of acquiring new customers is higher than keeping existing ones, thus it’s necessary to keep the rate at which customers discontinue a service within a period (Churn rate) to a minimum, this can be reduced by identifying such customer groups and offering them incentives such as discounts and special offers. The code for the project can be found here

In this post I looked at predicting user churn using PySpark through the steps of Data wrangling, exploration, feature engineering and finally modelling using linear regression, random forest classifier and gradient-boosted trees.

The dataset was provided by Udacity, it contains simulated user logs from a music streaming service like Spotify or Audiomack called ‘Sparkify’ , the service has both free and paid tiers, users can upgrade, downgrade or cancel their service at any time, when users interact with the service they generate data, which I will be using gain insights, in this post I will be looking at identifying users who are at risk to churn i.e. Users cancelling their service.

A look at the data

The medium subset of the data contains 351,033 user events, I checked for missing values in in the userId and sessionId columns, where I discovered and removed records where the user wasn’t logged in and then extracted additional information from existing columns such as:

  • date, year, month, weekday, number of days since registration from the ts and registration columns.
  • device/os and browser from userAgent.
  • Total number of songs listened from song column.
  • State of user from the Location.

Exploratory Data Analysis

Here I explored the data in order to observe the behavior of user groups who stayed versus those who churned, I will define churn by using the Cancellation confirmation page type.

Churn by gender

Chart showing churn by gender.

It seems male customers are more likely to churn.

Churn by State

Chart showing churn by state

Certain states have a higher churn rate this can be due varying factors.

Churn by browser

chart showing churn by browser

Compatibility issues may make certain browsers have a higher churn rate amongst other factors.

Churn by Level

Churn by tier

Users seem more likely to churn while on the paid tier.

Feature Engineering

Now we have familiarized ourselves with the data, we will select features we find promising to train our model such as:

1) daysSinceReg

2) totalSongs

3) Number of Thumbs up

4) Number of Thumbs down

5) Number of songs added to playlist

6) Number of friends added

7) Listening time

8) Average songs listened per session

9) Gender

10) Number of artists listened to per user

After identifying the columns we ensure that they are all of numeric data type The Gender, label (churn) columns had to be converted into numeric values using an encoder.

Modelling

For this project, I used the Logistic Regression, Gradient Boosted trees, Support Vector Machine & Random forest as machine learning algorithms.

Now that we have our feature columns in numeric format and our machine learning algorithms is chosen, Here are the steps that were taken:

  • Normalizing and scaling: Using Vector assembler and Standard Scalar
  • Train, test, split: The data is split into three parts namely the train, test and validation sets using a seed of 43.
  • Base Model: We will check for churned as well as non-churn. Because, the dataset is uneven. The churners are far less than non-churners. Hence we will create two base models for both churners as well as non-churners. Thus making 2 base models for label churn = 1 (churned) and churn = 0 (non-churned). Find the results below
  • Final Models:

1) Gradient Boosted Tree:

Using 10 iterations and setting the seed of 42 we will use MulticlassClassificationEvaluator() with “f1” as the metric. The ParamGridBuilder() uses its default parameters here while CrossValidator() picks best model based on f1 metric. After fitting the model and using it on the validation set the model had a accuracy of: 0.65, and F1 score of: 0.69, using 849.36 seconds to complete.

# initialize classifier
gbt = GBTClassifier(maxIter=10,seed=42)
# set evaluator
f1_evaluator = MulticlassClassificationEvaluator(metricName='f1')
# build paramGrid
paramGrid = ParamGridBuilder() \
.build()
crossval_gbt = CrossValidator(estimator=gbt,
estimatorParamMaps=paramGrid,
evaluator=f1_evaluator,
numFolds=3)
start = time()
cvModel_gbt = crossval_gbt.fit(train)
end = time()
cvModel_gbt.avgMetrics
print('The training process took {} seconds'.format(end - start))
results_gbt = cvModel_gbt.transform(validation)
evaluator = MulticlassClassificationEvaluator(predictionCol="prediction")
print('Gradient Boosted Trees Metrics:')
print('Accuracy: {}'.format(evaluator.evaluate(results_gbt, {evaluator.metricName: "accuracy"})))
print('F-1 Score:{}'.format(evaluator.evaluate(results_gbt, {evaluator.metricName: "f1"})))
Gradient Boosted Trees Metrics:
Accuracy: 0.6470588235294118
F-1 Score:0.6932773109243697

2.) Logistic Regression:

The logistic regression model has an accuracy of: 0.88, and F1 score of:0.83, using 378.90 seconds to complete.

# initialize classifier
lr = LogisticRegression(maxIter=10)
# set evaluator
f1_evaluator = MulticlassClassificationEvaluator(metricName='f1')
# build paramGrid
paramGrid = ParamGridBuilder() \
.build()
crossval_lr = CrossValidator(estimator=lr,
evaluator=f1_evaluator,
estimatorParamMaps=paramGrid,
numFolds=3)
start = time()
cvModel_lr = crossval_lr.fit(train)
end = time()
cvModel_lr.avgMetrics
print('The training process took {} seconds'.format(end - start))
Logistic Regression Metrics:
Accuracy: 0.8823529411764706
F-1 Score:0.8272058823529411

3) Support Vector machine: The SVM had the following metrics SVM has an Accuracy: 0.88, F-1 Score:0.827

# initialize classifier
svm = LinearSVC(maxIter=10)
# set evaluator
f1_evaluator = MulticlassClassificationEvaluator(metricName='f1')
# build paramGrid
paramGrid = ParamGridBuilder() \
.build()
crossval_svm = CrossValidator(estimator=svm,
estimatorParamMaps=paramGrid,
evaluator=f1_evaluator,
numFolds=3)
start = time()
cvModel_svm = crossval_svm.fit(train)
end = time()
cvModel_svm.avgMetrics
print('The training process took {} seconds'.format(end - start))
evaluator = MulticlassClassificationEvaluator(predictionCol="prediction")
print('SVM Metrics:')
print('Accuracy: {}'.format(evaluator.evaluate(results_svm, {evaluator.metricName: "accuracy"})))
print('F-1 Score:{}'.format(evaluator.evaluate(results_svm, {evaluator.metricName: "f1"})))
SVM Metrics:
Accuracy: 0.8823529411764706
F-1 Score:0.8272058823529411

4) Random Forest: The random forest model had an Accuracy: 0.88, F-1 Score:0.827

# initialize classifier
rf = RandomForestClassifier()
# set evaluator
f1_evaluator = MulticlassClassificationEvaluator(metricName='f1')
# build paramGrid
paramGrid = ParamGridBuilder() \
.build()
crossval_rf = CrossValidator(estimator=rf,
estimatorParamMaps=paramGrid,
evaluator=f1_evaluator,
numFolds=3)
start = time()
cvModel_rf = crossval_rf.fit(train)
end = time()
cvModel_rf.avgMetrics
print('The training process took {} seconds'.format(end - start))
start = time()
cvModel_rf = crossval_rf.fit(train)
end = time()
cvModel_rf.avgMetrics
print('The training process took {} seconds'.format(end - start))
evaluator = MulticlassClassificationEvaluator(predictionCol="prediction")
print('Random Forest Metrics:')
print('Accuracy: {}'.format(evaluator.evaluate(results_rf, {evaluator.metricName: "accuracy"})))
print('F-1 Score:{}'.format(evaluator.evaluate(results_rf, {evaluator.metricName: "f1"})))
Random Forest Metrics: Accuracy: 0.8823529411764706 F-1 Score:0.8272058823529411

I will select the Random Forest Model as the Final Model because of its accuracy and conduct a grid search to fine tune our model this time.

In the future, we may instead implement logistic regression model model for more time efficiency.

  • Hyperparameter Tuning:

We used cross-validation to tune the hyperparameters of the model, and reached 83.91% accuracy and a F1 score of 0.77 It is interesting to note we cannot outperform the accuracy and F1 score obtained with default parameters, probably due to the small size of the dataset.

rf = RandomForestClassifier()rf_paramGrid = ParamGridBuilder() \
.addGrid(rf.minInfoGain, [0, 1]) \
.addGrid(rf.numTrees, [20, 50]) \
.addGrid(rf.maxDepth, [5, 10]) \
.build()
f1_evaluator = MulticlassClassificationEvaluator(metricName='f1')
crossval_rf = CrossValidator(estimator=rf,
estimatorParamMaps=rf_paramGrid,
evaluator=f1_evaluator,
numFolds=3)
cvModel_rf_best = crossval_rf.fit(train)
cvModel_gbt.avgMetrics
rf_best = RandomForestClassifier(numTrees=10, minInfoGain=1,maxDepth = 5)
rf_best_model = rf_best.fit(train)
results_final = rf_best_model.transform(test)
evaluator = MulticlassClassificationEvaluator(predictionCol="prediction")
print('Test set metrics:')
print('Accuracy: {}'.format(evaluator.evaluate(results_final, {evaluator.metricName: "accuracy"})))
print('F-1 Score:{}'.format(evaluator.evaluate(results_final, {evaluator.metricName: "f1"})))
Test set metrics:
Accuracy: 0.8390804597701149
F-1 Score:0.7656609195402299

Feature Importance in Random Forest:

It is also insightful to visualize which elements are most important in predicting churn. Looking at feature importance, we see that the lifetime, thumbs up/down, add friend are important predictors of churn. As we expected, a combination of behavioral and more static features help us predict churn.

Conclusion:

We implemented a model to predict customer churn on a music streaming service, after loading the data we performed cleaning tasks removing rows without user id, generated additional columns from exiting one e.g. converting timestamp to get date, year, month, weekday, number of days since registration, after that we defined churn using the ‘cancellation confirmation’ page and explored its relationship to other variable, after which we selected features we found interesting to implement in out model, we selected the logistic regression, GBM, SVM, and RF classification models and selected the Random Forest as our final model for predicting our result.

Improvement

The features can be improved a lot after considering more factors, adding more domain knowledges and expertise. Although the volume of data may required tools such as spark to analyze, but we can use more data to have better results as the user base grow, the model has a huge potential to improve if the sample size increase, and the expected performance will also increase. The classification models could be further improved in the hyperparameter tuning process with extended parameter grids to search a broader range of possible parameter combinations.

Feature Importance:

We utilized the feature importance attribute of the Random Forest model and we can observe that the length of using the service plays a very important role.

--

--