{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Learning Behaviour\n\nIn these examples GMLVQ is used but the same applies to all the other algorithms. However,\nnot each solver provides the same variables. Additionally, the options \"lbfgs\" and \"bfgs\" are\nimplemented in scipy and their callbacks are different from the others. See Scipy's documentation\nfor further information.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import matplotlib\nimport matplotlib.pyplot as plt\nimport numpy as np\nfrom sklearn.datasets import load_iris\nfrom sklearn.metrics import classification_report\nfrom sklearn.pipeline import make_pipeline\nfrom sklearn.preprocessing import StandardScaler\n\nfrom sklvq import GMLVQ\n\nmatplotlib.rc(\"xtick\", labelsize=\"small\")\nmatplotlib.rc(\"ytick\", labelsize=\"small\")\n\ndata, labels = load_iris(return_X_y=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We create a process logger object and provide it to the solver of the model.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class ProcessLogger:\n    def __init__(self):\n        self.states = np.array([])\n\n    # A callback function has to accept two arguments, i.e., model and state, where model is the\n    # current model, and state contains a number of the optimizers variables.\n    def __call__(self, state):\n        self.states = np.append(self.states, state)\n        return False  # The callback function can also be used to stop training early,\n        # if some condition is met by returning True.\n\n\n# Initiate the \"logger\".\nlogger = ProcessLogger()\n\nscaler = StandardScaler()\n\nmodel = GMLVQ(\n    distance_type=\"adaptive-squared-euclidean\",\n    activation_type=\"swish\",\n    activation_params={\"beta\": 2},\n    solver_type=\"waypoint-gradient-descent\",\n    solver_params={\n        \"max_runs\": 15,\n        \"k\": 3,\n        \"step_size\": np.array(\n            [0.75, 0.85]\n        ),  # Note we chose very large step_sizes here to show\n        # the usefulness of waypoint averaging.\n        \"callback\": logger,\n    },\n    random_state=1428,\n)\n\npipeline = make_pipeline(scaler, model)\n\npipeline.fit(data, labels)\n\npredicted_labels = pipeline.predict(data)\n\n# Print a classification report (sklearn)\nprint(classification_report(labels, predicted_labels))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Additionally we can study the cost at each iteration of the solvers progress. Which doesn't\nlook very smooth and even gets worse. This is because of the chosen step_size, which is too\nlarge.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "iteration, fun = zip(*[(state[\"nit\"], state[\"fun\"]) for state in logger.states])\n\nax = plt.axes()\n\nax.set_title(\"Learning Curve (Less is better)\")\nax.plot(iteration, fun)\n_ = ax.legend([\"Cost per iteration\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "In the case of waypoint-gradient-descent there is an average cost (tfun) computed over the last\nk=3 updates and a regular update cost (nfun). Depending on which is less the regular update or\nthe average update is applied.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "tfun, nfun = zip(*[(state[\"tfun\"], state[\"nfun\"]) for state in logger.states])\n\nax = plt.axes()\n\nax.set_title(\"Learning Curves (Less is better)\")\nax.plot(iteration, nfun)\nax.plot(iteration, tfun)\n_ = ax.legend([\"Cost of regular gradient update\", \"Cost of average gradient update\"])"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}