{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n\n# Generalized LVQ (GLVQ)\n\nExample of how to fit the GLVQ `[1]`_ algorithm on the classic iris dataset.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import matplotlib\nimport matplotlib.pyplot as plt\nfrom sklearn.datasets import load_iris\nfrom sklearn.metrics import classification_report\nfrom sklearn.preprocessing import StandardScaler\n\nfrom sklvq import GLVQ\n\nmatplotlib.rc(\"xtick\", labelsize=\"small\")\nmatplotlib.rc(\"ytick\", labelsize=\"small\")\n\n# Contains also the target_names and feature_names, which we will use for the plots.\niris = load_iris()\n\ndata = iris.data\nlabels = iris.target"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Fitting the Model\nScale the data and create a GLVQ object with, e.g., custom distance function, activation\nfunction and solver. See the API reference under documentation for defaults and other\npossible parameters.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Sklearn's standardscaler to perform z-transform\nscaler = StandardScaler()\n\n# Compute (fit) and apply (transform) z-transform\ndata = scaler.fit_transform(data)\n\n# The creation of the model object used to fit the data to.\nmodel = GLVQ(\n    distance_type=\"squared-euclidean\",\n    activation_type=\"swish\",\n    activation_params={\"beta\": 2},\n    solver_type=\"steepest-gradient-descent\",\n    solver_params={\"max_runs\": 20, \"step_size\": 0.1},\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The next step is to fit the GLVQ object to the data and use the predict method to make the\npredictions. Note that this example only works on the training data and therefor does not say\nanything about the generalizability of the fitted model.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Train the model using the iris dataset\nmodel.fit(data, labels)\n\n# Predict the labels using the trained model\npredicted_labels = model.predict(data)\n\n# To get a sense of the training performance we could print the classification report.\nprint(classification_report(labels, predicted_labels))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Extracting the Prototypes\nThe GLVQ model produces prototypes as representations for the different\nclasses. These prototypes can be accessed and, e.g., plotted for visual inspection. Note that\nthe prototypes of the model are within the z-score space and are transformed back before they\nare plotted.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "colors = [\"blue\", \"red\", \"green\"]\nnum_prototypes = model.prototypes_.shape[0]\nnum_features = model.prototypes_.shape[1]\n\nfig, ax = plt.subplots(num_prototypes, 1)\nfig.suptitle(\"Prototype of each class\")\n\nfor i, prototype in enumerate(model.prototypes_):\n    # Reverse the z-transform to go back to the original feature space.\n    prototype = scaler.inverse_transform(prototype)\n\n    ax[i].bar(\n        range(num_features),\n        prototype,\n        color=colors[i],\n        label=iris.target_names[model.prototypes_labels_[i]],\n    )\n    ax[i].set_xticks(range(num_features))\n    if i == (num_prototypes - 1):\n        ax[i].set_xticklabels([name[:-5] for name in iris.feature_names])\n    else:\n        ax[i].set_xticklabels([], visible=False)\n        ax[i].tick_params(\n            axis=\"x\", which=\"both\", bottom=False, top=False, labelbottom=False\n        )\n    ax[i].set_ylabel(\"cm\")\n    ax[i].legend()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## References\n_`[1]` Sato, A., and Yamada, K. (1996) \u201cGeneralized Learning Vector Quantization.\u201d Advances in\nNeural Network Information Processing Systems, 423\u2013429, 1996.\n\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}