Skip to main content

Apache Spark Tutorial: ML with PySpark

Apache Spark tutorial introduces you to big data processing, analysis and ML with PySpark.
Jul 2017  · 34 min read

Apache Spark and Python for Big Data and Machine Learning

Apache Spark is known as a fast, easy-to-use and general engine for big data processing that has built-in modules for streaming, SQL, Machine Learning (ML) and graph processing. This technology is an in-demand skill for data engineers, but also data scientists can benefit from learning Spark when doing Exploratory Data Analysis (EDA), feature extraction and, of course, ML.

In this tutorial, you’ll interface Spark with Python through PySpark, the Spark Python API that exposes the Spark programming model to Python. More concretely, you’ll focus on:

Apache Spark Tutorial

If you're rather interested in using Spark with R, you should check out DataCamp’s free Introduction to Spark in R with sparklyr or download the PySpark SQL cheat sheet.

Installing Apache Spark

Installing Spark and getting it to work can be a challenge. In this section, you’ll cover some steps that will show you how to get it installed on your pc.

First thing that you want to do is checking whether you meet the prerequisites. Spark is written in Scala Programming Language and runs on Java Virtual Machine (JVM) environment. That’s why you need to check if you have a Java Development Kit (JDK) installed. You do this because the JDK will provide you with one or more implementations of the JVM. Preferably, you want to pick the latest one, which, at the time of writing is the JDK8.

Next, you’re reading to download Spark!

Downloading pyspark with pip

Then, you can download and install PySpark it with the help of pip. This is fairly easy and much like installing any other package. You just run the usual command and the heavy lifting gets done for you:

$ pip install pyspark

Alternatively, you can also go to the Spark download page. Keep the default options in the first three steps and you’ll find a downloadable link in step 4. Click on that link to download it. For this tutorial, you’ll download the 2.2.0 Spark Release and the “Pre-built for Apache Hadoop 2.7 and later” package type.

Note that the download can take some time to finish!

Downloading Spark with Homebrew

You can also install Spark with the Homebrew, a free and open-source package manager. This is especially handy if you’re working with macOS.

Simply run the following commands to search for Spark, to get more information and to finally install it on your personal computer:

# Search for spark
$ brew search spark

# Get more information on apache-spark
$ brew info apache-spark

# Install apache-spark
$ brew install apache-spark

Download and Set Up Spark

Next, make sure that you untar the directory that appears in your Downloads folder. This can happen automatically for you, by double clicking the spark-2.2.0-bin-hadoop2.7.tgz archive or by opening up your Terminal and running the following command:

$ tar xvf spark-2.2.0-bin-hadoop2.7.tgz

Next, move the untarred folder to /usr/local/spark by running the following line:

$ mv spark-2.1.0-bin-hadoop2.7 /usr/local/spark

Note that if you get an error that says that the permission is denied to move this folder to the new location, you should add sudo in front of this command. The line above will then become $ sudo mv spark-2.1.0-bin-hadoop2.7 /usr/local/spark. You’ll be prompted to give your password, which is usually the one that you also use to unlock your pc when you start it up :)

Now that you’re all set to go, open the README file in the file path /usr/local/spark. You can do this by executing

$ cd /usr/local/spark

This will brings you to the folder that you need to be. Then, you can start inspecting the folder and reading the README file that is incuded in it.

First, use $ ls to get a list of the files and folders that are in this spark folder. You’ll see that there’s a file in there. You can open it by executing one of the following commands:

# Open and edit the file
$ nano

# Just read the file 
$ cat

Tip use the tab button on your keyboard to autocomplete as you’re typing the file name :) This will save you some time.

You’ll see that this README provides you with some general information about Spark, online documentation, building Spark, the Interactive Scala and Python shells, example programs and much more.

The thing that could interest you most here is the section on how to build Spark but note that this will only be particularly relevant if you haven’t downloaded a pre-built version. For this tutorial, however, you downloaded a pre-built version. You can press CTRL + X to exit the README, which brings you back to the spark folder.

In case you selected a version that hasn’t been built yet, make sure you run the command that is listed in the README file. At the time of writing, this is the following:

$ build/mvn -DskipTests clean package run

Note that this command can take a while to run.

PySpark Basics: RDDs

Now that you’ve successfully installed Spark and PySpark, let’s first start off by exploring the interactive Spark Shell and by nailing down some of the basics that you will need when you want to get started. In the rest of this tutorial, however, you’ll work with PySpark in a Jupyter notebook.

Spark Applications Versus Spark Shell

The interactive shell is an example of a Read-Eval(uate)-Print-Loop (REPL) environment; That means that whatever you type in is read, evaluated and printed out to you so that you can continue your analysis. This might remind you of IPython, which is a powerful interactive Python shell that you might know from working with Jupyter. If you want to know more, consider reading DataCamp’s IPython or Jupyter blog post.

This means that you can use the shell, which is available for Python as well as Scala, for all interactive work that you need to do.

Besides this shell, you can also write and deploy Spark applications. In contrast to writing Spark applications, the SparkSession has already been created for you so that you can just start working and not waste valuable time on creating one.

Now you might wonder: what is the SparkSession?

Well, it’s the main entry point for Spark functionality: it represents the connection to a Spark cluster and you can use it to create RDDs and to broadcast variables on that cluster. When you’re working with Spark, everything starts and ends with this SparkSession. Note that before Spark 2.0.0, the three main connection objects were SparkContext, SqlContext and HiveContext.

You’ll see more on this later on. For now, let’s just focus on the shell.

The Python Spark Shell

From within the spark folder located at /usr/local/spark, you can run

$ ./bin/pyspark

At first, you’ll see some text appearing. And then, you’ll see “Spark” appearing, just like this:

Python 2.7.13 (v2.7.13:a06454b1afa1, Dec 17 2016, 12:39:47) 
[GCC 4.2.1 (Apple Inc. build 5666) (dot 3)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
Using Spark's default log4j profile: org/apache/spark/
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
17/07/26 11:41:26 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
17/07/26 11:41:47 WARN ObjectStore: Failed to get database global_temp, returning NoSuchObjectException
Welcome to
      ____              __
     / __/__  ___ _____/ /__
    _\ \/ _ \/ _ `/ __/  '_/
   /__ / .__/\_,_/_/ /_/\_\   version 2.2.0

Using Python version 2.7.13 (v2.7.13:a06454b1afa1, Dec 17 2016 12:39:47)
SparkSession available as 'spark'.

When you see this, you know that you’re ready to start experimenting within the interactive shell!

Tip: if you prefer using the IPython shell instead of the Spark shell, you can do this by setting the following environment variable:

export PYSPARK_DRIVER_PYTHON="/usr/local/ipython/bin/ipython"

Creating RDDs

Now, let’s start small and make an RDD, which is the most basic building block of Spark. An RDD simply represents data but it’s not one object, a collection of records, a result set or a data set. That is because it’s intended for data that resides on multiple computers: a single RDD could be spread over thousands of Java Virtual Machines (JVMs), because Spark automatically partitions the data under the hood to get this parallelism. Of course, you can adjust the parallelism to get more partitions. That’s why an RDD is actually a collection of partitions.

You can easily create a simple RDD by using the parallelize() function and by simply passing some data (an iterable, like a list, or a collection) to it:

>>> rdd1 = spark.sparkContext.parallelize([('a',7),('a',2),('b',2)])
>>> rdd2 = spark.sparkContext.parallelize([("a",["x","y","z"]), ("b",["p", "r"])])
>>> rdd3 = spark.sparkContext.parallelize(range(100))

Note that the SparkSession object has the SparkContext object, which you can access with spark.sparkContext. For backwards compatibility reasons, it’s also still possible to call the SparkContext with sc, as in rdd1 = sc.parallelize(['a',7),('a',2),('b',2)]).

RDD Operations

Now that you have created the RDDs, you can use the distributed data in rdd1 and rdd2 to operate on in parallel. You have two types of operations: transformations and actions.

Now, to intuitively get the difference between these two, consider some of the most common transformations are map(), filter(), flatMap(), sample(), randomSplit(), coalesce() and repartition() and some of the most common actions are reduce(), collect(), first(), take(), count(), saveAsHadoopFile().

Transformations are lazy operations on a RDD that create one or many new RDDs, while actions produce non-RDD values: they return a result set, a number, a file, …

You can, for example, aggregate all the elements of rdd1 using the following, simple lambda function and return the results to the driver program:

>>> rdd1.reduce(lambda a,b: a+b)

Executing this line of code will give you the following result: ('a', 7, 'a', 2, 'b', 2). Another example of a transformation is flatMapValues(), which you run on key-value pair RDDs, such as rdd2. In this case, you pass each value in the key-value pair RDD rdd2 through a flatMap function without changing the keys, which is the lambda function defined below and you perform an action after that by collecting hte results with collect().

>>> rdd2.flatMapValues(lambda x: x).collect()
[('a', 'x'), ('a', 'y'), ('a', 'z'), ('b', 'p'), ('b', 'r')]

The Data

Now that you have covered some basics with the interactive shell, it’s time to get started with some real data. For this tutorial, you’ll make use of the California Housing data set. Note, of course, that this is actually ‘small’ data and that using Spark in this context might be overkill; This tutorial is for educational purposes only and is meant to give you an idea of how you can use PySpark to build a machine learning model.

Loading and Exploring your Data

Even though you know a bit more about your data, you should take the time to go ahead and explore it more thoroughly; Before you do this, however, you will set up your Jupyter Notebook with Spark and you’ll take some first steps to defining the SparkContext.

PySpark in Jupyter Notebook

For this part of the tutorial, you won’t use the ishell but you’ll build your own application. You’ll do this in a Jupyter Notebook. You already have all the things that you need installed, so you don’t need to do much to get PySpark to work in Jupyter.

You can just launch the notebook application the same way like you always do, by running $ jupyter notebook. Then, you make a new notebook and you simply import the findspark library and use the init() function. In this case, you’re going to supply the path /usr/local/spark to init() because you’re certain that this is the path where you installed Spark.

# Import findspark 
import findspark

# Initialize and provide path

# Or use this alternative

Tip: if you have no idea whether your path is set correctly or where you have installed Spark on your pc, you can always use findspark.find() to automatically detect the location of where Spark is installed.

If you’re looking for alternative ways to work with Spark in Jupyter, consult our Apache Spark in Python: Beginner’s Guide.

Now that you have got all of that settled, you can finally start by creating your first Spark program!

Creating your First Spark Program

What you first want to be doing is importing the SparkContext from the pyspark package and initializing it. Remember that you didn’t have to do this before because the interactive Spark shell automatically created and initialized it for you! Here, you’ll need to do a little bit more work yourself :)

Import the SparkSession module from pyspark.sql and build a SparkSession with the builder() method. Afterwards, you can set the master URL to connect to, the application name, add some additional configuration like the executor memory and then lastly, use getOrCreate() to either get the current Spark session or to create one if there is none running.

# Import SparkSession
from pyspark.sql import SparkSession

# Build the SparkSession
spark = SparkSession.builder \
   .master("local") \
   .appName("Linear Regression Model") \
   .config("spark.executor.memory", "1gb") \
sc = spark.sparkContext

Note that if you get an error where there’s a FileNotFoundError similar to this one: “No such file or directory: ‘/User/YourName/Downloads/spark-2.1.0-bin-hadoop2.7/./bin/spark-submit’”, you know that you have to (re)set your Spark PATH. Go to your home directory by executing $ cd and then edit the .bash_profile file by running $ nano .bash_profile.

Add something like the following to the bottom of the file

export SPARK_HOME="/usr/local/spark"

Use CTRL + X to exit the file but make sure to save your adjustments by also entering Y to confirm the changes. Next, don’t forget to set the changes in motion by running source .bash_profile.

Tip: you can also set additional environment variables if you want; You probably don’t need them, but it’s definitely good to know that you can set them if desired. Consider the following examples:

# Set a fixed value for the hash seed secret

# Set an alternate Python executable
export PYSPARK_PYTHON=/usr/local/ipython/bin/ipython

# Augment the default search path for shared libraries
export LD_LIBRARY_PATH=/usr/local/ipython/bin/ipython

# Augment the default search path for private libraries 
export PYTHONPATH=$SPARK_HOME/python/lib/py4j-*$PYTHONPATH:$SPARK_HOME/python/

Note also that now you have initialized a default SparkSession. However, in most cases, you’ll want to configure this further. You’ll see that this will be really needed when you start working with big data. If you want to know more about it, check this page.

Loading in your Data

This tutorial makes use of the California Housing data set. It appeared in a 1997 paper titled Sparse Spatial Autoregressions, written by Pace, R. Kelley and Ronald Barry and published in the Statistics and Probability Letters journal. The researchers built this data set by using the 1990 California census data.

The data contains one row per census block group. A block group is the smallest geographical unit for which the U.S. Census Bureau publishes sample data (a block group typically has a population of 600 to 3,000 people). In this sample a block group on average includes 1425.5 individuals living in a geographically compact area. You’ll gather this information from this web page or by reading the paper which was mentioned above and which you can find here.

These spatial data contain 20,640 observations on housing prices with 9 economic variables:

  • Longitude refers to the angular distance of a geographic place north or south of the earth’s equator for each block group;
  • Latitude refers to the angular distance of a geographic place east or west of the earth’s equator for each block group;
  • Housing median age is the median age of the people that belong to a block group. Note that the median is the value that lies at the midpoint of a frequency distribution of observed values;
  • Total rooms is the total number of rooms in the houses per block group;
  • Total bedrooms is the total number of bedrooms in the houses per block group;
  • Population is the number of inhabitants of a block group;
  • Households refers to units of houses and their occupants per block group;
  • Median income is used to register the median income of people that belong to a block group; And,
  • Median house value is the dependent variable and refers to the median house value per block group.

What’s more, you also learn that all the block groups have zero entries for the independent and dependent variables have been excluded from the data.

The Median house value is the dependent variable and will be assigned the role of the target variable in your ML model.

You can download the data here. Look for the folder, download and untar it so that you can access the data folders.

Next, you’ll use the textFile() method to read in the data from the folder that you downloaded it to RDDs. This method takes an URI for the file, which is in this case the local path of your machine, and reads it as a collection of lines. For all convenience, you’ll not only read in the .data file, but also the .domain file that contains the header. This will allow you to double check the order of the variables.

# Load in the data
rdd = sc.textFile('/Users/yourName/Downloads/CaliforniaHousing/')

# Load in the header
header = sc.textFile('/Users/yourName/Downloads/CaliforniaHousing/cal_housing.domain')

Data Exploration

You already gathered a lot of information by just looking at the web page where you found the data set, but it’s always better to get hands-on and inspect your data with the help of Spark with Python, in this case.

Important to understand here is that, because Spark’s execution is “lazy” execution, nothing has been executed yet. Your data hasn’t been actually read in. The rdd and header variables are actually just concepts in your mind. You have to push Spark to work for you, so let’s use the collect() method to look at the header:


The collect() method brings the entire RDD to a single machine, and you’ll get to see the following result:

[u'longitude: continuous.', u'latitude: continuous.', u'housingMedianAge: continuous. ', u'totalRooms: continuous. ', u'totalBedrooms: continuous. ', u'population: continuous. ', u'households: continuous. ', u'medianIncome: continuous. ', u'medianHouseValue: continuous. ']

Tip: be careful when using collect()! Running this line of code can possibly cause the driver to run out of memory. That’s why the following approach with the take() method is a safer approach if you want to just print a few elements of the RDD. In general, it’s a good principle to limit your result set whenever possible, just like when you’re using SQL.

You learn that the order of the variables is the same as the one that you saw above in the presentation of the data set, and you also learn that all columns should have continuous values. Let’s force Spark to do some more work and take a look at the California housing data to confirm this.

Call the take() method on your RDD:


By executing the previous line of code, you take the first 2 elements of the RDD. The result is as you expected: because you read in the files with the textFile() function, the lines are just all read in together. The entries are separated by a single comma and the rows themselves are also separated by a comma:

[u'-122.230000,37.880000,41.000000,880.000000,129.000000,322.000000,126.000000,8.325200,452600.000000', u'-122.220000,37.860000,21.000000,7099.000000,1106.000000,2401.000000,1138.000000,8.301400,358500.000000']

You definitely need to solve this. Now, you don’t need to split the entries, but you definitely need to make sure that the rows of your data are separate elements. To solve this, you’ll use the map() function to which you pass a lambda function to split the line at the comma. Then, check your result by running the same line with the take() method, just like you did before:

Remember that lambda functions are anonymous functions which are created at runtime.

# Split lines on commas
rdd = line: line.split(","))

# Inspect the first 2 lines 

You’ll get the following result:

[[u'-122.230000', u'37.880000', u'41.000000', u'880.000000', u'129.000000', u'322.000000', u'126.000000', u'8.325200', u'452600.000000'], [u'-122.220000', u'37.860000', u'21.000000', u'7099.000000', u'1106.000000', u'2401.000000', u'1138.000000', u'8.301400', u'358500.000000']]

Alternatively, you can also use the following functions to inspect your data:

# Inspect the first line 

# Take top elements

If you’re used to working with Pandas or data frames in R, you’ll have probably also expected to see a header, but there is none. To make your life easier, you will move on from the RDD and convert it to a DataFrame. Dataframes are preferred over RDDs whenever you can use them. Especially when you’re working with Python, the performance of DataFrames is better than RDDs.

But what is the difference between the two?

You can use RDDs when you want to perform low-level transformations and actions on your unstructured data. This means that you don’t care about imposing a schema while processing or accessing the attributes by name or column. Tying in to what was said before about performance, by using RDDs, you don’t necessarily want the performance benefits that DataFrames can offer for (semi-) structured data. Use RDDs when you want to manipulate the data with functional programming constructs rather than domain specific expressions.

To recapitulate, you’ll switch to DataFrames now to use high-level expressions, to perform SQL queries to explore your data further and to gain columnar access.

So let’s do this.

The first step is to make a SchemaRDD or an RDD of Row objects with a schema. This is normal, because just like a DataFrame, you eventually want to come to a situation where you have rows and columns. Each entry is linked to a row and a certain column and columns have data types.

You’ll use the map() function again and another lambda function in which you’ll map each entry to a field in a Row. To make this more visual, consider this first line:

[u'-122.230000', u'37.880000', u'41.000000', u'880.000000', u'129.000000', u'322.000000', u'126.000000', u'8.325200', u'452600.000000']

The lambda function says that you’re going to construct a row in a SchemaRDD and that the element at index 0 will have the name “longitude”, and so on.

With this SchemaRDD in place, you can easily convert the RDD to a DataFrame with the toDF() method.

# Import the necessary modules 
from pyspark.sql import Row

# Map the RDD to a DF
df = line: Row(longitude=line[0], 

Now that you have your DataFrame df, you can inspect it with the methods that you have also used before, namely first() and take(), but also with head() and show():

# Show the top 20 rows

You’ll immediately see that this looks much different from the RDD that you were working with before:

pyspark tutorial

Tip: use df.columns to return the columns of your DataFrame.

The data seems all nicely ordered into columns, but what about the data types? By reading in your data, Spark will try to infer a schema, but has this been successful here? Use either df.dtypes or df.printSchema() to get to know more about the data types that are contained within your DataFrame.

# Print the data types of all `df` columns
# df.dtypes

# Print the schema of `df`

Because you don’t execute the first line of code, you will only get back the following result:

 |-- households: string (nullable = true)
 |-- housingMedianAge: string (nullable = true)
 |-- latitude: string (nullable = true)
 |-- longitude: string (nullable = true)
 |-- medianHouseValue: string (nullable = true)
 |-- medianIncome: string (nullable = true)
 |-- population: string (nullable = true)
 |-- totalBedRooms: string (nullable = true)
 |-- totalRooms: string (nullable = true)

All columns are still of data type string… That’s disappointing!

If you want to continue with this DataFrame, you’ll need to rectify this situation and assign “better” or more accurate data types to all columns. Your performance will also benefit from this. Intuitively, you could go for a solution like the following, where you declare that each column of the DataFrame df should be cast to a FloatType():

from pyspark.sql.types import *

df = df.withColumn("longitude", df["longitude"].cast(FloatType())) \
   .withColumn("latitude", df["latitude"].cast(FloatType())) \
   .withColumn("housingMedianAge",df["housingMedianAge"].cast(FloatType())) \
   .withColumn("totalRooms", df["totalRooms"].cast(FloatType())) \ 
   .withColumn("totalBedRooms", df["totalBedRooms"].cast(FloatType())) \ 
   .withColumn("population", df["population"].cast(FloatType())) \ 
   .withColumn("households", df["households"].cast(FloatType())) \ 
   .withColumn("medianIncome", df["medianIncome"].cast(FloatType())) \ 
   .withColumn("medianHouseValue", df["medianHouseValue"].cast(FloatType()))

But these repeated calls are quite obscure, error-proof and don’t really look nice. Why don’t you write a function that can do all of this for you in a more clean way?

The following User-Defined Function (UDF) takes a DataFrame, column names, and the new data type that you want the have the columns to have. You say that for every column name, you take the column and you cast it to a new data type. Then, you return the DataFrame:

# Import all from `sql.types`
from pyspark.sql.types import *

# Write a custom function to convert the data type of DataFrame columns
def convertColumn(df, names, newType):
  for name in names: 
     df = df.withColumn(name, df[name].cast(newType))
  return df 

# Assign all column names to `columns`
columns = ['households', 'housingMedianAge', 'latitude', 'longitude', 'medianHouseValue', 'medianIncome', 'population', 'totalBedRooms', 'totalRooms']

# Conver the `df` columns to `FloatType()`
df = convertColumn(df, columns, FloatType())

That already looks much better! You can quickly inspect the data types of df with the printSchema() method, just like you have done before.

Now that you’ve got that all sorted out, it’s time to really get started on the data exploration. You have seen that columnar access and SQL queries were two advantages of using DataFrames. Well, now it’s time to dig a little bit further into that. Let’s start small and just select two columns from df of which you only want to see 10 rows:'population','totalBedRooms').show(10)

This query gives you the following result:

|     322.0|        129.0|
|    2401.0|       1106.0|
|     496.0|        190.0|
|     558.0|        235.0|
|     565.0|        280.0|
|     413.0|        213.0|
|    1094.0|        489.0|
|    1157.0|        687.0|
|    1206.0|        665.0|
|    1551.0|        707.0|
only showing top 10 rows

You can also make your queries more complex, as you see in the following example:


Which gives you the following result:

|            52.0| 1273|
|            51.0|   48|
|            50.0|  136|
|            49.0|  134|
|            48.0|  177|
|            47.0|  198|
|            46.0|  245|
|            45.0|  294|
|            44.0|  356|
|            43.0|  353|
|            42.0|  368|
|            41.0|  296|
|            40.0|  304|
|            39.0|  369|
|            38.0|  394|
|            37.0|  537|
|            36.0|  862|
|            35.0|  824|
|            34.0|  689|
|            33.0|  615|
only showing top 20 rows

Besides querying, you can also choose to describe your data and get some summary statistics. This will most definitely help you after!


PySpark Machine Learning

Look at the minimum and maximum values of all the (numerical) attributes. You see that multiple attributes have a wide range of values: you will need to normalize your dataset.

Data Preprocessing

With all this information that you gathered from your small exploratory data analysis, you know enough to preprocess your data to feed it to the model.

  • You shouldn’t care about missing values; all zero values have been excluded from the data set.
  • You should probably standardize your data, as you have seen that the range of minimum and maximum values is quite big.
  • There are possibbly some additional attributes that you could add, such as a feature that registers the number of bedrooms per room or the rooms per household.
  • Your dependent variable is also quite big; To make your life easier, you’ll have to adjust the values slightly.

Preprocessing the Target Values

First, let’s start with the medianHouseValue, your dependent variable. To facilitate your working with the target values, you will express the house values in units of 100,000. That means that a target such as 452600.000000 should become 4.526:

# Import all from `sql.functions` 
from pyspark.sql.functions import *

# Adjust the values of `medianHouseValue`
df = df.withColumn("medianHouseValue", col("medianHouseValue")/100000)

# Show the first 2 lines of `df`

You can clearly see that the values have been adjusted correctly when you look at the result of the take() method:

[Row(households=126.0, housingMedianAge=41.0, latitude=37.880001068115234, longitude=-122.2300033569336, medianHouseValue=4.526, medianIncome=8.325200080871582, population=322.0, totalBedRooms=129.0, totalRooms=880.0), Row(households=1138.0, housingMedianAge=21.0, latitude=37.86000061035156, longitude=-122.22000122070312, medianHouseValue=3.585, medianIncome=8.301400184631348, population=2401.0, totalBedRooms=1106.0, totalRooms=7099.0)]

Feature Engineering

Now that you have adjusted the values in medianHouseValue, you can also add the additional variables that you read about above. You’re going to add the following columns to the data set:

  • Rooms per household which refers to the number of rooms in households per block group;
  • Population per household, which basically gives you an indication of how many people live in households per block group; And
  • Bedrooms per room which will give you an idea about how many rooms are bedrooms per block group;

As you’re working with DataFrames, you can best use the select() method to select the columns that you’re going to be working with, namely totalRooms, households, and population. Additionally, you have to indicate that you’re working with columns by adding the col() function to your code. Otherwise, you won’t be able to do element-wise operations like the division that you have in mind for these three variables:

# Import all from `sql.functions` if you haven't yet
from pyspark.sql.functions import *

# Divide `totalRooms` by `households`
roomsPerHousehold ="totalRooms")/col("households"))

# Divide `population` by `households`
populationPerHousehold ="population")/col("households"))

# Divide `totalBedRooms` by `totalRooms`
bedroomsPerRoom ="totalBedRooms")/col("totalRooms"))

# Add the new columns to `df`
df = df.withColumn("roomsPerHousehold", col("totalRooms")/col("households")) \
   .withColumn("populationPerHousehold", col("population")/col("households")) \
   .withColumn("bedroomsPerRoom", col("totalBedRooms")/col("totalRooms"))
# Inspect the result

You see that, for the first row, there are about 6.98 rooms per household, the households in the block group consist of about 2.5 people and the amount of bedrooms is quite low with 0.14:

Row(households=126.0, housingMedianAge=41.0, latitude=37.880001068115234, longitude=-122.2300033569336, medianHouseValue=4.526, medianIncome=8.325200080871582, population=322.0, totalBedRooms=129.0, totalRooms=880.0, roomsPerHousehold=6.984126984126984, populationPerHousehold=2.5555555555555554, bedroomsPerRoom=0.14659090909090908)

Next, -and this is already forseeing an issue that you might have when you’ll standardize the values in your data set- you’ll also re-order the values. Since you don’t want to necessarily standardize your target values, you’ll want to make sure to isolate those in your data set.

In this case, you’ll need to do this by using the select() method and passing the column names in the order that is more appropriate. In this case, the target variable medianHouseValue is put first, so that it won’t be affected by the standardization.

Note also that this is the time to leave out variables that you might not want to consider in your analysis. In this case, let’s leave out variables such as longitude, latitude, housingMedianAge and totalRooms.

# Re-order and select columns
df ="medianHouseValue", 


Now that you have re-ordered the data, you’re ready to normalize the data. Or almost, at least. There is just one more step that you need to go through: separating the features from the target variable. In essence, this boils down to isolating the first column in your DataFrame from the rest of the columns.

In this case, you’ll use the map() function that you use with RDDs to perform this action. You also see that you make use of the DenseVector() function. A dense vector is a local vector that is backed by a double array that represents its entry values. In other words, it's used to store arrays of values for use in PySpark.

Next, you go back to making a DataFrame out of the input_data and you re-label the columns by passing a list as a second argument. This list consists of the column names "label" and "features":

# Import `DenseVector`
from import DenseVector

# Define the `input_data` 
input_data = x: (x[0], DenseVector(x[1:])))

# Replace `df` with the new DataFrame
df = spark.createDataFrame(input_data, ["label", "features"])

Next, you can finally scale the data. You can use Spark ML to do this: this library will make machine learning on big data scalable and easy. You’ll find tools such as ML algorithms and everything you need to build practical ML pipelines. In this case, you don’t need to do that much preprocessing so a pipeline would maybe be overkill, but if you want to look into it, definitely consider visiting the this page.

The input columns are the features, and the output column with the rescaled that will be included in the scaled_df will be named "features_scaled":

# Import `StandardScaler` 
from import StandardScaler

# Initialize the `standardScaler`
standardScaler = StandardScaler(inputCol="features", outputCol="features_scaled")

# Fit the DataFrame to the scaler
scaler =

# Transform the data in `df` with the scaler
scaled_df = scaler.transform(df)

# Inspect the result

Let’s take a look at your DataFrame and the result. You see that, indeed, a third column features_scaled was added to your DataFrame, which you can use to compare with features:

[Row(label=4.526, features=DenseVector([129.0, 322.0, 126.0, 8.3252, 6.9841, 2.5556, 0.1466]), features_scaled=DenseVector([0.3062, 0.2843, 0.3296, 4.3821, 2.8228, 0.2461, 2.5264])), Row(label=3.585, features=DenseVector([1106.0, 2401.0, 1138.0, 8.3014, 6.2381, 2.1098, 0.1558]), features_scaled=DenseVector([2.6255, 2.1202, 2.9765, 4.3696, 2.5213, 0.2031, 2.6851]))]

Note that these lines of code are very similar to what you would be doing in Scikit-Learn.

Building a Machine Learning Model with Spark ML

With all the preprocessing done, it’s finally time to start building your Linear Regression model! Just like always, you first need to split the data into training and test sets. Luckily, this is no issue with the randomSplit() method:

# Split the data into train and test sets
train_data, test_data = scaled_df.randomSplit([.8,.2],seed=1234)

You pass in a list with two numbers that represent the size that you want your training and test sets to have and a seed, which is needed for reproducibility reasons. If you want to know more about this, consider DataCamp’s Python Machine Learning Tutorial.

Then, without further ado, you can make your model!

Note that the argument elasticNetParam corresponds to α or the vertical intercept and that the regParam or the regularization paramater corresponds to λ. Go here for more information.

# Import `LinearRegression`
from import LinearRegression

# Initialize `lr`
lr = LinearRegression(labelCol="label", maxIter=10, regParam=0.3, elasticNetParam=0.8)

# Fit the data to the model
linearModel =

With your model in place, you can generate predictions for your test data: use the transform() method to predict the labels for your test_data. Then, you can use RDD operations to extract the predictions as well as the true labels from the DataFrame and zip these two values together in a list called predictionAndLabel.

Lastly, you can then inspect the predicted and real values by simply accessing the list with square brackets []:

# Generate predictions
predicted = linearModel.transform(test_data)

# Extract the predictions and the "known" correct labels
predictions ="prediction") x: x[0])
labels ="label") x: x[0])

# Zip `predictions` and `labels` into a list
predictionAndLabel =

# Print out first 5 instances of `predictionAndLabel` 

You’ll see the following real and predicted values (in that order):

[(1.4491508524918457, 0.14999), (1.5705029404692372, 0.14999), (2.148727956912464, 0.14999), (1.5831547768979277, 0.344), (1.5182107797955968, 0.398)]

Evaluating the Model

Looking at predicted values is one thing, but another and better thing is looking at some metrics to get a better idea of how good your model actually is. You can first start by printing out the coefficients and the intercept of the model:

# Coefficients for the model

# Intercept for the model

Which gives you the following result:

# The coefficients

# The intercept

Next, you can also use the summary attribute to pull up the rootMeanSquaredError and the r2:

# Get the RMSE

# Get the R2
  • The RMSE measures how much error there is between two datasets comparing a predicted value and an observed or known value. The smaller an RMSE value, the closer predicted and observed values are.

  • The R2 (“R squared”) or the coefficient of determination is a measure that shows how close the data are to the fitted regression line. This score will always be between 0 and a 100% (or 0 to 1 in this case), where 0% indicates that the model explains none of the variability of the response data around its mean, and 100% indicates the opposite: it explains all the variability. That means that, in general, the higher the R-squared, the better the model fits your data.

You'll get back the following result:


# R2

There's definitely some improvements needed to your model! If you want to continue with this model, you can play around with the parameters that you passed to your model, the variables that you included in your original DataFrame, .... But this is where the tutorial ends for now! 

Before you Go…

Before you go, make sure to stop the SparkSession with the following line of code:


Taking Big Data Further

Congrats! You have made it to the end of this tutorial, where you learned how to make a linear regression model with the help of Spark ML.

If you are interested in learning more about PySpark, consider taking DataCamp’s Introduction to PySpark course and take a look at the Apache Spark Tutorial: ML with PySpark.


Learn more about Python and PySpark


Introduction to PySpark

4 hr
Learn to implement distributed data management and machine learning in Spark using the PySpark package.
See DetailsRight Arrow
Start Course
See MoreRight Arrow

cheat sheet

PySpark Cheat Sheet: Spark DataFrames in Python

This PySpark SQL cheat sheet is your handy companion to Apache Spark DataFrames in Python and includes code samples.
Karlijn Willems's photo

Karlijn Willems

5 min

cheat sheet

PySpark Cheat Sheet: Spark in Python

This PySpark cheat sheet with code samples covers the basics like initializing Spark in Python, loading data, sorting, and repartitioning.
Karlijn Willems's photo

Karlijn Willems

6 min


Pyspark Tutorial: Getting Started with Pyspark

Discover what Pyspark is and how it can be used while giving examples.
Natassha Selvaraj's photo

Natassha Selvaraj

10 min


Installation of PySpark (All operating systems)

This tutorial will demonstrate the installation of PySpark and hot to manage the environment variables in Windows, Linux, and Mac Operating System.

Olivia Smith

8 min


Snowflake Snowpark: A Comprehensive Introduction

Take the first steps to master in-database machine learning using Snowflake Snowpark.
Bex Tuychiev's photo

Bex Tuychiev

19 min


Python Machine Learning: Scikit-Learn Tutorial

An easy-to-follow scikit-learn tutorial that will help you get started with Python machine learning.
Kurtis Pykes 's photo

Kurtis Pykes

12 min

See MoreSee More