Writing Python Functions for Data Understanding
In this post, I’ll explain two Python functions: one that calculates key statistics for a dataset attribute and another that prepares data for visualization.
To ensure these functions work correctly, I will use unittest to create test cases and verify their functionality.
Table of contents
Part I. Data Understanding Overview
Context
Data Mining Pipeline involves:
- Data understanding
- Data pre-processing
- Data warehousing
- Data modeling
- Pattern evaluation
The first crucial step in the Data Mining Pipeline process is understanding the data.
In this step, we analyze the dataset to understand its quantitative objects and attributes, calculating central tendency and dispersion measures. Attributes can be categorical or numerical, impacting the statistics you can calculate.
- Numerical attributes are quantitative and can be measured on a scale.
- Categorical attributes represent categories or groups and are qualitative. Some statistics do not apply to these attributes or are used differently (i.e., number of objects, frequency count, mode, proportions, cross-tabulation).
Let’s define a function that returns the following attributes for the ith column:
- Number of objects: count()
- The minimum value: min()
- The maximum value: max()
- The mean value: mean()
- The standard deviation value: std()
- The Q1 value: quantile(0.25)
- The median value: median()
- The Q3 value: quantile(0.75)
- The IQR value: Q3 - Q1
Actions
- Import the required Python packages & libraries
# import required Python packages and libraries
import argparse
import pandas as pd
import numpy as np
import pickle
from pathlib import Path
- Define the function and return statistics on selected attributes.
# define the calculate function
def calculate(dataFile, col_num):
"""
Input Parameters:
dataFile: The dataset file.
ithAttre: The ith attribute for which the various properties must be calculated.
Default value of 0,infinity,-infinity are assigned to all the variables as required.
"""
#Initialize the variables
numObj, minValue, maxValue, mean, stdev, Q1, median, Q3, IQR = [0,"inf","-inf",0,0,0,0,0,0]
#load the dataset
data = pd.read_csv(dataFile)
#select attribute
attre_selection = data.iloc[:, col_num]
# Get the column name
column_name = data.columns[col_num]
#calculate stats
numObj = attre_selection.count()
minValue = attre_selection.min()
maxValue = attre_selection.max()
mean = attre_selection.mean()
stdev = attre_selection.std()
Q1 = attre_selection.quantile(0.25)
median = attre_selection.median()
Q3 = attre_selection.quantile(0.75)
IQR = Q3 - Q1
#return results
return column_name, numObj, minValue, maxValue, mean, stdev, Q1, median, Q3, IQR
If you wish to print the results using a f-string:
# Print the results with their respective labels
print(f"Count: {numObj}")
print(f"Min: {minValue}")
print(f"Max: {maxValue}")
print(f"Mean: {mean}")
print(f"Standard Deviation: {stdev}")
print(f"Q1: {Q1}")
print(f"Median: {median}")
print(f"Q3: {Q3}")
print(f"IQR: {IQR}")
Results
When executing this function, various statistics are calculated and returned:
- numObj: The number of non-null values in the column.
- minValue: The minimum value in the column.
- maxValue: the maximum value in the column.
- mean: The mean/average value of the column.
- stdev: The standard deviation measures the amount of variation or dispersion of values.
- Q1: The first quartile (25th percentile)
- Median: The 50th percentile represents the middle value of the dataset.
- Q3: The 3rd quartile (75th percentile).
- IQR: The Interquartile Range (Q3-Q1), which measures the spread of the middle 50% of the data.
Tests
- Run tests using unittest
The unittest library is a built-in Python library for writing and running tests on your code. It provides a framework to create unit tests, which are small and focused tests designed to check if individual pieces of your code (such as functions and methods) work correctly.
# import required Python packages and libraries
import unittest
# define the test cases/assertions
class TestKnn(unittest.TestCase):
def setUp(self):
self.loc = "/Data_Mining/pipeline_data/dataset.csv"
file = open('/Data_Mining/pipeline_data/testing', 'rb')
self.data = pickle.load(file)
file.close()
def test0(self):
"""
Test the label counter
"""
self.column = self.data[0]
result = calculate(self.loc,self.column)
self.assertEqual(result[0],self.data[1][0])
self.assertAlmostEqual(result[1],self.data[1][1], places = 3)
self.assertAlmostEqual(result[2],self.data[1][2], places = 3)
self.assertAlmostEqual(result[3],self.data[1][3], places = 3)
self.assertAlmostEqual(result[4],self.data[1][4], places = 3)
self.assertAlmostEqual(result[5],self.data[1][5], places = 3)
self.assertAlmostEqual(result[6],self.data[1][6], places = 3)
self.assertAlmostEqual(result[7],self.data[1][7], places = 3)
self.assertAlmostEqual(result[8],self.data[1][8], places = 3)
tests = TestKnn()
tests_to_run = unittest.TestLoader().loadTestsFromModule(tests)
unittest.TextTestRunner().run(tests_to_run)
Tests returned 0 errors and 0 failures, indicating that the functions above were defined correctly.
unittest.runner.TextTestResult run=1 errors=0 failures=0
Part II. Data Visualization Overview
Data visualization can be very powerful in assisting in data understanding. Data Visualization Methods include boxplots, histograms, scatterplots, quantile plots, heatmaps, etc. Let’s define a function that helps generate a scatter plot.
Context
Having a function to extract the necessary data for plotting can simplify the process of gathering and preparing data for visualization and allow for the reuse of the function without having to rewrite the data extraction and preparation code. The function below streamlines the process of creating a plot.
Actions
- Import Python packages and libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
- Define the function to return:
- x: data for the x-axis
- y: data for the y-axis
- title: the title of the plot
- x_label: label for the x-axis
- y_label: label for the y-axis
def func():
'''
Output: x, y, title, x-label, y-label
'''
#initialize variables
x = []
y = []
title = ''
x_label = ''
y_label = ''
#load the dataset
data = pd.read_csv('/Data_Mining/pipeline_data/dataset.csv')
# Extract x and y values
x = data['CO'].tolist()
y = data['AFDP'].tolist()
# Define Titles and Labels
title = 'CO vs AFDP'
x_label = 'CO'
y_label = 'AFDP'
return x, y, title, x_label, y_label
Results
The function returns x, y, title, x_label, y_label values and makes them available for plotting.
Tests
By running this cell, we see the scatter plot image based on the function we defined.
from IPython.display import Image, display
# Display the image with a custom-size
display(Image(filename='/Data_Mining/pipeline_data/scatter_plot.png', width=300, height=200))
We can then compare the output of our function and validate the function is working properly.
# Testing the func() function
x, y, title, x_label, y_label = func()
plt.scatter(x, y)
plt.title(title)
plt.xlabel(x_label)
plt.ylabel(y_label)