{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Generalized Matrix LVQ (GMLVQ)\n\nExample of how to use GMLVQ `[1]`_ on the classic iris dataset.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import matplotlib\nimport matplotlib.pyplot as plt\nimport numpy as np\nfrom sklearn.datasets import load_iris\nfrom sklearn.metrics import classification_report\nfrom sklearn.preprocessing import StandardScaler\n\nfrom sklvq import GMLVQ\n\nmatplotlib.rc(\"xtick\", labelsize=\"small\")\nmatplotlib.rc(\"ytick\", labelsize=\"small\")\n\n# Contains also the target_names and feature_names, which we will use for the plots.\niris = load_iris()\n\ndata = iris.data\nlabels = iris.target\nfeature_names = [name[:-5] for name in iris.feature_names]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Fitting the Model\nScale the data and create a GLVQ object with, e.g., custom distance function, activation\nfunction and solver. See the API reference under documentation for defaults and other\npossible parameters.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Sklearn's standardscaler to perform z-transform\nscaler = StandardScaler()\n\n# Compute (fit) and apply (transform) z-transform\ndata = scaler.fit_transform(data)\n\n# The creation of the model object used to fit the data to.\nmodel = GMLVQ(\n    distance_type=\"adaptive-squared-euclidean\",\n    activation_type=\"swish\",\n    activation_params={\"beta\": 2},\n    solver_type=\"waypoint-gradient-descent\",\n    solver_params={\"max_runs\": 10, \"k\": 3, \"step_size\": np.array([0.1, 0.05])},\n    random_state=1428,\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The next step is to fit the GMLVQ object to the data and use the predict method to make the\npredictions. Note that this example only works on the training data and therefor does not say\nanything about the generalizability of the fitted model.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Train the model using the scaled data and true labels\nmodel.fit(data, labels)\n\n# Predict the labels using the trained model\npredicted_labels = model.predict(data)\n\n# To get a sense of the training performance we could print the classification report.\nprint(classification_report(labels, predicted_labels))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Extracting the Relevance Matrix\nIn addition to the prototypes (see GLVQ example), GMLVQ learns a\nmatrix `lambda_` which can tell us something about which features are most relevant for the\nclassification.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# The relevance matrix is available after fitting the model.\nrelevance_matrix = model.lambda_\n\n# Plot the diagonal of the relevance matrix\nfig, ax = plt.subplots()\nfig.suptitle(\"Relevance Matrix Diagonal\")\nax.bar(feature_names, np.diagonal(relevance_matrix))\nax.set_ylabel(\"Weight\")\nax.grid(False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Note that the relevance diagonal adds up to one. The most relevant features for\ndistinguishing between the classes present in  the iris dataset seem to be (in decreasing\norder) the petal length, petal width, sepal length, and sepal width. Although not very\ninteresting for the iris dataset one could use this information to select only the top most\nrelevant features to be used for the classification and thus reducing the dimensionality of\nthe problem.\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Transforming the data\nIn addition to making predictions GMLVQ can be used to transform the data using the\neigenvectors of the relevance matrix.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Transform the data (scaled by square root of eigenvalues \"scale = True\")\ntransformed_data = model.transform(data, scale=True)\n\nx_d = transformed_data[:, 0]\ny_d = transformed_data[:, 1]\n\n# Transform the model, i.e., the prototypes (scaled by square root of eigenvalues \"scale = True\")\ntransformed_model = model.transform(model.prototypes_, scale=True)\n\nx_m = transformed_model[:, 0]\ny_m = transformed_model[:, 1]\n\n# Plot\nfig, ax = plt.subplots()\nfig.suptitle(\"Discriminative projection Iris data and GMLVQ prototypes\")\ncolors = [\"blue\", \"red\", \"green\"]\nfor i, cls in enumerate(model.classes_):\n    ii = cls == labels\n    ax.scatter(\n        x_d[ii],\n        y_d[ii],\n        c=colors[i],\n        s=100,\n        alpha=0.7,\n        edgecolors=\"white\",\n        label=iris.target_names[model.prototypes_labels_[i]],\n    )\nax.scatter(x_m, y_m, c=colors, s=180, alpha=0.8, edgecolors=\"black\", linewidth=2.0)\nax.set_xlabel(\"First eigenvector\")\nax.set_ylabel(\"Second eigenvector\")\nax.legend()\nax.grid(True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The transformed data and prototypes can be used to visualize the problem in a lower dimension,\nwhich is also the space the model would compute the distance. The axis are the directions which\nare the most discriminating directions (combinations of features). Hence, inspecting the\neigenvalues and eigenvectors (axis) themselves can be interesting.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Plot the eigenvalues of the eigenvectors of the relevance matrix.\nfig, ax = plt.subplots()\nfig.suptitle(\"Eigenvalues\")\nax.bar(range(0, len(model.eigenvalues_)), model.eigenvalues_)\nax.set_ylabel(\"Weight\")\nax.grid(False)\n\n# Plot the first two eigenvectors of the relevance matrix, which  is called `omega_hat`.\nfig, ax = plt.subplots()\nfig.suptitle(\"First Eigenvector\")\nax.bar(feature_names, model.omega_hat_[:, 0])\nax.set_ylabel(\"Weight\")\nax.grid(False)\n\nfig, ax = plt.subplots()\nfig.suptitle(\"Second Eigenvector\")\nax.bar(feature_names, model.omega_hat_[:, 1])\nax.set_ylabel(\"Weight\")\nax.grid(False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "In the plots from the eigenvalues and eigenvector we see a similar effects as we could see from\njust the diagonal of `lambda_`. The two leading (most relevant or discriminating) eigenvectors\nmostly use the petal length and petal width in their calculation. The diagonal of the\nrelevance matrix can therefor be considered as a summary of the relevances of the features.\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## References\n_`[1]` Schneider, P., Biehl, M., & Hammer, B. (2009). \"Adaptive Relevance Matrices in Learning\nVector Quantization\" Neural Computation, 21(12), 3532\u20133561, 2009.\n\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}