Wavelet image hash in Python

For several weekends, I had fun playing Kaggle: Avito Duplicate Ads Detection problem. This machine learning problem includes more than 10 million images in addition to the structured data set. In this competition, many players use image hashes instead of the actual images to optimize the model creation process.

What I found interesting is – most of the implementation of the image hashing uses a standard Discrete Cosine Transformation (CDT). I used to work with images many years back and remember that Discrete Wavelet Transformation (DWT) might give better results for images. I was unable to find any Python implementation DWT based image hashing, so I implemented one and pushed to the imagehash library. The change is available in the master branch on github and in the new version of the package. In this blogpost, I will describe how it works concisely.

1. Imagehash Python library

The most simple and effective library that I found was the imagehash library from Johannes Bucher. There were several image hashes implemented in the library: aHash, pHash, dHash. All three of the approaches scale an image into a grayscale 8×8 image first. Then the library performs some calculations for each of these 64 pixels and assigns a binary 1 or 0 value. These 64 bits form the output of algorithm. The bit computation methods are different:

  1. aHash – average hash, for each of the pixels output 1 if the pixel is bigger or equal to the average and 0 otherwise.
  2. pHash – perceptive hash, does the same as aHash, but first it does a Discrete Cosine Transformation and works in the frequency domain.
  3. dHash – gradient hash, calculate the difference for each of the pixel and compares the difference with the average differences.
  4. * wHash – wavelet hashing, that I added to the library a couple days back. It works in the frequency domain as pHash but it uses DWT instead of DCT.

You can fine more detailed description of the hashes in this blogpost.

The code below shows how to use the library.

import PIL

from PIL import Image

import imagehash

hash1 = imagehash.phash(Image.open(‘test1.jpg’))



hash2 = imagehash.phash(Image.open(‘test2.jpg’))



hash1 == hash2

> False

hash1 – hash2


The two images from the code example are definitely not equal. The 44 bits out of 64 are different. Similar images will have a difference up to 6-8 bits.

UPDATE: A typo was found by LIN China. The hash values and the difference were changed.

2. Calculate image hash

For regular photos, frequency based methods like pHash usually give better results because the frequency domain is more stable for images transformations like:

  • JPG compression
  • color schema change or applying image filters
  • size scaling
  • and even some minor image editing: cutting part of an image, marking an image by watermark, adding text of modifying an image .

For example, let’s take a look at an image and a transformed version of the same image. This is going to be a very popular Lenna image. Many image processing researches use this picture. I remember this picture pretty well from my student days when I did some image researches more than some 10 years back.

Lenna.png. Original image. Size 512×512.

Let’s make some basic transformations on the image and compare the hashes. First of all we will introduce size change from 512×512 to 400×400 pixels. Then we will change color schema and then compress to JPEG for the final step.

Lenna1.jpg. Color schema and image size were changed. JPG compressed. Size 400×400

import PIL

from PIL import Image

import imagehash

lenna = PIL.Image.open(‘lenna.png’)

lenna1 = PIL.Image.open(‘lenna1.jpg’)

h = imagehash.phash(lenna)

h1 = imagehash.phash(lenna1)


> 0

Ha… not bad! No difference in the image hashes even after compression, resizing and color changing.

Let’s apply more transformations to the lenna1.jpg image (not the original one):

  • take only the central part of the picture
  • add text
  • compress again
Lenna2.jpg. More image transformations. Size 317×360

I shared all three images: lenna.png, lenna1.jpg, lenna2.jpg.

lenna2 = PIL.Image.open(‘lenna2.jpg’)

h2 = imagehash.phash(lenna2)

h – h2

> 20

(h – h2)/len(h.hash)**2

> 0.3125

All right. Now we can see the hash difference is 20, or 31.2% per hash bit. The second metric is much better because the hash size is varies for different hashes.

aHash brings different results. Even simple transformation of lenna1.jpg shows 1.6% hash difference. More aggressive lenna2.jpg gives 29.7 % difference.

a = imagehash.average_hash(lenna)

a1 = imagehash.average_hash(lenna1)

a2 = imagehash.average_hash(lenna2)

a – a1

(a – a1)/len(a.hash)**2

> 0.015625

(a – a2)/len(a.hash)**2

> 0.296875

3. Wavelet hash

Discrete Wavelet Transformation (DWT) is another form of frequency representation. The popular DCT and Fourier transformations use a set of sin\cos functions as a basis: sin(x), sin(2x), sin(3x), etc. In contrast, DWT uses one single function as a basis but in different forms: scaled and shifted. The basis function can be changed and this is why we can have Haar wavelet, Daubechie-4 wavelet etc. This scaling effect gives us a great “time-frequency representation” when the low frequency part looks similar to the original signal.

There is a great Python library for wavelets – pywt. I used this library to implement whash() method for the imagehash library. By default whash() computes 8×8 hash using Haar transformation. In addition, the method removes the lowest Haar frequency LL(max). The lowest frequency consists from only one data point/pixel and this point represent the contrast of the image and isn’t so useful for hashing.

wHash Python code is below:

w = imagehash.whash(lenna)

w1 = imagehash.whash(lenna1)

w2 = imagehash.whash(lenna2)

(w – w1)/len(w.hash)**2

> 0.03125

(w – w2)/len(w.hash)**2

> 0.28125

4. Validation

To make results cleaner, let’s compare the original image with another one. The expected hash difference should be 50%. Here is another standard image for comparison – barbara.jpg. Let’s calculate the hash difference between Lenna and Barbara using all hashes. The code a listed below:


barb = PIL.Image.open(‘barbara.jpg’)

w_b = imagehash.whash(barb)

h_b = imagehash.phash(barb)

a_b = imagehash.average_hash(barb)

(a – a_b)/len(a.hash)**2

> 0.5

(h – h_b)/len(w.hash)**2

> 0.53125

(w – w_b)/len(w.hash)**2

> 0.4375

Table with all results:

aHash pHash wHash
lenna vs. lenna1 1.6% 0% 3.1%
lenna vs. lenna2 29.7% 31.3% 28.1%
lenna vs. barbara 50% 53.1% 43.8%

In the new whash() method, we can play with different parameters. The most important thing in whash() is the hash size. It is 8 by default but you can change it by any power of 2 number less than input image size (minimum by an image dimensions). Also, you can avoid removing the lowest frequency by setting parameter remove_max_haar_ll to False. In addition, you can change the initial scaling of the image rom 64 (which is 8×8) to any power of 2 less than the image size.

The most interesting parameter is mode – wavelet families. By default the library use haar wavelet but the value can be change to any value from pywt library like ‘db4’. See the library page.

import pywt


> [‘haar’, ‘db’, ‘sym’, ‘coif’, ‘bior’, ‘rbio’, ‘dmey’]

5. Known issues

I had an issue when processed big number of small images. It looks like pywt has a memory leak. An issue was created in github. I’ll try to contact pywt creators regarding the issue.

To mitigate the issue I just split images by directories with ~50K images each and re-run processing for each directory separately.


It is hard to say which of the methods provides better results. It depends on your application and you should focus on your application or machine learning model metrics like precision\recall or AUC. For my Kaggle score the whash() brought +0.04% to AUC metric, in addition to my current ~92.9% result.

It doesn’t look like a huge difference. However, we should remember that in the modeling code, we achieved this by a one-letter change from phash() to whash(). It is nice to have more advanced analytical tools and  I hope this method will be a good addition to your analytical toolbox. In addition, I believe that wHash has a great potential for tuning by the method parameters.

Please share your experience in using the library. Any comments, suggestions, code improvements and fixes are highly appreciated.

How to check hypotheses with bootstrap and Apache Spark?

There is a featureI really like in Apache Spark. Spark can process data out of memory in my local machine even without a cluster. Good news for those who process data sets bigger than the memory size that currently have. From time to time, I have this issue when I work with hypothesis testing.

For hypothesis testing I usually use statistical bootstrapping techniques. This method does not require any statistical knowledge and is very easy to understand. Also, this method is very simple to implement. There are no normal distributions and student distributions from your statistical courses, only some basic coding skills. Good news for those who doesn’t like statistics. Spark and bootstrapping is a very powerful combination which can help you check hypotheses in a large scale.

1. Bootstrap methods

The most common application with bootstrapping is calculating confidence intervals and you can use these confidence intervals as a part of the hypotheses checking process. There is a very simple idea behind bootstrapping – sample your data set size N for hundreds or even thousands times with the replacement (this is important) and calculate the estimated metrics for each of the hundreds\thousands subset. This process gives you a histogram which is an actual distribution for your data. Then, you can use this actual distribution for hypothesis testing.

The beauty of this method is the actual distribution histogram. In a classical statistical approach, you need to approximate a distribution of your data by normal distribution and calculate z-scores or student-scores based on theoretical distributions. With the actual distribution from the first step it is easy to calculate 2.5% percentile and 97.5% percentiles and this would be your actual confidence interval. That’s it! Confident interval with almost no math.

2. Choosing the right hypothesis

Choosing right hypotheses is only the tricky part in this analytical process. This is a question you ask the data and you cannot automate that. Hypotheses testing is a part of the analytical process and isn’t usual for machine learning experts. In machine learning you ask an algorithm to build a model\structure which is sometimes called hypothesis and you are looking for the best hypotheses which correlates your data and labels.

In the analytics process, knowing the correlation is not enough, you should know the hypothesis from the get-go and the question is – if the hypothesis is correct and what is your level of confidence.

If you have a correct hypotheses it is easy to check the hypotheses based on the bootstrapping approach. For example let’s try to check the hypothesis in which we take an average for some feature in your dataset that is equal to 30.0. We should start with a null hypothesis H0 which we try to reject and an alternative hypothesis H1:

H0: mean(A) == 30.0

H1: meanA() != 30.0

If we fail to reject H0 we will take this hypothesis as ground truth. That’s what we need. If we don’t – then we should come up with a better hypothesis (mean(A) == 40).

3. Checking hypotheses

For the hypotheses checking we can simply calculate the confidence interval for dataset A by sampling and calculating 95% confidence interval. If the interval does not contain 30.0 then your hypotheses H0 was rejected.

Obviously, this confident interval starts with 2.5% and ends 97.5% which gives us 95% of the items between this interval. In the sorted array of our observations we should find 2.5% and 97.5% percentiles: p1 and p2. If p1 <= 30.0 <= p2, then we weren’t able to reject H0. So, we can suppose that H0 is the truth.

4. Apache Spark code

Implementation of bootstrapping in this particular case is straight forward.

import scala.util.Sorting.quickSort
def getConfInterval(input: org.apache.spark.rdd.RDD[Double], N: Int, left: Double, right:Double)
            : (Double, Double) = {
    // Simulate by sampling and calculating averages for each of subsamples
    val hist = Array.fill(N){0.0}
    for (i <- 0 to N-1) {
        hist(i) = input.sample(withReplacement = true, fraction = 1.0).mean
    // Sort the averages and calculate quantiles
    val left_quantile  = hist((N*left).toInt)
    val right_quantile = hist((N*right).toInt)
    return (left_quantile, right_quantile)

Because I did not find any good open datasets for the large scale hypotheses testing problem, let’s use skewdata.csv dataset from the book “Statistics: An Introduction Using R”. You can find this dataset in this archive. It is not perfect but will work in a pinch.

val dataWithHeader = sc.textFile("zipped/skewdata.csv")
val header = dataWithHeader.first
val data = dataWithHeader.filter( _ != header ).map( _.toDouble )

val (left_qt, right_qt) = getConfInterval(data, 1000, 0.025, 0.975)
val H0_mean = 30

if (left_qt < H0_mean && H0_mean < right_qt) {
    println("We failed to reject H0. It seems like H0 is correct.")
} else {
    println("We rejected H0")

We have to understand the difference between “filed to reject H0” and “proof H0”. A failing to reject a hypothesis gives you a pretty strong level of evidence that the hypothesis is correct and you can use this information in your decision making process but this is not an actual proof.

5. Equal means code example

Another type of hypotheses – check if the means of the two datasets are different. This leads us to the usual design of experiment questions – if you apply some change in your web system (user interface change for example) would your click rate change in a positive direction?

Let’s create a hypothesis:

H0: mean(A) == mean(B)

H1: mean(A) > mean(B)

It is not easy to find H1 for this hypothesis which we can prove. Let’s change this hypothesis around a little bit:

Ho’: mean(A-B) == 0

H1: mean(A-B) > 0

Now we can try to reject H0′.

def getConfIntervalTwoMeans(input1: org.apache.spark.rdd.RDD[Double],
                    input2: org.apache.spark.rdd.RDD[Double],
                    N: Int, left: Double, right:Double)
            : (Double, Double) = {
    // Simulate average of differences
    val hist = Array.fill(N){0.0}
    for (i <- 0 to N-1) {
        val mean1 = input1.sample(withReplacement = true, fraction = 1.0).mean
        val mean2 = input2.sample(withReplacement = true, fraction = 1.0).mean
        hist(i) = mean2 - mean1

    // Sort the averages and calculate quantiles
    val left_quantile  = hist((N*left).toInt)
    val right_quantile = hist((N*right).toInt)
    return (left_quantile, right_quantile)

We should change 2.5% and 97.5% percentiles in the interval to 5% percentile in the left side only because of one-side (one-tailed) hypothesis testing. And an actual code as an example:

// Let's try to check the same dataset with itself. Ha-ha.
val (left_qt, right_qt) = getConfIntervalTwoMeans(data, data, 1000, 0.05, 0.95)

// A condition was changed because of one-tailed test.
if (left_qt > 0) {
    println("We failed to reject H0. It seems like H0 is correct."
} else {
    println("We rejected H0")


Bootstrapping methods are very simple for understanding and implementation. They are intuitively simple and you don’t need any deep knowledge of statistics. Apache Spark can help you implement these methods in a large scale.

As I mentioned previously it is not easy to find a good open large dataset for hypotheses testing. Please share with our community if you have one or come across one.

My code is shared in this Scala file.

How to export data-frame from Apache Spark

Apache Spark is a great tool for working with a large amount of data like terabytes and petabytes in a cluster. It’s also very useful in local machine when gigabytes of data do not fit your memory. Normally we use Spark for preparing data and very basic analytic tasks. However, it is not advanced analytical features or even visualization. So, therefore, you have to reduce the amount of data to fit your computer memory capacity. It turns out that Apache Spark still lack the ability to export data in a simple format like CSV.

1. spark-csv library

I was really surprised when I realized that Spark does not have a CSV exporting features from the box. It turns out that CSV library is an external project. This is must-have library for Spark and I find it funny that this appears to be a marketing plug for Databricks than an Apache Spark project.

Another surprise is this library does not create one single file. It creates several files based on the data frame partitioning. This means that for one single data-frame it creates several CSV files. I understand that this is good for optimization in a distributed environment but you don’t need this to extract data to R or Python scripts.

2. Export from data-frame to CSV

Let’s take a closer look to see how this library works and export CSV from data-frame.

    option("header", "true").

You should include this library in your Spark environment. From spark-shell just add –packages parameter:

1) for scala 2.10:
        bin/spark-shell --packages com.databricks:spark-csv_2.10:1.3.0
2) for scala 2.11:
        bin/spark-shell --packages com.databricks:spark-csv_2.11:1.3.0

This code creates a directory myfile.csv with several CSV files and metadata files. If you need single CSV file, you have to implicitly create one single partition.

    option("header", "true").

We should export data the directory with Parquet data, more CSV to the correct place and remove the directory with all the files. Let’s automate this process:

def saveDfToCsv(df: DataFrame, tsvOutput: String,
                sep: String = ",", header: Boolean = false): Unit = {
    val tmpParquetDir = "Posts.tmp.parquet"
        option("header", header.toString).
        option("delimiter", sep).
    val dir = new File(tmpParquetDir)
    val tmpTsvFile = tmpParquetDir + File.separatorChar + "part-00000"
    (new File(tmpTsvFile)).renameTo(new File(tsvOutput))
    dir.listFiles.foreach( f => f.delete )


Apache Spark has many great aspects about it. At this time it cannot be the be-all answer. Usually, you have to pair Spark with your analytical tools like R or Python. However, improvement are constantly being made.

How Much Memory Does A Data Scientist Need?

Recently, I discovered an interesting blog post Big RAM is eating big data – Size of datasets used for analytics from Szilard Pafka. He says that “Big RAM is eating big data”. This phrase means that the growth of the memory size is much faster than the growth of the data sets that typical data scientist process. So, data scientist do not need as much data as the industry offers to them. Would you agree?

I do not agree. This result does not match my intuition. During my research I found an infrastructure bias in the data from this blogpost. I’ll show that the growth of the datasets is approximately the same as the memory growth in Amazon AWS rented machines and the Apple MacBook Pro laptops during the last 10 years.

1. The blog post results

According to “Big RAM is eating big data” blog post, the amount of memory in the Amazon AWS machines grow faster (50% per year) than the median datasets (20% per year) that people use for analytics. This result is based on KDNuggets survey about data sizes: Poll Results: Where is Big Data? For most, Largest Dataset Analyzed is in laptop-size GB rangeYou might find the most recent survey dataset here in Github.

Let’s take a look at the data and results more closely. The cumulative distribution of dataset sizes for a few select years is below:memory-size

I did not find a code from the post. So, I reproduced this research in R.

Below is my R code to create this graph from the dataset file.


file <- "dataset-sizes.cv"
data <- read.csv(file, sep="\t")
data.slice <- data %>%
        filter(year == 2006 | year == 2009 | year == 2012 | year == 2015)
data.slice.cum_freq <- data.slice %>%
        group_by(year, sizeGB) %>%
        summarise(value = sum(freq)) %>%
        mutate(user_prop = value/sum(value), cum_freq = cumsum(value)/sum(value)) 

ggplot(data.slice.cum_freq, aes(x=log10(sizeGB), y=cum_freq, color=factor(year))) + 
        geom_line(aes(group = factor(year)))

He mentioned that cumulative distribution function looks like linear in the 0.1-0.9 range (10 megabytes to 10 petabytes). By fitting the linear model for this range you might calculate the difference between these years.

My R code:

data.slice.reg <- data.slice.cum_freq %>%
        filter(log10(sizeGB) >= -2) %>%
        filter(log10(sizeGB) <= 4)

ggplot(data.slice.reg, aes(x=log10(sizeGB), y=cum_freq, color=factor(year))) + 
        geom_line(aes(group = factor(year)))


model <- lm(log10(sizeGB) ~ cum_freq + year, na.action=na.exclude)

From the model summary you might find the coefficient corresponding to the year variable which is equal to 0.08821 from my code (0.075 from the blogpost). This coefficient corresponds to log10(sizeGB). After the conversion from log10(GB) back to GB we will get 10^0.088 = 1.22 which give us 22%, or roughly 20%, growth in datasets.

This 20% growth is what he compares to the AWS maximum memory instance size for the same year ranges:

year type RAM (GB)
2007 m1.xlarge 15
2009 m2.4xlarge 68
2012 hs1.8xlarge 117
2014 r3.8xlarge 244
2016* x1 2 TB

A change from 15GB in 2007 to 244GB in 2014 give us approximately 50% AWS memory growth which is much higher than the datasets growth and shows that data scientists do not need as much memory according to the blog post.

3. An intuition about memory size

So, we got the same result as in the blog post. However, I can’t say that I agree with this study result. My intuition tells me that more memory gives me more luxury in data processing and analytics. The ability to work with a large amount of data could simplify the analytics process. Due to the memory constraints, I feel this squeeze constantly.

Another aspect of the memory issue is the data preparation step. Today you need two set of skills – preparing “big data” (usually in-disk processing using Unix grep, awk, Python, Apache Spark in standalone mode etc..) and in-memory analytics (R, Python scipy). These two things are very different. Good data scientists will have both skills. However, if you have a large amount of memory you don’t need the first skill because you can prepare data in R or Python directly. This is especially important for text analytics where the amount of input data is huge by default. So, data processing becomes simplified with the large amount of memory in your machine.

I can’t imagine saying “Okay, I don’t need any more memory and more CPU cores”. Additianally, I can add “…and please stop parallelizing my nice sequential code!”.

3. AWS memory growth

It looks like the maximum amount of memory in a rented AWS instance is not the best proxy for estimating the amount of memory that data scientists use. There are three reasons for that:

  1. High performance computing (HPC) machines are a relatively new products which have been introduced in around 2010 and AWS HPC product creates a strong bias in the analytics memory v.s. the AWS memory correlation. The research jumps from regular machines in 2006 to 2010 to HPC ones from 2010 to 2015. Thereby, giving us an improvement in 50%. However, in my humble opinion, I believe that the improvement is less (perhaps closer to 20% as in the median data size).
  2. The price of AWS HPC machines is much higher than many companies can afford ($2-$3K/month). A couple of months of using this kind of machine is more expensive than a brand new shiny MacBook Pro with 16Gb of RAM memory and 1Tb SSD disk.
  3. It is not easy and efficient to use remote AWS machines. Not a big deal. However, I believe that many sata scientists would prefer to use their local machines, especially Apple fans :).

In my mind, HPC machines create a bias in this research and we should estimate memory usage only by regular AWS machines not including HPC and memory optimized machines. Here is the AWS history for regular machines:

year type RAM (GB)
2006 m1.small 1.75
2007 m1.xlarge 15
2009 m2.4xlarge 68
2015 m4.10xlarge 160

From this table I’d exclude 2006 and m1.small because it was a limited beta and obviously m1.small is the m1.xlarge machine “sliced” by 8 parts. The blogger did the same – he started from 2007.

Side note: As luck would have it, my AWS experience started in that same 2007 year. For 2007, it was an amazing experience to rent a machine in one minute as apposite to days or even weeks in hosting companies previously. During this time frame, my experience was mostly in working with regular AWS machines. HPC machines were specialized and overpriced for my purposes.

So, let’s start the AWS regular machine history from 2007 with m1.xlarge then the AWS memory growth would be 35% annually during these years: 15GB*1.35^8year ~ 160GB.

Based on this result, it is closer to the growth of the datasets. As you can see the difference is 20% vs. 35%. Consequently, this cannot be agreed as strong evidence for the unimportance of RAM memory.

Let’s have more fun…

4. Apple MacBook Pro memory growth

I think that many people analyse data in their local machines and laptops. I think that most people are not ready to switch from their shiny laptops with a cozy local environment to a remote AWS machine for analytics. At least it is not easy for me and I’ll find a way to process a relatively large amount of data in my laptop (I suppose that a cluster is not needed).

Let’s try to use Apple MacBook Pro as a proxy for estimating memory growth. In the table (data is based on wikipedia) below is the MacBook Pro memory history:

year type RAM (GB)
2006 1st generation 1
2007 1st generation. Late 2006 release. 2
2008 2nd generation 4
2012 2nd generation. Mid 2012 release. 8
2015 3rd generation. Retina. 16

Surprisingly, this MacBook Pro data gives us the same result as the AWS regular machine results – 35% growth: 1GB*1.36^9 ~ 16GB. It appears as if  we removed (or at least dramatically reduced) the infrastructure bias.


This blog post shows that maximum memory in  MacBook Pro laptops and regular AWS machines are unbiased proxies for estimating the amount of memory people and data scientists use.

Memory is huge. It gives us the ability to analyze data more efficiently. We are limited only by the growth of analytical methods and memory size. Given the opportunity, we can consume all the “affordable” memory and then some as data scientists are memory hogs, in my humble opinion (and biased as well 🙂 ).


An update: Szilard Pafka pointed me to his code in his Github.

Where to find terabyte-size dataset for machine learning

In the previous blog posts we played with a large multi-gigabyte dataset. This 34 GB dataset is based on stackoverflow.com data. A couple days ago I found another great large dataset. This is a two terabyte snapshot from Reddit website. This dataset is perfect for text mining and NLP experimentation.

1. Two terabytes data set

The full dataset contains two terabytes of data in JSON format. Thank you for Stuck_In_the_Matrix who created this dataset! The compressed version is 250 GB. You can find this dataset here in Reddit. You should use torrent to download this compressed data.

Additionally, you might find a 32 gb subset of this data in Kaggle website in SQLite format here. Also, you can play with the data online through R or Python in the Kaggle competition.

2. Easy to use 16 gigabytes subset

To simplify the process of working with this data, I created a subset of this data in plain text TSV format (tab separated values) here in my dropbox folder (updated, old Mac OS compatable only archive is here). The file contains the copy of the Kaggle subset. File size is 16GB uncompressed (yes, it is 2 times smaller than the Kaggel file because of plain text format without indexes) and 6.6GB in archive.

SQLite code for converting the Kaggle file to a plain text:

 sqlite> .open database.sqlite
 sqlite> .headers off
 sqlite> .mode tabs
 sqlite> .out reddit-May2015.tsv
 sqlite> SELECT created_utc,ups,subreddit_id,link_id,name,score_hidden,replace(replace(author_flair_css_class, X'09', ' '), X'0A', ' ') AS author_flair_css_class,replace(replace(author_flair_text, X'09', ' '), X'0A', ' ') AS author_flair_text,subreddit,id,removal_reason,gilded,downs,archived,author,score,retrieved_on, replace(replace(body,X'09',' '), X'0A', ' ') AS body, distinguished,edited,controversiality,parent_id FROM May2015;
 sqlite> .exit

Note that I replace all tabs (X’09’) and newlines (X’0A’) to spaces for all text columns. Please let me know if you know how to combine two character replacement to one operations.

3. Read data in Spark

import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql._
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._

val fileName = "reddit-May2015.tsv"
val textFile = sc.textFile(fileName)
val rdd = textFile.map(_.split("\t")).filter( _.length == 22 ).map { p =>
            Row(p(0), p(1), p(2), p(3), p(4), p(5), p(6), p(7), p(8), p(9),
                p(10), p(11), p(12), p(13), p(14), p(15), p(16), p(17), p(18), p(19),
                p(20), p(21))
val schemaString = "created_utc,ups,subreddit_id,link_id,name,score_hidden,author_flair_css_class,author_flair_text,subreddit,id,removal_reason,gilded,downs,archived,author,score,retrieved_on,body,distinguished,edited,controversiality,parent_id"
val schema = StructType(
      schemaString.split(",").map(fieldName => StructField(fildName, StringType, true)))
val df = sqlContext.createDataFrame(rdd, schema)


Today is not easy to find great and interesting dataset for testing, training and research. So,  let’s collect some interesting datasets. Please share with the community your newly found information.


I looked into the licensing of this dataset. The dataset publisher Stuck_In_the_Matrix just published the dataset and provided description and links to the torrent directly in the Reddit website. Please note that Reddit sponsors the Kaggle competition with this dataset. It appears that we may play with the dataset for non-business related purposes.

Enter your email address to follow this blog and receive notifications of new posts by email.

Join 538 other followers

Beginners Guide: Apache Spark Python – Machine Learning Scenario With A Large Input Dataset

In the previous post “Beginners Guide: Apache Spark Machine Learning Scenario With A Large Input Dataset” we discussed the process of creating predictive model with 34 gigabytes of input data using Apache Spark. I received a request for the Python code as a solution instead of Scala. This is exactly what I will do in this post.

1. Python and Scala difference

Python solution looks similar to the last Scala solution because when you look “under the hood” you have the same Spark library and engine. Because of this fact, I don’t anticipate any significant performance change. As there aren’t many difference between Python and Scala, I will highlight only the major ones and you can refer back to the last post for the code in it’s entirety.

2. Sources

The complete source code of this program could be found here. Scala version from the previous post is here. Small 128MB testing dataset is here.

Entire 34GB dataset is available here at https://archive.org/details/stackexchange, look at file Posts.xml in  stackoverflow.com folder. Copy of 34GB Posts.xml file is here (8GB compressed). This data is licensed under the Creative Commons license (cc-by-sa).

3. Python code

In the Python version of code (source file) I create a correct Label column directly without intermediate sqlfunc\myudf function. Otherwise you should upload code of this function through intermediate python file to a Spark environment (sc.addPyFile() method). For the same reason I do not use xml libraries.

postsRDD = postsXml.map( lambda s: pyspark.sql.Row(\
        Id = re.search('Id=".+?"', s).group(0)[4:-1].encode('utf-8'),\
        Label = 1.0 if re.search('Tags=".+?"', s) != None\
             and re.search('Tags=".+?"', s).group(0)[6:-1].encode('utf-8').find(targetTag) >= 0 else 0.0,\
        Text = ((re.search('Title=".+?"', s).group(0)[7:-1] if re.search('Title=".+?"', s) != None else "") + " " + (re.search('Body=".+?"', s).group(0)[6:-1]) if re.search('Body=".+?"', s) != None else "")))

postsLabeled = sqlContext.createDataFrame(postsRDD)

One of the issues of Python version of code – we won’t decode xml meta symbols like <. Let’s keep these symbols for now.

Python code needs couple more temporary variables in the data preparation step (negTrainTmp1 and posTrainTmp1).

positiveTrain = positive.sample(False, 0.9)
negativeTrain = negative.sample(False, 0.9)
training = positiveTrain.unionAll(negativeTrain)

negTrainTmp1 = negativeTrain.withColumnRenamed("Label", "Flag")
negativeTrainTmp = negTrainTmp1.select(negTrainTmp1.Id, negTrainTmp1.Flag)

negativeTest = negative.join( negativeTrainTmp, negative.Id == negativeTrainTmp.Id, "LeftOuter").\
                        filter("Flag is null").\
                        select(negative.Id, negative.Text, negative.Label)

posTrainTmp1 = positiveTrain.withColumnRenamed("Label", "Flag")
positiveTrainTmp = posTrainTmp1.select(posTrainTmp1.Id, posTrainTmp1.Flag)

positiveTest = positive.join( positiveTrainTmp, positive.Id == positiveTrainTmp.Id, "LeftOuter").\
                        filter("Flag is null").\
                        select(positive.Id, positive.Text, positive.Label)
testing = negativeTest.unionAll(positiveTest)

Small changes in the model validation step:

testText = testTitle + testBody
testDF = sqlContext.createDataFrame([ ("0", testText, 1.0)], ["Id", "Text", "Label"])
result = model.transform(testDF)
prediction = result.collect()[0][7]
print("Prediction: ", prediction)

That’s all the changes that we need.


Thank you for all the great feedback to the previous post “Beginners Guide: Apache Spark Machine Learning Scenario With A Large Input Dataset”. The reception helped me to see where the needs and demands are in this field. I welcome all suggestions so keep the feedback coming and I’ll try to address as many as I humanly can.

Beginners Guide: Apache Spark Machine Learning Scenario With A Large Input Dataset

What if you want to create a machine learning model but realized that your input dataset doesn’t fit your computer memory? Usual you would use distributed computing tools like Hadoop and Apache Spark for that computation in a cluster with many machines. However, Apache Spark is able to process your data in local machine standalone mode and even build models when the input data set is larger than the amount of memory your computer has. In this blog post, I’ll show you an end-to-end scenario with Apache Spark where we will be creating a binary classification model using a 34.6 gigabytes of input dataset. Run this scenario in your laptop (yes, yours with its 4-8 gigabytes of memory and 50+ gigabytes of disk space) to test this.

Choose dataset
Choose dataset

1. Input data and expected results

In the previous post we discussed “How To Find Simple And Interesting Multi-Gigabytes Data Set”. The Posts.xml file from this dataset will be used in the current post. The file size is 34.6 gigabytes. This xml file contains the stackoverflow.com posts data as xml attributes:

  1. Title – post title
  2. Body – post text
  3. Tags – list of tags for post
  4. 10+ more xml-attributes that we won’t use.

The full dataset with stackoverflow.com Posts.xml file is available here at https://archive.org/details/stackexchange. Additionally I created a smaller version of this file with only 10 items\posts in it. This file contains a small size of original dataset. This data is licensed under the Creative Commons license (cc-by-sa).

As you might expect, this small file is not the best choice for model training. This file is only good for experimenting with your data preparation code. However, the end-to-end Spark scenario from this article works with this small file as well. Please download the file from here.

Our goal is to create a predictive model which predicts post Tags based on Body and Title. To simplify the task and reduce the amount of code, we are going to concatenate Title and Body and use that as a single text column.

It might be easy to imagine how this model should work in the stackoverflow.com web site – the user types a question and the web size automatically gives tags suggestion.

Assume that we need as many correct tags as possible and that the user would remove the unnecessary tags. Because of this assumption we are choosing recall as a high priority target for our model.

2. Binary and multi-label classification

The problem of stackoverflow tag prediction is a multi-label classification one because the model should predict many classes, which are not exclusive. The same text might be classified as “Java” and “Multithreading”. Note that multi-label classification is a generalization of different problems – multi-class classification problem which predict only one class from a set of classes.

To simplify our the first Apache Spark problem and reduce the amount of code, let’s simplify our problem. Instead of training a multi-label classifier, let’s train a simple binary classifier for a given tag. For instance, for the tag “Java” one classifier will be created which can predict a post that is about the Java language.

By using this simple approach, many classifiers might be created for almost all frequent labels (Java, C++, Python, multi-threading etc…). This approach is simple and good for studying. However, it is not perfect in practice because by splitting predictive models by separate classifiers, you are ignoring the correlations between classes. Another reason – training many classifiers might be computationally expensive.

3. Setup and Run Apache Spark in a standalone mode

If you don’t have Apache Spark in your machine you can simply download it from the Spark web page http://spark.apache.org/. Please use version 1.5.1. Direct link to a pre-built version – http://d3kbcqa49mib13.cloudfront.net/spark-1.5.1-bin-hadoop2.6.tgz

You are ready to run Spark in Standalone mode if Java is installed in your computer. If not – install Java.

For Unix systems and Macs, uncompress the file and copy to any directory. This is a Spark directory now.

Run spark master:


Run spark slave:


Run Spark shell:


Spark shell can run your Scala command in interactive mode.

Windows users can find the instruction here: http://nishutayaltech.blogspot.in/2015/04/how-to-run-apache-spark-on-windows7-in.html

If you are working in cluster mode in a Hadoop environment, I’m assuming you already know how to run the Spark shell.

4. Importing libraries

For this end-to-end scenario we are going to use Scala, the primary language for Apache Spark.

// General purpose library
import scala.xml._

// Spark data manipulation libraries
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql._
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._

// Spark machine learning libraries
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.ml.Pipeline

5. Parsing XML

We need to extract Body, Text and Tags from the input xml file and create a single data-frame with these columns. First, let’s remove the xml header and footer. I assume that the input file is located in the same directory where you run the spark shell command.

val fileName = "Posts.small.xml"
val textFile = sc.textFile(fileName)
val postsXml = textFile.map(_.trim).
                    filter(!_.startsWith("<?xml version=")).
                    filter(_ != "<posts>").
                    filter(_ != "</posts>")

Spark has good functions for parsing json and csv formats. For Xml we need to write several additional lines of code to create a data frame by specifying the schema programmatically.

Note, Scala language automatically converts all xml codes like “<a>” to actual tags “<a>”. Also we are going to concatenate title and body and remove all unnecessary tags and new line characters from the body and all space duplications.

val postsRDD = postsXml.map { s =>
            val xml = XML.loadString(s)

            val id = (xml \ "@Id").text
            val tags = (xml \ "@Tags").text

            val title = (xml \ "@Title").text
            val body = (xml \ "@Body").text
            val bodyPlain = ("<\\S+>".r).replaceAllIn(body, " ")
            val text = (title + " " + bodyPlain).replaceAll("\n", " ").replaceAll("( )+", " ");

            Row(id, tags, text)

To create a data-frame, schema should be applied to RDD.

val schemaString = "Id Tags Text"
val schema = StructType(
      schemaString.split(" ").map(fieldName => StructField(fieldName, StringType, true)))

val postsDf = sqlContext.createDataFrame(postsRDD, schema)

Now you can take a look at your data frame.


6. Preparing training and testing datasets

The next step – creating binary labels for a binary classifier. For this code examples, we are using “java” as a label that we would like to predict by a binary classifier. All rows with the “java” label should be marked as a “1” and rows with no “java” as a “0”. Let’s identify our target tag “java” and create binary labels based on this tag.

val targetTag = "java"
val myudf: (String => Double) = (str: String) => {if (str.contains(targetTag)) 1.0 else 0.0}
val sqlfunc = udf(myudf)
val postsLabeled = postsDf.withColumn("Label", sqlfunc(col("Tags")) )

Dataset can be split into negative and positive subsets by using the new label.

val positive = postsLabeled.filter('Label > 0.0)
val negative = postsLabeled.filter('Label < 1.0)

We are going to use 90% of our data for the model training and 10% as a testing dataset. Let’s create a training dataset by sampling the positive and negative datasets separately.

val positiveTrain = positive.sample(false, 0.9)
val negativeTrain = negative.sample(false, 0.9)
val training = positiveTrain.unionAll(negativeTrain)

The testing dataset should include all rows which are not included in the training datasets. And again – positive and negative examples separately.

val negativeTrainTmp = negativeTrain.withColumnRenamed("Label", "Flag").select('Id, 'Flag)
val negativeTest = negative.join( negativeTrainTmp, negative("Id") === negativeTrainTmp("Id"), "LeftOuter").
                            filter("Flag is null").select(negative("Id"), 'Tags, 'Text, 'Label)
val positiveTrainTmp = positiveTrain.withColumnRenamed("Label", "Flag").select('Id, 'Flag)
val positiveTest = positive.join( positiveTrainTmp, positive("Id") === positiveTrainTmp("Id"), "LeftOuter").
                            filter("Flag is null").select(positive("Id"), 'Tags, 'Text, 'Label)
val testing = negativeTest.unionAll(positiveTest)

7. Training a model

Let’s identify training parameters:

  1. Number of features
  2. Regression parameters
  3. Number of epoch for gradient decent

Spark API creates a model based on columns from the data-frame and the training parameters:

val numFeatures = 64000
val numEpochs = 30
val regParam = 0.02

val tokenizer = new Tokenizer().setInputCol("Text").setOutputCol("Words")
val hashingTF = new  org.apache.spark.ml.feature.HashingTF().setNumFeatures(numFeatures).
val lr = new LogisticRegression().setMaxIter(numEpochs).setRegParam(regParam).
val pipeline = new Pipeline().setStages(Array(tokenizer, hashingTF, lr))

val model = pipeline.fit(training)

8. Testing a model

This is our final code for the binary “Java” classifier which returns a prediction (0.0 or 1.0):

val testTitle = "Easiest way to merge a release into one JAR file"
val testBoby = """Is there a tool or script which easily merges a bunch of 
                   JAR files into one JAR file? A bonus would be to easily set the main-file manifest 
                   and make it executable. I would like to run it with something like:
                  As far as I can tell, it has no dependencies which indicates that it shouldn't be an easy 
                  single-file tool, but the downloaded ZIP file contains a lot of libraries."""
val testText = testTitle + testBody
val testDF = sqlContext.createDataFrame(Seq( (99.0, testText))).toDF("Label", "Text")
val result = model.transform(testDF)
val prediction = result.collect()(0)(6).asInstanceOf[Double]
print("Prediction: "+ prediction)

Let’s evaluate the quality of the model based on training dataset.

val testingResult = model.transform(testing)
val testingResultScores = testingResult.select("Prediction", "Label").rdd.
                                    map(r => (r(0).asInstanceOf[Double], r(1).asInstanceOf[Double]))
val bc = new BinaryClassificationMetrics(testingResultScores)
val roc = bc.areaUnderROC
print("Area under the ROC:" + roc)

If you use the small dataset then the quality of your model is probably not the best. Area under the ROC value will be very low (close to 50%) which indicates a poor quality of the model. With an entire Posts.xml dataset, the quality is no so bad. Area under the ROC is 0.64. Probably you can improve this result by playing with different transformations such as TF-IDF and normalization. Not in this blog post.


Apache Spark could be a great option for data processing and for machine learning scenarios if your dataset is larger than your computer memory can hold. It might not be easy to use Spark in a cluster mode within the Hadoop Yarn environment. However, in a local (or standalone) mode, Spark is as simple as any other analytical tool.

Please let me know if you encountered any problem or had future questions. I would really like to head your feedback.

The complete source code of this program could be found here.

How To Find Simple And Interesting Multi-Gigabytes Data Set

Many folks are very exited about big data. They like play, explore, work and study this frontier. Most likely these folks either work with or would like to play with large amount of data (hundreds of gigabytes or even terabytes). But here’s the thing, it’s not easy to find a multi-gigabytes dataset. Usually, these kinds of datasets are needed for experimentating with new data processing framework such as Apache Spark or data streaming tools like Apache Kafka. In this blog post I will describe and provide a link to simple and a powerful multi-gigabytes stackoverflow data set.

1. Datasets for machine learning

Lots of sources exist for machine learning problems. Kaggle is the best source for these problems and they offer lots of datasets presented with examples of code. Most of these data sets are clean and ready to use in your machine learning experiments.

In a real data scientist’s life most likely you do not have the luxury of clean data and the size of the input data creates an additional big problem. University courses as well as online courses offer a limited viewpoint on data science and machine learning due to the fact they teach student to apply statistical and machine learning methods to a small amount of clean data. In reality, a data scientist spends the majority part of time by getting data and cleaning up that data. According to Hal Varian (Google’s chief economist) “the sexiest job of the 21st century” belongs to Statisticians (and I assume to Data Scientists). However, they perform “clean up” work most of the time.

In order to experiment with new data processing or data streaming tools, you need a large (larger than your computer can hold in memory) and an uncleaned datasets.

Large and uncleanrf datasets will allow you to get actual data processing or learn analytical skills. It turns out that this is not that easy to find.

2. Datasets for processing

Kdnuggets and Quora have pretty good lists of open repositories:

  1. http://www.kdnuggets.com/datasets/index.html
  2. https://www.quora.com/What-kinds-of-large-datasets-open-to-the-public-do-you-analyze-the-mostly

Most of these datasets from these lists are very small in size and for the most part, you need specific knowledge from a dataset specific business domain such as physics or healthcare. However, for learning and experimentation purposes, it would be nice to have a dataset from a well known business domain that all people are familiar with.

Social network data is the best because people understand these datasets and they have intuition about the data which is important in the analytic process. You might use a social network API to extract your data sets. Unfortunately, your data set is not the best for sharing your analytical results with other people. It would be great to find a common social network dataset with an open license. And I’ve found one!

3. Stackoverflow open dataset

Stackoverflow data set is the only social open dataset that I was able to find. Stackoverflow.com is a question and answers web site about programming. This web site is especially useful when you have to write a code in a language you are not familiar with. This well known approach is called – stackoverflow driven development or SDD. I believe all people from the high-tech industry are familiar with stackoverflow and many of them have an account for this web site.

Stack Exchange Company (owner of stackoverflow.com) publishes stackexchange dataset under an open creative common license. You might find the freshest dataset on this page:


The dataset contains all stackexchange data including stackoverflow and the overall size of the archive is 27 gigabytes. The size of the uncompressed data is more than 1 terabyte.

4. How to download and extract the dataset?

However, this dataset is not easy to get. First, you need to upload the archive of the entire dataset. Please note that the downloading speed is very slow. They recommend using a bittorrent client to download the archive but often it has some issues. Without the bittorent, I made 3 attempts and spent 2 days to download this archive. Next, you need to unzip the large archive. Finally, you need to unzip the subset of data that you need (like stackoverflow-Posts or travel.stackexchange) using the 7z compressor. If you don’t have the 7z compressor, you need to find and install it to your machine.

After you download the archive from https://archive.org/details/stackexchange extract all stackoverflow related archives and uncompress each of them (all archives which starts with stackovervlow.com):

  • stackovervlow.com-Posts.7z
  • stackovervlow.com-PostsHistory.7z
  • stackovervlow.com-Comments.7z
  • stackovervlow.com-Badges.7z
  • stackovervlow.com-PostLinks.7z
  • stackovervlow.com-Tags.7z
  • stackovervlow.com-Users.7z
  • stackovervlow.com-Votes.7z

As a result you will see a set of xml files with the same names.

5. How to use the dataset?

Let’s experiment with the dataset. The most interesting file is Posts.xml. This file contains 34Gb of uncompressed data, approximately 70% is Body text which is a text of questions from the web site. This amount of data, most likely, does not fit your memory. We might use an in-disk data manipulation or machine learning technology. This is a good chance to use Apache Spark and MLLib or your custom solution.

Let’s take a look how this stackoverflow question will look like in the file.

Stackowerflow example
Stackowerflow example

In the file this post is presented by one single row. Note that because the text is HTML – the opening and closing p tags (<p> and </p>) are written as &lt;p&gt; and &lt;/p&gt; respectively.

Body=“&lt;p&gt;I want to use a track-bar to change a form’s opacity.&lt;/p&gt; &lt;p&gt;This is my code:&lt;/p&gt; &lt;pre&gt;&lt;code&gt;decimal trans = trackBar1.Value / 5000; this.Opacity = trans; &lt;/code&gt;&lt;/pre&gt; &lt;p&gt;When I try to build it, I get this error:&lt;/p&gt; &lt;blockquote&gt; &lt;p&gt;Cannot implicitly convert type ‘decimal’ to ‘double’.&lt;/p&gt; &lt;/blockquote&gt; &lt;p&gt;I tried making &lt;code&gt;trans&lt;/code&gt; a &lt;code&gt;double&lt;/code&gt;, but then the control doesn’t work. This code has worked fine for me in VB.NET in the past. &lt;/p&gt; ”
LastEditorDisplayName=“Rich B”
Title=“When setting a form’s opacity should I use a decimal or double?”

I’ll provide Apache Spark code examples with this data set in the next blog post. My scenario will include two parts: preparing data or data manipulation and machine learning part. Both of these part I’ll use multi-gigabytes dataset as an input.


Stackoverflow dataset (https://archive.org/details/stackexchange) is probably the simplest and most interesting open multi-gigabytes dataset you can find which fits machine learning, data processing scenarios and data streaming. Please share if you have any information about other simple open big dataset resources. This should help the community a lot.

Can Apache Spark process 100 terabytes of data in interactive mode?

Apache Spark innovates a lot of in the in-memory data processing area. With this framework, you are able to upload data to a cluster memory and work with this data extremely fast in the interactive mode (interactive mode is another important Spark feature btw…). One year back (10/10/2014) Databricks announced that Apache Spark was able to sort 100 terabytes of data in 23 minutes.

Here is an interesting question – what is the limit for the amount of data you can process interactively in a cluster? What if you had 100 terabytes of memory in your cluster? Memory is so quick you would think! Intuition tells you can use this memory to interactively process 100 terabytes of input data or at least half of this size. However, as usual in a distributed systems world, our intuition is wrong!

Interactive Apache Spark
Interactive Apache Spark

1. Response time

What would be a response time for a simple data processing scenario and for a more complicated one? Are we still in interactive mode? We’d like to think so but unfortunately, we are not. I saw in a practice scenario that response time for a simple scenario with a simple “where sum(), count() ” statements with 8 terabytes of data was 20-40 seconds. For a more complicated one and for more realistic scenarios (couple of “group bys” + couple of “joins”) response time was 3-5 minutes. This is definitely not what I call interactive mode!

In my daily life, I do analytics where the response time is critical. For me, I give it up to 3 or 10 seconds, okay perhaps even up to 15 seconds and still consider this interactive mode. Beyond this I would consider it actually batch mode. Several seconds or 3-5 minutes instead of 15-60 minutes might look like a incredible result compared to MapReduce-like on-disk processing. However, this is not interactive.

2. Where the interactivity end?

The maximum amount of memory I was able to process in the interactive mode with only a few seconds of delay was limited by 1 terabyte. With this, the efficiency was still good. However, beyond 1 Tb, I noticed that the response time was extremely delayed

My guess is that in order to improve efficiency (5-10 terrabytes with only several seconds delay) we would need to update our hardware (I’d like to try a cluster with the most powerful EC2 machines i2.8xlarge with 250 gigabytes of RAM memory) and tune software settings (Apache Spark driver settings, in-memory columnar format, and probably YARN settings)

Even with software and hardware upgrade, it is clear to me that the interactive mode limit doesn’t even come close to the 100 terabytes.

3. Read data to memory first

As you recall from previously, remember that it takes many seconds or even several minutes for each iteration of data processing. However, this is not the complete story. If you work in Ad Hoc analytics or create machine learning models your initial data set will most likely be stored in a cluster HDFS storage. This means that before the in-memory iterations you will be reading data from disks which takes much longer. The performance as usually depends on the hardware you have and the software settings. Most likely it will take between 15-30 minutes for an 5-8 terabytes data set. Even for 1 terabyte it might take 5 minutes or so.


Before jumping into the Apache Spark in-memory processing it is worth to make a plan for your analytical scenarios and estimate response time especially if your data size is more than 1 terabyte.

Please offer feedback regarding your experience dealing with the maximum amount of memory you were able to work with in interactive mode.

What No One Tells You About Real-Time Machine Learning

During this year, I heard and read a lot about real-time machine learning. People usually provide this appealing business scenario when discussing credit card fraud detection systems. They say that they can continuously update credit card fraud detection model in real-time (See “What is Apache Spark?”, “…real-time use cases…” and “Real time machine learning”). It looks fantastic but not realistic to me. One important detail is missing in this scenario – continuous flow of transactional data is not needed for model retraining. Instead, you need continuous flow of labeled (or pre-marked as Fraud\Not-Fraud) transactional data.

Machine learning process
Machine learning process

Creating labeled data is probably the slowest and the most expensive step in most of the machine learning systems. Machine learning algorithms learn to detect the fraud transactions from the people which is much like labeled data. Let’s see how it works for fraud detection scenario.

1. Creating model

For training credit card models, you need a lot of examples of transactions and each transaction should be labeled as Fraud or Not-Fraud. This labels has to be as accurate as possible! This is our labeled data set. This data set is an input for supervised machine learning algorithms. Based on the labeled data, the algorithm trains the fraud detection model. The model is usually presented as a binary classifier with True (Fraud) or False (Not-Fraud) classes.

The labeled data set plays a central role in this process. It is very easy to change the parameters of our algorithm such as the feature normalization method or loss function. We can change the algorithm itself from logistic regression to SVM or random forest for example. However, you cannot change the labeled data set. This information is predefined and your model should predict the labels that you already have.

2. How long does data labeling process takes?

How can we label the freshest transactions? If customers report fraud transactions or stolen credit cards, we can immediately mark the transaction as “Fraud”. What should we do with the rest of the transactions? We can assume that non reported transactions are “Not Fraud”. How long should we wait to be sure that they are not fraud? The last time when my friend lost a credit card, she said, “I won’t report the missing credit card yet. Tomorrow I’ll go to the shop that I had last visited and I’ll ask them if they found my credit card.” Fortunately, the store found and returned her credit card. I’m not an expert in the credit card fraud field (I’m only a good card user), but from my experience, we should wait at least a couple of days before marking transactions as “Not Fraud”.

In contrast, if somebody reported a Fraud transaction, we can immediately label this transaction as “Fraud”. A guy who reports fraud probably realizes the fraud transaction only after several hours or couple days after the loss but this is the best we can do.

In that way, our “freshest” labeled data set will be limited by a few “Fraud” transactions with several hours or days delay and lot of “Not Fraud” transactions within 2-3 days delay.

3. Let’s try to speed up the labeling process

Our goal is to obtained the “freshest” labeled data possible. In fact, we have “fresh Fraud” labels only. For “Not Fraud” labels, we have to wait a few days. It might look like a good idea to build a model using only “Fresh Fraud” labeled data. However, we should understand that this labeled data set is biased which might lead to a lot of issues with the models.

Let’s imagine a new big shopping center opened yesterday and we got one single fraud report regarding one single transaction from this store. Our labeled data set will contain only one transaction from this shop with a “Fraud” label. All other transactions from the shop are not labeled yet. The algorithm might decide that this shop is a strong fraud predictor and all transactions from this shop will be erroneously mis-classified as “Fraud” immediately “in real-time”. Advantages of real time give us real-time problems.


As we can see, the credit card fraud detection business scenario does not look like the best scenario for real-time supervised machine learning. Also, I was unable to imagine a good scenario from another business domains. I’d love to see good scenarios of real-time machine learning. Please share if you have any information or ideas to share with the community.