{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Grid Search\n\nCross validation  is not the whole story as it only can tell you  the expected performance of a\nsingle set of (hyper) parameters. Luckily  sklearn also provides a way of trying out multiple\nsettings and return the CV scores for each of them. We can use `gridsearch`_ for this.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from sklearn.datasets import load_iris\nfrom sklearn.model_selection import GridSearchCV, RepeatedStratifiedKFold\nfrom sklearn.pipeline import make_pipeline\nfrom sklearn.preprocessing import StandardScaler\n\nfrom sklvq import GMLVQ\n\ndata, labels = load_iris(return_X_y=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We first need to create a pipeline and initialize a parameter grid we want to search.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Create the standard scaler instance\nstandard_scaler = StandardScaler()\n\n# Create the GMLVQ model instance\nmodel = GMLVQ()\n\n# Link them together by using sklearn's pipeline\npipeline = make_pipeline(standard_scaler, model)\n\n# We want to see the difference in performance of the two following solvers\nsolvers_types = [\n    \"steepest-gradient-descent\",\n    \"waypoint-gradient-descent\",\n]\n\n# Currently, the  sklvq package contains only the following distance function compatible with\n# GMLVQ. However, see the customization examples for writing your own.\ndistance_types = [\"adaptive-squared-euclidean\"]\n\n# Because we are using a pipeline we need to prepend the parameters with the name of the\n# class of instance we want to provide the parameters for.\nparam_grid = [\n    {\n        \"gmlvq__solver_type\": solvers_types,\n        \"gmlvq__distance_type\": distance_types,\n        \"gmlvq__activation_type\": [\"identity\"],\n    },\n    {\n        \"gmlvq__solver_type\": solvers_types,\n        \"gmlvq__distance_type\": distance_types,\n        \"gmlvq__activation_type\": [\"sigmoid\"],\n        \"gmlvq__activation_params\": [{\"beta\": beta} for beta in range(1, 4, 1)],\n    },\n]\n# This grid can be read as: for each solver, try each distance type with the identity function,\n# and the sigmoid activation function for each beta in the range(1, 4, 1)\n\n# Initialize a repeated stratiefiedKFold object\nrepeated_kfolds = RepeatedStratifiedKFold(n_splits=5, n_repeats=5)\n\n# Initilialize the gridsearch CV instance that will fit the pipeline (standard scaler, GMLVQ) to\n# the data for each of the parameter sets in the grid. Where each fit is a 5 times\n# repeated stratified 5 fold cross validation. For each set return the testing accuracy.\nsearch = GridSearchCV(\n    pipeline,\n    param_grid,\n    scoring=\"accuracy\",\n    cv=repeated_kfolds,\n    n_jobs=4,\n    return_train_score=False,\n    verbose=10,\n)\n\n# The gridsearch object can be fitted to the data.\nsearch.fit(data, labels)\n\n# Print the best CV score and parameters.\nprint(\"\\nBest parameter (CV score=%0.3f):\" % search.best_score_)\nprint(search.best_params_)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "When inspecting the resulting classifier and its prototypes,\ne.g., in a plot overlaid on a scatter plot of the data, don't forget to apply the scaling to the data:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "transformed_data = search.transform(data)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}