3.10 AutoML#

Automated workflow for hyper-parameter tuning and optimal model finder

In this tutorial, we will try some cool technique that has been used widely to make AI/ML less tedious and boost your ML workflow efficiency.

If you have learned 3.6, you might be amazed but also annoyed by all those parameter tuning efforts and many back-n-forth iterations needed to figure out which configuration will be optimal for your case. It has been known as the major reason for low productivity in the AI/ML world. People come up with an idea that it seems most work in that tuning and iteration are very simple, can we automate it? The answer is yes, and that will be the technique we will introduce here: AutoML.

There are many AutoML solutions on the market, e.g., AutoKeras, auto-sklearn, H2O, Auto-WEKA, etc. Here we will focus on PyCaret which is a popular one in both academia and industry and very easy to use.

In the following tutorial, we will use the Pycaret Docker Image to run the tutorial. In Terminal, call docker to pull the PyCaret image and start a jupyter notebook:

docker pull pycaret/full
docker run -it -p 8888:8888 -e GRANT_SUDO=yes pycaret/full

Installations on M1 Mac can be tricky - especially when using lighgbm library. Try to install both libraries.

You will then be able to edit a notebook with the following cells:

!pip install pycaret
Requirement already satisfied: pycaret in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (3.3.2)
Requirement already satisfied: ipython>=5.5.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (8.16.1)
Requirement already satisfied: ipywidgets>=7.6.5 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (8.1.5)
Requirement already satisfied: tqdm>=4.62.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (4.66.5)
Requirement already satisfied: numpy<1.27,>=1.21 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (1.26.0)
Requirement already satisfied: pandas<2.2.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (2.1.1)
Requirement already satisfied: jinja2>=3 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (3.1.2)
Requirement already satisfied: scipy<=1.11.4,>=1.6.1 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (1.11.3)
Requirement already satisfied: joblib<1.4,>=1.2.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (1.3.2)
Requirement already satisfied: scikit-learn>1.4.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (1.4.2)
Requirement already satisfied: pyod>=1.1.3 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (2.0.2)
Requirement already satisfied: imbalanced-learn>=0.12.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (0.12.4)
Requirement already satisfied: category-encoders>=2.4.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (2.6.4)
Requirement already satisfied: lightgbm>=3.0.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (4.5.0)
Requirement already satisfied: numba>=0.55.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (0.60.0)
Requirement already satisfied: requests>=2.27.1 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (2.31.0)
Requirement already satisfied: psutil>=5.9.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (5.9.5)
Requirement already satisfied: markupsafe>=2.0.1 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (2.1.3)
Requirement already satisfied: importlib-metadata>=4.12.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (6.8.0)
Requirement already satisfied: nbformat>=4.2.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (5.9.2)
Requirement already satisfied: cloudpickle in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (3.1.0)
Requirement already satisfied: deprecation>=2.1.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (2.1.0)
Requirement already satisfied: xxhash in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (3.5.0)
Requirement already satisfied: matplotlib<3.8.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (3.7.5)
Requirement already satisfied: scikit-plot>=0.3.7 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (0.3.7)
Requirement already satisfied: yellowbrick>=1.4 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (1.5)
Requirement already satisfied: plotly>=5.14.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (5.17.0)
Requirement already satisfied: kaleido>=0.2.1 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (0.2.1)
Requirement already satisfied: schemdraw==0.15 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (0.15)
Requirement already satisfied: plotly-resampler>=0.8.3.1 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (0.10.0)
Requirement already satisfied: statsmodels>=0.12.1 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (0.14.0)
Requirement already satisfied: sktime==0.26.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (0.26.0)
Requirement already satisfied: tbats>=1.1.3 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (1.1.3)
Requirement already satisfied: pmdarima>=2.0.4 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (2.0.4)
Requirement already satisfied: wurlitzer in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pycaret) (3.1.1)
Requirement already satisfied: packaging in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from sktime==0.26.0->pycaret) (23.2)
Requirement already satisfied: scikit-base<0.8.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from sktime==0.26.0->pycaret) (0.7.8)
Requirement already satisfied: patsy>=0.5.1 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from category-encoders>=2.4.0->pycaret) (0.5.3)
Requirement already satisfied: threadpoolctl>=2.0.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from imbalanced-learn>=0.12.0->pycaret) (3.2.0)
Requirement already satisfied: zipp>=0.5 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from importlib-metadata>=4.12.0->pycaret) (3.17.0)
Requirement already satisfied: backcall in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from ipython>=5.5.0->pycaret) (0.2.0)
Requirement already satisfied: decorator in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from ipython>=5.5.0->pycaret) (5.1.1)
Requirement already satisfied: jedi>=0.16 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from ipython>=5.5.0->pycaret) (0.19.1)
Requirement already satisfied: matplotlib-inline in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from ipython>=5.5.0->pycaret) (0.1.6)
Requirement already satisfied: pickleshare in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from ipython>=5.5.0->pycaret) (0.7.5)
Requirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from ipython>=5.5.0->pycaret) (3.0.39)
Requirement already satisfied: pygments>=2.4.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from ipython>=5.5.0->pycaret) (2.16.1)
Requirement already satisfied: stack-data in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from ipython>=5.5.0->pycaret) (0.6.2)
Requirement already satisfied: traitlets>=5 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from ipython>=5.5.0->pycaret) (5.11.2)
Requirement already satisfied: typing-extensions in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from ipython>=5.5.0->pycaret) (4.12.2)
Requirement already satisfied: exceptiongroup in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from ipython>=5.5.0->pycaret) (1.1.3)
Requirement already satisfied: pexpect>4.3 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from ipython>=5.5.0->pycaret) (4.8.0)
Requirement already satisfied: appnope in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from ipython>=5.5.0->pycaret) (0.1.3)
Requirement already satisfied: comm>=0.1.3 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from ipywidgets>=7.6.5->pycaret) (0.1.4)
Requirement already satisfied: widgetsnbextension~=4.0.12 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from ipywidgets>=7.6.5->pycaret) (4.0.13)
Requirement already satisfied: jupyterlab-widgets~=3.0.12 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from ipywidgets>=7.6.5->pycaret) (3.0.13)
Requirement already satisfied: contourpy>=1.0.1 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from matplotlib<3.8.0->pycaret) (1.3.0)
Requirement already satisfied: cycler>=0.10 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from matplotlib<3.8.0->pycaret) (0.12.0)
Requirement already satisfied: fonttools>=4.22.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from matplotlib<3.8.0->pycaret) (4.43.1)
Requirement already satisfied: kiwisolver>=1.0.1 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from matplotlib<3.8.0->pycaret) (1.4.5)
Requirement already satisfied: pillow>=6.2.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from matplotlib<3.8.0->pycaret) (10.0.1)
Requirement already satisfied: pyparsing>=2.3.1 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from matplotlib<3.8.0->pycaret) (3.1.1)
Requirement already satisfied: python-dateutil>=2.7 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from matplotlib<3.8.0->pycaret) (2.8.2)
Requirement already satisfied: importlib-resources>=3.2.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from matplotlib<3.8.0->pycaret) (6.4.5)
Requirement already satisfied: fastjsonschema in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from nbformat>=4.2.0->pycaret) (2.18.1)
Requirement already satisfied: jsonschema>=2.6 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from nbformat>=4.2.0->pycaret) (4.19.1)
Requirement already satisfied: jupyter-core in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from nbformat>=4.2.0->pycaret) (5.3.2)
Requirement already satisfied: llvmlite<0.44,>=0.43.0dev0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from numba>=0.55.0->pycaret) (0.43.0)
Requirement already satisfied: pytz>=2020.1 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pandas<2.2.0->pycaret) (2023.3.post1)
Requirement already satisfied: tzdata>=2022.1 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pandas<2.2.0->pycaret) (2023.3)
Requirement already satisfied: tenacity>=6.2.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from plotly>=5.14.0->pycaret) (8.2.3)
Requirement already satisfied: dash>=2.9.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from plotly-resampler>=0.8.3.1->pycaret) (2.18.2)
Requirement already satisfied: orjson<4.0.0,>=3.8.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from plotly-resampler>=0.8.3.1->pycaret) (3.10.11)
Requirement already satisfied: tsdownsample>=0.1.3 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from plotly-resampler>=0.8.3.1->pycaret) (0.1.3)
Requirement already satisfied: Cython!=0.29.18,!=0.29.31,>=0.29 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pmdarima>=2.0.4->pycaret) (3.0.11)
Requirement already satisfied: urllib3 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pmdarima>=2.0.4->pycaret) (1.26.20)
Requirement already satisfied: setuptools!=50.0.0,>=38.6.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pmdarima>=2.0.4->pycaret) (68.2.2)
Requirement already satisfied: charset-normalizer<4,>=2 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from requests>=2.27.1->pycaret) (3.3.0)
Requirement already satisfied: idna<4,>=2.5 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from requests>=2.27.1->pycaret) (3.4)
Requirement already satisfied: certifi>=2017.4.17 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from requests>=2.27.1->pycaret) (2023.7.22)
Requirement already satisfied: Flask<3.1,>=1.0.4 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from dash>=2.9.0->plotly-resampler>=0.8.3.1->pycaret) (3.0.3)
Requirement already satisfied: Werkzeug<3.1 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from dash>=2.9.0->plotly-resampler>=0.8.3.1->pycaret) (3.0.0)
Requirement already satisfied: dash-html-components==2.0.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from dash>=2.9.0->plotly-resampler>=0.8.3.1->pycaret) (2.0.0)
Requirement already satisfied: dash-core-components==2.0.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from dash>=2.9.0->plotly-resampler>=0.8.3.1->pycaret) (2.0.0)
Requirement already satisfied: dash-table==5.0.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from dash>=2.9.0->plotly-resampler>=0.8.3.1->pycaret) (5.0.0)
Requirement already satisfied: retrying in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from dash>=2.9.0->plotly-resampler>=0.8.3.1->pycaret) (1.3.4)
Requirement already satisfied: nest-asyncio in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from dash>=2.9.0->plotly-resampler>=0.8.3.1->pycaret) (1.5.6)
Requirement already satisfied: parso<0.9.0,>=0.8.3 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from jedi>=0.16->ipython>=5.5.0->pycaret) (0.8.3)
Requirement already satisfied: attrs>=22.2.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from jsonschema>=2.6->nbformat>=4.2.0->pycaret) (23.1.0)
Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from jsonschema>=2.6->nbformat>=4.2.0->pycaret) (2023.7.1)
Requirement already satisfied: referencing>=0.28.4 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from jsonschema>=2.6->nbformat>=4.2.0->pycaret) (0.30.2)
Requirement already satisfied: rpds-py>=0.7.1 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from jsonschema>=2.6->nbformat>=4.2.0->pycaret) (0.10.4)
Requirement already satisfied: six in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from patsy>=0.5.1->category-encoders>=2.4.0->pycaret) (1.16.0)
Requirement already satisfied: ptyprocess>=0.5 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from pexpect>4.3->ipython>=5.5.0->pycaret) (0.7.0)
Requirement already satisfied: wcwidth in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=5.5.0->pycaret) (0.2.8)
Requirement already satisfied: platformdirs>=2.5 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from jupyter-core->nbformat>=4.2.0->pycaret) (3.5.1)
Requirement already satisfied: executing>=1.2.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from stack-data->ipython>=5.5.0->pycaret) (1.2.0)
Requirement already satisfied: asttokens>=2.1.0 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from stack-data->ipython>=5.5.0->pycaret) (2.4.0)
Requirement already satisfied: pure-eval in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from stack-data->ipython>=5.5.0->pycaret) (0.2.2)
Requirement already satisfied: itsdangerous>=2.1.2 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from Flask<3.1,>=1.0.4->dash>=2.9.0->plotly-resampler>=0.8.3.1->pycaret) (2.2.0)
Requirement already satisfied: click>=8.1.3 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from Flask<3.1,>=1.0.4->dash>=2.9.0->plotly-resampler>=0.8.3.1->pycaret) (8.1.7)
Requirement already satisfied: blinker>=1.6.2 in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (from Flask<3.1,>=1.0.4->dash>=2.9.0->plotly-resampler>=0.8.3.1->pycaret) (1.6.2)

[notice] A new release of pip is available: 23.3.1 -> 24.3.1
[notice] To update, run: pip install --upgrade pip

First we get data ready#

As usual, data collection is the first step. To better demonstrate the point of AutoML, we will use the same data as 3.6 Random Forest.

!pip install wget
Requirement already satisfied: wget in /Users/marinedenolle/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages (3.2)

[notice] A new release of pip is available: 23.3.1 -> 24.3.1
[notice] To update, run: pip install --upgrade pip
import wget
wget.download("https://docs.google.com/uc?export=download&id=1pko9oRmCllAxipZoa3aoztGZfPAD2iwj")
'temps (6).csv'

Display the data columns#

Show the columns and settle on the target variables and the input variables. In this chapter, we will use

# Pandas is used for data manipulation
import pandas as pd
# Read in data and display first 5 rows
features = pd.read_csv('temps.csv')
features.columns
Index(['year', 'month', 'day', 'week', 'temp_2', 'temp_1', 'average', 'actual',
       'forecast_noaa', 'forecast_acc', 'forecast_under', 'friend'],
      dtype='object')
  • Temp_2 : Maximum temperature on 2 days prior to today.

  • Temp_1: Maximum temperature on yesterday.

  • Average: Historical temperature average

  • Actual: Actual measure temperature on today.

  • Forecast_NOAA: Temperature values forecasted by NOAA

  • Friend: Forecasted by Friend (Randomly selected number within plus-minus 20 of Average temperature)

We will use the actual as the label, and all the other variables as features.

Check the data shape#

features.shape
(348, 12)
# One-hot encode the data using pandas get_dummies
features = pd.get_dummies(features)
# Display the first 5 rows of the last 12 columns
features.iloc[:,5:].head(5)
average actual forecast_noaa forecast_acc forecast_under friend week_Fri week_Mon week_Sat week_Sun week_Thurs week_Tues week_Wed
0 45.6 45 43 50 44 29 True False False False False False False
1 45.7 44 41 50 44 61 False False True False False False False
2 45.8 41 43 46 47 56 False False False True False False False
3 45.9 40 44 48 46 53 False True False False False False False
4 46.0 44 46 46 46 41 False False False False False True False

Split training and testing#

As we already did all the quality checks in 3.6, we will not repeat them here and directly go to AutoML experiment. First, split the data into training and testing subsets.

train_df = features[:300]
test_df = features[300:]
print('Data for Modeling: ' + str(train_df.shape))
print('Unseen Data For Predictions: ' + str(test_df.shape))
Data for Modeling: (300, 18)
Unseen Data For Predictions: (48, 18)
train_df
year month day temp_2 temp_1 average actual forecast_noaa forecast_acc forecast_under friend week_Fri week_Mon week_Sat week_Sun week_Thurs week_Tues week_Wed
0 2016 1 1 45 45 45.6 45 43 50 44 29 True False False False False False False
1 2016 1 2 44 45 45.7 44 41 50 44 61 False False True False False False False
2 2016 1 3 45 44 45.8 41 43 46 47 56 False False False True False False False
3 2016 1 4 44 41 45.9 40 44 48 46 53 False True False False False False False
4 2016 1 5 41 40 46.0 44 46 46 46 41 False False False False False True False
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
295 2016 11 9 63 71 52.4 65 48 56 52 42 False False False False False False True
296 2016 11 10 71 65 52.2 64 52 54 51 38 False False False False True False False
297 2016 11 11 65 64 51.9 63 50 53 52 55 True False False False False False False
298 2016 11 12 64 63 51.7 59 50 52 52 63 False False True False False False False
299 2016 11 13 63 59 51.4 55 48 56 50 64 False False False True False False False

300 rows × 18 columns

Run PyCaret (no hassle)#

Directly get to the point. Expect PyCaret to tell you what is going wrong. It should be able to automatically recognize the columns and assign appropriate data types to them.

First step, PyCaret need you to confirm the data columns are correctly parsed and their data types match their values. If yes, please enter in the popup text field.

from pycaret.regression import *
exp_reg101 = setup(data = train_df, 
                   target = 'actual',
                   # imputation_type='iterative', 
                   fold_shuffle=True, 
                   session_id=123)
  Description Value
0 Session id 123
1 Target actual
2 Target type Regression
3 Original data shape (300, 18)
4 Transformed data shape (300, 18)
5 Transformed train set shape (210, 18)
6 Transformed test set shape (90, 18)
7 Numeric features 10
8 Preprocess True
9 Imputation type simple
10 Numeric imputation mean
11 Categorical imputation mode
12 Fold Generator KFold
13 Fold Number 10
14 CPU Jobs -1
15 Use GPU False
16 Log Experiment False
17 Experiment Name reg-default-name
18 USI 569c

Compare Models#

Once you confirmed the data types are correct, run the comparison using one single line of code:

best = compare_models(exclude = ['ransac'])
  Model MAE MSE RMSE R2 RMSLE MAPE TT (Sec)
rf Random Forest Regressor 3.8383 24.2928 4.7893 0.7416 0.0713 0.0590 0.0180
ada AdaBoost Regressor 4.0621 26.1749 5.0051 0.7203 0.0747 0.0628 0.0090
et Extra Trees Regressor 4.0285 27.5968 5.0967 0.7128 0.0759 0.0620 0.0140
lasso Lasso Regression 3.8163 27.7310 4.9796 0.7117 0.0743 0.0590 0.0030
llar Lasso Least Angle Regression 3.8166 27.7349 4.9799 0.7117 0.0743 0.0590 0.0040
en Elastic Net 3.8289 28.0368 5.0033 0.7079 0.0747 0.0593 0.0040
huber Huber Regressor 3.8708 28.4140 5.0481 0.7048 0.0754 0.0595 0.0050
br Bayesian Ridge 3.8655 28.4961 5.0447 0.7037 0.0754 0.0599 0.0100
lightgbm Light Gradient Boosting Machine 4.0354 27.4807 5.1129 0.7018 0.0767 0.0624 0.0880
gbr Gradient Boosting Regressor 4.0385 28.0064 5.0722 0.6964 0.0756 0.0623 0.0090
ridge Ridge Regression 3.9747 30.2205 5.1822 0.6869 0.0774 0.0617 0.0040
lr Linear Regression 3.9802 30.3045 5.1893 0.6861 0.0775 0.0618 0.2340
knn K Neighbors Regressor 4.1305 30.5617 5.3863 0.6781 0.0806 0.0638 0.0070
xgboost Extreme Gradient Boosting 4.2272 31.1954 5.3840 0.6609 0.0809 0.0658 0.0100
dt Decision Tree Regressor 5.1238 49.7619 6.8983 0.4722 0.1038 0.0795 0.0030
omp Orthogonal Matching Pursuit 5.2630 52.6945 6.7951 0.4646 0.1032 0.0841 0.0030
par Passive Aggressive Regressor 8.0310 116.7174 9.3526 0.0639 0.1423 0.1315 0.0040
dummy Dummy Regressor 8.4914 105.2282 10.1498 -0.0453 0.1581 0.1371 0.0030
lar Least Angle Regression 27.9575 6343.8284 35.4660 -53.1037 0.3068 0.4584 0.0040

Get Best Model#

It looks great! PyCaret automatically did all the work under the hood and give us the best model! You need to look at the RMSE and R2 columns in the comparison table, and the best RMSE and R2 are both achieved by Random Forest, which is much clear and can save you a lot of time to compare them. These results are professionally calculated at the point where PyCaret thinks it is neither overfitting nor underfitting. So the comparison results are very solid and reliable.

Next step is to extract the best model’s hyperparameter configuration, and you can consider the hyperparameter tuning step is done, and go ahead and train your model.

best
RandomForestRegressor(n_jobs=-1, random_state=123)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

If you don’t think the best model is the most cost wise model and need to check more models, you can print out more models by top3 = compare_models(exclude = ['ransac'], n_select = 3) and top3 will be a list and return the first 3 models.

Model Interpretation#

You can get more details about why the best model is the best. PyCaret provides a function called interpret_model. It will produce a figure showing the influence of each input variable on the results. It is actually the same result of SHAP library and PyCaret integrates it.

interpret_model(best)
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[12], line 1
----> 1 interpret_model(best)

File ~/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages/pycaret/utils/generic.py:964, in check_if_global_is_not_none.<locals>.decorator.<locals>.wrapper(*args, **kwargs)
    962     if globals_d[name] is None:
    963         raise ValueError(message)
--> 964 return func(*args, **kwargs)

File ~/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages/pycaret/regression/functional.py:1824, in interpret_model(estimator, plot, feature, observation, use_train_data, X_new_sample, y_new_sample, save, **kwargs)
   1727 @check_if_global_is_not_none(globals(), _CURRENT_EXPERIMENT_DECORATOR_DICT)
   1728 def interpret_model(
   1729     estimator,
   (...)
   1737     **kwargs,
   1738 ):
   1739     """
   1740     This function takes a trained model object and returns an interpretation plot
   1741     based on the test / hold-out set.
   (...)
   1821 
   1822     """
-> 1824     return _CURRENT_EXPERIMENT.interpret_model(
   1825         estimator=estimator,
   1826         plot=plot,
   1827         feature=feature,
   1828         observation=observation,
   1829         use_train_data=use_train_data,
   1830         X_new_sample=X_new_sample,
   1831         y_new_sample=y_new_sample,
   1832         save=save,
   1833         **kwargs,
   1834     )

File ~/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages/pycaret/regression/oop.py:2122, in RegressionExperiment.interpret_model(self, estimator, plot, feature, observation, use_train_data, X_new_sample, y_new_sample, save, **kwargs)
   2025 def interpret_model(
   2026     self,
   2027     estimator,
   (...)
   2035     **kwargs,
   2036 ):
   2037     """
   2038     This function takes a trained model object and returns an interpretation plot
   2039     based on the test / hold-out set.
   (...)
   2119 
   2120     """
-> 2122     return super().interpret_model(
   2123         estimator=estimator,
   2124         plot=plot,
   2125         feature=feature,
   2126         observation=observation,
   2127         use_train_data=use_train_data,
   2128         X_new_sample=X_new_sample,
   2129         y_new_sample=y_new_sample,
   2130         save=save,
   2131         **kwargs,
   2132     )

File ~/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages/pycaret/internal/pycaret_experiment/supervised_experiment.py:4051, in _SupervisedExperiment.interpret_model(self, estimator, plot, feature, observation, use_train_data, X_new_sample, y_new_sample, save, **kwargs)
   4049 # checking if shap available
   4050 if plot in ["summary", "correlation", "reason"]:
-> 4051     _check_soft_dependencies("shap", extra="analysis", severity="error")
   4052     import shap
   4054 # checking if pdpbox is available

File ~/opt/miniconda3/envs/mlgeo/lib/python3.9/site-packages/pycaret/utils/_dependencies.py:152, in _check_soft_dependencies(package, severity, extra, install_name)
    150 if severity == "error":
    151     logger.exception(f"{msg}")
--> 152     raise ModuleNotFoundError(msg)
    153 elif severity == "warning":
    154     logger.warning(f"{msg}")

ModuleNotFoundError: 
'shap' is a soft dependency and not included in the pycaret installation. Please run: `pip install shap` to install.
Alternately, you can install this by running `pip install pycaret[analysis]`

Evaluate More Metrics#

PyCaret provides some awesome widgets and plots to give you an easy way for visualizing and checking many other useful metrics during its training.

evaluate_model(best)

TroubleShooting#

  1. First time runners might meet this issue on M1: https://github.com/microsoft/LightGBM/issues/1369 Please reinstall pycaret and lightgbm and see if the problem is gone. If not, please create a new issue on the Github repository issue page.