3.10 AutoML
Contents
3.10 AutoML#
Automated workflow for hyper-parameter tuning and optimal model finder
In this tutorial, we will try some cool technique that has been used widely to make AI/ML less tedious and boost your ML workflow efficiency.
If you have learned 3.6, you might be amazed but also annoyed by all those parameter tuning efforts and many back-n-forth iterations needed to figure out which configuration will be optimal for your case. It has been known as the major reason for low productivity in the AI/ML world. People come up with an idea that it seems most work in that tuning and iteration are very simple, can we automate it? The answer is yes, and that will be the technique we will introduce here: AutoML.
There are many AutoML solutions on the market, e.g., AutoKeras, auto-sklearn, H2O, Auto-WEKA, etc. Here we will focus on PyCaret which is a popular one in both academia and industry and very easy to use.
In the following tutorial, we will use the Pycaret Docker Image to run the tutorial. In Terminal, call docker
to pull the PyCaret image and start a jupyter notebook:
docker pull pycaret/full
docker run -it -p 8888:8888 -e GRANT_SUDO=yes pycaret/full
Installations on M1 Mac can be tricky - especially when using lighgbm library. Try to install both libraries.
You will then be able to edit a notebook with the following cells:
!pip install pycaret
Requirement already satisfied: pycaret in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (3.3.2)
Requirement already satisfied: ipython>=5.5.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (8.16.1)
Requirement already satisfied: ipywidgets>=7.6.5 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (8.1.5)
Requirement already satisfied: tqdm>=4.62.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (4.66.5)
Requirement already satisfied: numpy<1.27,>=1.21 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (1.26.0)
Requirement already satisfied: pandas<2.2.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (2.1.1)
Requirement already satisfied: jinja2>=3 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (3.1.2)
Requirement already satisfied: scipy<=1.11.4,>=1.6.1 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (1.11.3)
Requirement already satisfied: joblib<1.4,>=1.2.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (1.3.2)
Requirement already satisfied: scikit-learn>1.4.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (1.4.2)
Requirement already satisfied: pyod>=1.1.3 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (2.0.2)
Requirement already satisfied: imbalanced-learn>=0.12.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (0.12.4)
Requirement already satisfied: category-encoders>=2.4.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (2.6.4)
Requirement already satisfied: lightgbm>=3.0.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (4.5.0)
Requirement already satisfied: numba>=0.55.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (0.60.0)
Requirement already satisfied: requests>=2.27.1 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (2.31.0)
Requirement already satisfied: psutil>=5.9.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (5.9.5)
Requirement already satisfied: markupsafe>=2.0.1 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (2.1.3)
Requirement already satisfied: importlib-metadata>=4.12.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (6.8.0)
Requirement already satisfied: nbformat>=4.2.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (5.9.2)
Requirement already satisfied: cloudpickle in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (3.1.0)
Requirement already satisfied: deprecation>=2.1.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (2.1.0)
Requirement already satisfied: xxhash in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (3.5.0)
Requirement already satisfied: matplotlib<3.8.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (3.7.5)
Requirement already satisfied: scikit-plot>=0.3.7 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (0.3.7)
Requirement already satisfied: yellowbrick>=1.4 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (1.5)
Requirement already satisfied: plotly>=5.14.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (5.17.0)
Requirement already satisfied: kaleido>=0.2.1 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (0.2.1)
Requirement already satisfied: schemdraw==0.15 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (0.15)
Requirement already satisfied: plotly-resampler>=0.8.3.1 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (0.10.0)
Requirement already satisfied: statsmodels>=0.12.1 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (0.14.0)
Requirement already satisfied: sktime==0.26.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (0.26.0)
Requirement already satisfied: tbats>=1.1.3 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (1.1.3)
Requirement already satisfied: pmdarima>=2.0.4 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (2.0.4)
Requirement already satisfied: wurlitzer in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (3.1.1)
Requirement already satisfied: packaging in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from sktime==0.26.0->pycaret) (23.2)
Requirement already satisfied: scikit-base<0.8.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from sktime==0.26.0->pycaret) (0.7.8)
Requirement already satisfied: patsy>=0.5.1 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from category-encoders>=2.4.0->pycaret) (0.5.3)
Requirement already satisfied: threadpoolctl>=2.0.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from imbalanced-learn>=0.12.0->pycaret) (3.2.0)
Requirement already satisfied: zipp>=0.5 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from importlib-metadata>=4.12.0->pycaret) (3.17.0)
Requirement already satisfied: backcall in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from ipython>=5.5.0->pycaret) (0.2.0)
Requirement already satisfied: decorator in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from ipython>=5.5.0->pycaret) (5.1.1)
Requirement already satisfied: jedi>=0.16 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from ipython>=5.5.0->pycaret) (0.19.1)
Requirement already satisfied: matplotlib-inline in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from ipython>=5.5.0->pycaret) (0.1.6)
Requirement already satisfied: pickleshare in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from ipython>=5.5.0->pycaret) (0.7.5)
Requirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from ipython>=5.5.0->pycaret) (3.0.39)
Requirement already satisfied: pygments>=2.4.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from ipython>=5.5.0->pycaret) (2.16.1)
Requirement already satisfied: stack-data in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from ipython>=5.5.0->pycaret) (0.6.2)
Requirement already satisfied: traitlets>=5 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from ipython>=5.5.0->pycaret) (5.11.2)
Requirement already satisfied: typing-extensions in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from ipython>=5.5.0->pycaret) (4.12.2)
Requirement already satisfied: exceptiongroup in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from ipython>=5.5.0->pycaret) (1.1.3)
Requirement already satisfied: pexpect>4.3 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from ipython>=5.5.0->pycaret) (4.8.0)
Requirement already satisfied: appnope in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from ipython>=5.5.0->pycaret) (0.1.3)
Requirement already satisfied: comm>=0.1.3 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from ipywidgets>=7.6.5->pycaret) (0.1.4)
Requirement already satisfied: widgetsnbextension~=4.0.12 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from ipywidgets>=7.6.5->pycaret) (4.0.13)
Requirement already satisfied: jupyterlab-widgets~=3.0.12 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from ipywidgets>=7.6.5->pycaret) (3.0.13)
Requirement already satisfied: contourpy>=1.0.1 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from matplotlib<3.8.0->pycaret) (1.3.0)
Requirement already satisfied: cycler>=0.10 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from matplotlib<3.8.0->pycaret) (0.12.0)
Requirement already satisfied: fonttools>=4.22.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from matplotlib<3.8.0->pycaret) (4.43.1)
Requirement already satisfied: kiwisolver>=1.0.1 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from matplotlib<3.8.0->pycaret) (1.4.5)
Requirement already satisfied: pillow>=6.2.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from matplotlib<3.8.0->pycaret) (10.0.1)
Requirement already satisfied: pyparsing>=2.3.1 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from matplotlib<3.8.0->pycaret) (3.1.1)
Requirement already satisfied: python-dateutil>=2.7 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from matplotlib<3.8.0->pycaret) (2.8.2)
Requirement already satisfied: importlib-resources>=3.2.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from matplotlib<3.8.0->pycaret) (6.4.5)
Requirement already satisfied: fastjsonschema in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from nbformat>=4.2.0->pycaret) (2.18.1)
Requirement already satisfied: jsonschema>=2.6 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from nbformat>=4.2.0->pycaret) (4.19.1)
Requirement already satisfied: jupyter-core in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from nbformat>=4.2.0->pycaret) (5.3.2)
Requirement already satisfied: llvmlite<0.44,>=0.43.0dev0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from numba>=0.55.0->pycaret) (0.43.0)
Requirement already satisfied: pytz>=2020.1 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pandas<2.2.0->pycaret) (2023.3.post1)
Requirement already satisfied: tzdata>=2022.1 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pandas<2.2.0->pycaret) (2023.3)
Requirement already satisfied: tenacity>=6.2.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from plotly>=5.14.0->pycaret) (8.2.3)
Requirement already satisfied: dash>=2.9.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from plotly-resampler>=0.8.3.1->pycaret) (2.18.2)
Requirement already satisfied: orjson<4.0.0,>=3.8.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from plotly-resampler>=0.8.3.1->pycaret) (3.10.11)
Requirement already satisfied: tsdownsample>=0.1.3 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from plotly-resampler>=0.8.3.1->pycaret) (0.1.3)
Requirement already satisfied: Cython!=0.29.18,!=0.29.31,>=0.29 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pmdarima>=2.0.4->pycaret) (3.0.11)
Requirement already satisfied: urllib3 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pmdarima>=2.0.4->pycaret) (1.26.20)
Requirement already satisfied: setuptools!=50.0.0,>=38.6.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pmdarima>=2.0.4->pycaret) (68.2.2)
Requirement already satisfied: charset-normalizer<4,>=2 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from requests>=2.27.1->pycaret) (3.3.0)
Requirement already satisfied: idna<4,>=2.5 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from requests>=2.27.1->pycaret) (3.4)
Requirement already satisfied: certifi>=2017.4.17 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from requests>=2.27.1->pycaret) (2023.7.22)
Requirement already satisfied: Flask<3.1,>=1.0.4 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from dash>=2.9.0->plotly-resampler>=0.8.3.1->pycaret) (3.0.3)
Requirement already satisfied: Werkzeug<3.1 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from dash>=2.9.0->plotly-resampler>=0.8.3.1->pycaret) (3.0.0)
Requirement already satisfied: dash-html-components==2.0.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from dash>=2.9.0->plotly-resampler>=0.8.3.1->pycaret) (2.0.0)
Requirement already satisfied: dash-core-components==2.0.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from dash>=2.9.0->plotly-resampler>=0.8.3.1->pycaret) (2.0.0)
Requirement already satisfied: dash-table==5.0.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from dash>=2.9.0->plotly-resampler>=0.8.3.1->pycaret) (5.0.0)
Requirement already satisfied: retrying in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from dash>=2.9.0->plotly-resampler>=0.8.3.1->pycaret) (1.3.4)
Requirement already satisfied: nest-asyncio in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from dash>=2.9.0->plotly-resampler>=0.8.3.1->pycaret) (1.5.6)
Requirement already satisfied: parso<0.9.0,>=0.8.3 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from jedi>=0.16->ipython>=5.5.0->pycaret) (0.8.3)
Requirement already satisfied: attrs>=22.2.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from jsonschema>=2.6->nbformat>=4.2.0->pycaret) (23.1.0)
Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from jsonschema>=2.6->nbformat>=4.2.0->pycaret) (2023.7.1)
Requirement already satisfied: referencing>=0.28.4 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from jsonschema>=2.6->nbformat>=4.2.0->pycaret) (0.30.2)
Requirement already satisfied: rpds-py>=0.7.1 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from jsonschema>=2.6->nbformat>=4.2.0->pycaret) (0.10.4)
Requirement already satisfied: six in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from patsy>=0.5.1->category-encoders>=2.4.0->pycaret) (1.16.0)
Requirement already satisfied: ptyprocess>=0.5 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pexpect>4.3->ipython>=5.5.0->pycaret) (0.7.0)
Requirement already satisfied: wcwidth in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=5.5.0->pycaret) (0.2.8)
Requirement already satisfied: platformdirs>=2.5 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from jupyter-core->nbformat>=4.2.0->pycaret) (3.5.1)
Requirement already satisfied: executing>=1.2.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from stack-data->ipython>=5.5.0->pycaret) (1.2.0)
Requirement already satisfied: asttokens>=2.1.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from stack-data->ipython>=5.5.0->pycaret) (2.4.0)
Requirement already satisfied: pure-eval in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from stack-data->ipython>=5.5.0->pycaret) (0.2.2)
Requirement already satisfied: itsdangerous>=2.1.2 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from Flask<3.1,>=1.0.4->dash>=2.9.0->plotly-resampler>=0.8.3.1->pycaret) (2.2.0)
Requirement already satisfied: click>=8.1.3 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from Flask<3.1,>=1.0.4->dash>=2.9.0->plotly-resampler>=0.8.3.1->pycaret) (8.1.7)
Requirement already satisfied: blinker>=1.6.2 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from Flask<3.1,>=1.0.4->dash>=2.9.0->plotly-resampler>=0.8.3.1->pycaret) (1.6.2)
[notice] A new release of pip is available: 23.3.1 -> 24.3.1
[notice] To update, run: pip install --upgrade pip
First we get data ready#
As usual, data collection is the first step. To better demonstrate the point of AutoML, we will use the same data as 3.6 Random Forest.
!pip install wget
Requirement already satisfied: wget in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (3.2)
[notice] A new release of pip is available: 23.3.1 -> 24.3.1
[notice] To update, run: pip install --upgrade pip
import wget
wget.download("https://docs.google.com/uc?export=download&id=1pko9oRmCllAxipZoa3aoztGZfPAD2iwj")
'temps (6).csv'
Display the data columns#
Show the columns and settle on the target variables and the input variables. In this chapter, we will use
# Pandas is used for data manipulation
import pandas as pd
# Read in data and display first 5 rows
features = pd.read_csv('temps.csv')
features.columns
Index(['year', 'month', 'day', 'week', 'temp_2', 'temp_1', 'average', 'actual',
'forecast_noaa', 'forecast_acc', 'forecast_under', 'friend'],
dtype='object')
Temp_2 : Maximum temperature on 2 days prior to today.
Temp_1: Maximum temperature on yesterday.
Average: Historical temperature average
Actual: Actual measure temperature on today.
Forecast_NOAA: Temperature values forecasted by NOAA
Friend: Forecasted by Friend (Randomly selected number within plus-minus 20 of Average temperature)
We will use the actual
as the label, and all the other variables as features.
Check the data shape#
features.shape
(348, 12)
# One-hot encode the data using pandas get_dummies
features = pd.get_dummies(features)
# Display the first 5 rows of the last 12 columns
features.iloc[:,5:].head(5)
average | actual | forecast_noaa | forecast_acc | forecast_under | friend | week_Fri | week_Mon | week_Sat | week_Sun | week_Thurs | week_Tues | week_Wed | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 45.6 | 45 | 43 | 50 | 44 | 29 | True | False | False | False | False | False | False |
1 | 45.7 | 44 | 41 | 50 | 44 | 61 | False | False | True | False | False | False | False |
2 | 45.8 | 41 | 43 | 46 | 47 | 56 | False | False | False | True | False | False | False |
3 | 45.9 | 40 | 44 | 48 | 46 | 53 | False | True | False | False | False | False | False |
4 | 46.0 | 44 | 46 | 46 | 46 | 41 | False | False | False | False | False | True | False |
Split training and testing#
As we already did all the quality checks in 3.6, we will not repeat them here and directly go to AutoML experiment. First, split the data into training and testing subsets.
train_df = features[:300]
test_df = features[300:]
print('Data for Modeling: ' + str(train_df.shape))
print('Unseen Data For Predictions: ' + str(test_df.shape))
Data for Modeling: (300, 18)
Unseen Data For Predictions: (48, 18)
train_df
year | month | day | temp_2 | temp_1 | average | actual | forecast_noaa | forecast_acc | forecast_under | friend | week_Fri | week_Mon | week_Sat | week_Sun | week_Thurs | week_Tues | week_Wed | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 2016 | 1 | 1 | 45 | 45 | 45.6 | 45 | 43 | 50 | 44 | 29 | True | False | False | False | False | False | False |
1 | 2016 | 1 | 2 | 44 | 45 | 45.7 | 44 | 41 | 50 | 44 | 61 | False | False | True | False | False | False | False |
2 | 2016 | 1 | 3 | 45 | 44 | 45.8 | 41 | 43 | 46 | 47 | 56 | False | False | False | True | False | False | False |
3 | 2016 | 1 | 4 | 44 | 41 | 45.9 | 40 | 44 | 48 | 46 | 53 | False | True | False | False | False | False | False |
4 | 2016 | 1 | 5 | 41 | 40 | 46.0 | 44 | 46 | 46 | 46 | 41 | False | False | False | False | False | True | False |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
295 | 2016 | 11 | 9 | 63 | 71 | 52.4 | 65 | 48 | 56 | 52 | 42 | False | False | False | False | False | False | True |
296 | 2016 | 11 | 10 | 71 | 65 | 52.2 | 64 | 52 | 54 | 51 | 38 | False | False | False | False | True | False | False |
297 | 2016 | 11 | 11 | 65 | 64 | 51.9 | 63 | 50 | 53 | 52 | 55 | True | False | False | False | False | False | False |
298 | 2016 | 11 | 12 | 64 | 63 | 51.7 | 59 | 50 | 52 | 52 | 63 | False | False | True | False | False | False | False |
299 | 2016 | 11 | 13 | 63 | 59 | 51.4 | 55 | 48 | 56 | 50 | 64 | False | False | False | True | False | False | False |
300 rows × 18 columns
Run PyCaret (no hassle)#
Directly get to the point. Expect PyCaret to tell you what is going wrong. It should be able to automatically recognize the columns and assign appropriate data types to them.
First step, PyCaret need you to confirm the data columns are correctly parsed and their data types match their values. If yes, please enter in the popup text field.
from pycaret.regression import *
exp_reg101 = setup(data = train_df,
target = 'actual',
# imputation_type='iterative',
fold_shuffle=True,
session_id=123)
Description | Value | |
---|---|---|
0 | Session id | 123 |
1 | Target | actual |
2 | Target type | Regression |
3 | Original data shape | (300, 18) |
4 | Transformed data shape | (300, 18) |
5 | Transformed train set shape | (210, 18) |
6 | Transformed test set shape | (90, 18) |
7 | Numeric features | 10 |
8 | Preprocess | True |
9 | Imputation type | simple |
10 | Numeric imputation | mean |
11 | Categorical imputation | mode |
12 | Fold Generator | KFold |
13 | Fold Number | 10 |
14 | CPU Jobs | -1 |
15 | Use GPU | False |
16 | Log Experiment | False |
17 | Experiment Name | reg-default-name |
18 | USI | 569c |
Compare Models#
Once you confirmed the data types are correct, run the comparison using one single line of code:
best = compare_models(exclude = ['ransac'])
Model | MAE | MSE | RMSE | R2 | RMSLE | MAPE | TT (Sec) | |
---|---|---|---|---|---|---|---|---|
rf | Random Forest Regressor | 3.8383 | 24.2928 | 4.7893 | 0.7416 | 0.0713 | 0.0590 | 0.0180 |
ada | AdaBoost Regressor | 4.0621 | 26.1749 | 5.0051 | 0.7203 | 0.0747 | 0.0628 | 0.0090 |
et | Extra Trees Regressor | 4.0285 | 27.5968 | 5.0967 | 0.7128 | 0.0759 | 0.0620 | 0.0140 |
lasso | Lasso Regression | 3.8163 | 27.7310 | 4.9796 | 0.7117 | 0.0743 | 0.0590 | 0.0030 |
llar | Lasso Least Angle Regression | 3.8166 | 27.7349 | 4.9799 | 0.7117 | 0.0743 | 0.0590 | 0.0040 |
en | Elastic Net | 3.8289 | 28.0368 | 5.0033 | 0.7079 | 0.0747 | 0.0593 | 0.0040 |
huber | Huber Regressor | 3.8708 | 28.4140 | 5.0481 | 0.7048 | 0.0754 | 0.0595 | 0.0050 |
br | Bayesian Ridge | 3.8655 | 28.4961 | 5.0447 | 0.7037 | 0.0754 | 0.0599 | 0.0100 |
lightgbm | Light Gradient Boosting Machine | 4.0354 | 27.4807 | 5.1129 | 0.7018 | 0.0767 | 0.0624 | 0.0880 |
gbr | Gradient Boosting Regressor | 4.0385 | 28.0064 | 5.0722 | 0.6964 | 0.0756 | 0.0623 | 0.0090 |
ridge | Ridge Regression | 3.9747 | 30.2205 | 5.1822 | 0.6869 | 0.0774 | 0.0617 | 0.0040 |
lr | Linear Regression | 3.9802 | 30.3045 | 5.1893 | 0.6861 | 0.0775 | 0.0618 | 0.2340 |
knn | K Neighbors Regressor | 4.1305 | 30.5617 | 5.3863 | 0.6781 | 0.0806 | 0.0638 | 0.0070 |
xgboost | Extreme Gradient Boosting | 4.2272 | 31.1954 | 5.3840 | 0.6609 | 0.0809 | 0.0658 | 0.0100 |
dt | Decision Tree Regressor | 5.1238 | 49.7619 | 6.8983 | 0.4722 | 0.1038 | 0.0795 | 0.0030 |
omp | Orthogonal Matching Pursuit | 5.2630 | 52.6945 | 6.7951 | 0.4646 | 0.1032 | 0.0841 | 0.0030 |
par | Passive Aggressive Regressor | 8.0310 | 116.7174 | 9.3526 | 0.0639 | 0.1423 | 0.1315 | 0.0040 |
dummy | Dummy Regressor | 8.4914 | 105.2282 | 10.1498 | -0.0453 | 0.1581 | 0.1371 | 0.0030 |
lar | Least Angle Regression | 27.9575 | 6343.8284 | 35.4660 | -53.1037 | 0.3068 | 0.4584 | 0.0040 |
Get Best Model#
It looks great! PyCaret automatically did all the work under the hood and give us the best model! You need to look at the RMSE and R2 columns in the comparison table, and the best RMSE and R2 are both achieved by Random Forest, which is much clear and can save you a lot of time to compare them. These results are professionally calculated at the point where PyCaret thinks it is neither overfitting nor underfitting. So the comparison results are very solid and reliable.
Next step is to extract the best model’s hyperparameter configuration, and you can consider the hyperparameter tuning step is done, and go ahead and train your model.
best
RandomForestRegressor(n_jobs=-1, random_state=123)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
RandomForestRegressor(n_jobs=-1, random_state=123)
If you don’t think the best model is the most cost wise model and need to check more models, you can print out more models by top3 = compare_models(exclude = ['ransac'], n_select = 3)
and top3
will be a list and return the first 3 models.
Model Interpretation#
You can get more details about why the best model is the best. PyCaret provides a function called interpret_model
. It will produce a figure showing the influence of each input variable on the results. It is actually the same result of SHAP library and PyCaret integrates it.
interpret_model(best)
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent call last)
Cell In[12], line 1
----> 1 interpret_model(best)
File ~/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages/pycaret/utils/generic.py:964, in check_if_global_is_not_none.<locals>.decorator.<locals>.wrapper(*args, **kwargs)
962 if globals_d[name] is None:
963 raise ValueError(message)
--> 964 return func(*args, **kwargs)
File ~/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages/pycaret/regression/functional.py:1824, in interpret_model(estimator, plot, feature, observation, use_train_data, X_new_sample, y_new_sample, save, **kwargs)
1727 @check_if_global_is_not_none(globals(), _CURRENT_EXPERIMENT_DECORATOR_DICT)
1728 def interpret_model(
1729 estimator,
(...)
1737 **kwargs,
1738 ):
1739 """
1740 This function takes a trained model object and returns an interpretation plot
1741 based on the test / hold-out set.
(...)
1821
1822 """
-> 1824 return _CURRENT_EXPERIMENT.interpret_model(
1825 estimator=estimator,
1826 plot=plot,
1827 feature=feature,
1828 observation=observation,
1829 use_train_data=use_train_data,
1830 X_new_sample=X_new_sample,
1831 y_new_sample=y_new_sample,
1832 save=save,
1833 **kwargs,
1834 )
File ~/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages/pycaret/regression/oop.py:2122, in RegressionExperiment.interpret_model(self, estimator, plot, feature, observation, use_train_data, X_new_sample, y_new_sample, save, **kwargs)
2025 def interpret_model(
2026 self,
2027 estimator,
(...)
2035 **kwargs,
2036 ):
2037 """
2038 This function takes a trained model object and returns an interpretation plot
2039 based on the test / hold-out set.
(...)
2119
2120 """
-> 2122 return super().interpret_model(
2123 estimator=estimator,
2124 plot=plot,
2125 feature=feature,
2126 observation=observation,
2127 use_train_data=use_train_data,
2128 X_new_sample=X_new_sample,
2129 y_new_sample=y_new_sample,
2130 save=save,
2131 **kwargs,
2132 )
File ~/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages/pycaret/internal/pycaret_experiment/supervised_experiment.py:4051, in _SupervisedExperiment.interpret_model(self, estimator, plot, feature, observation, use_train_data, X_new_sample, y_new_sample, save, **kwargs)
4049 # checking if shap available
4050 if plot in ["summary", "correlation", "reason"]:
-> 4051 _check_soft_dependencies("shap", extra="analysis", severity="error")
4052 import shap
4054 # checking if pdpbox is available
File ~/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages/pycaret/utils/_dependencies.py:152, in _check_soft_dependencies(package, severity, extra, install_name)
150 if severity == "error":
151 logger.exception(f"{msg}")
--> 152 raise ModuleNotFoundError(msg)
153 elif severity == "warning":
154 logger.warning(f"{msg}")
ModuleNotFoundError:
'shap' is a soft dependency and not included in the pycaret installation. Please run: `pip install shap` to install.
Alternately, you can install this by running `pip install pycaret[analysis]`
Evaluate More Metrics#
PyCaret provides some awesome widgets and plots to give you an easy way for visualizing and checking many other useful metrics during its training.
evaluate_model(best)
TroubleShooting#
First time runners might meet this issue on M1: https://github.com/microsoft/LightGBM/issues/1369 Please reinstall pycaret and lightgbm and see if the problem is gone. If not, please create a new issue on the Github repository issue page.