{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Distance Functions\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from typing import TYPE_CHECKING\n\nimport numpy as np\nfrom sklearn.datasets import load_iris\nfrom sklearn.metrics import classification_report\nfrom sklearn.metrics.pairwise import pairwise_distances\n\nfrom sklvq.distances import DistanceBaseClass\n\nif TYPE_CHECKING:\n    from sklvq.models import LVQBaseClass\n\nfrom sklvq import GLVQ\n\ndata, labels = load_iris(return_X_y=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The sklvq contains already a few distance function. Please see the API reference under\nDocumentation. It has a very similar base class to that of the activation functions.\nHowever, the structure in which the distance and especially the gradient with respect\nto the different parameters need to be returned are important. Furthermore not every\ndistance functions works with every algorithm. Below the\n`sklvq.distances.SquaredEuclidean`, which is suitable for the GLVQ algorithm.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class CustomSquaredEuclidean(DistanceBaseClass):\n\n    # The distance implementations use the sklearn pairwise distance function.\n    def __init__(self, **other_kwargs):\n        self.metric_kwargs = {\"metric\": \"euclidean\", \"squared\": True}\n\n        if other_kwargs is not None:\n            self.metric_kwargs.update(other_kwargs)\n\n    # The call function needs to return a matrix with the number of X points on the\n    # rows and the columns the distance to the prototypes.\n    def __call__(self, data: np.ndarray, model: \"LVQBaseClass\") -> np.ndarray:\n        return pairwise_distances(data, model.prototypes_, **self.metric_kwargs,)\n\n    # The gradient is slightly more difficult as the gradient (with respect to 1\n    # prototype) needs to be provided in a vector the size of all the prototypes.\n    # Hence, all values are zero except those of the prototype indicated by the index\n    # i_prototype. In the case of GMLVQ and LGMVLQ distance functions als the gradient\n    # of the omega matrix needs to be returned (in this same vector). See the API\n    # reference under Documentation or github for other distance functions and their\n    # implementation.\n    def gradient(\n        self, data: np.ndarray, model: \"LVQBaseClass\", i_prototype: int\n    ) -> np.ndarray:\n        prototypes = model.get_model_params()\n        (num_samples, num_features) = data.shape\n\n        distance_gradient = np.zeros((num_samples, prototypes.size))\n\n        ip_start = i_prototype * num_features\n        ip_end = ip_start + num_features\n\n        distance_gradient[:, ip_start:ip_end] = -2 * (data - prototypes[i_prototype, :])\n\n        return distance_gradient"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The CustomSquaredEuclidean above, accompanied with some tests and documentation, would make a\ngreat addition to the sklvq package. However, it can also directly be passed to\nthe algorithm.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "model = GLVQ(\n    distance_type=CustomSquaredEuclidean,\n    activation_type=\"sigmoid\",\n    activation_params={\"beta\": 2},\n)\n\nmodel.fit(data, labels)\n\n# Predict the labels using the trained model\npredicted_labels = model.predict(data)\n\n# Print a classification report (sklearn)\nprint(classification_report(labels, predicted_labels))"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}