{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Solvers\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from typing import TYPE_CHECKING\n\nimport numpy as np\nfrom sklearn.datasets import load_iris\nfrom sklearn.metrics import classification_report\nfrom sklearn.utils import shuffle\n\nfrom sklvq import GLVQ\nfrom sklvq.objectives import ObjectiveBaseClass\nfrom sklvq.solvers import SolverBaseClass\nfrom sklvq.solvers._base import _update_state\n\nif TYPE_CHECKING:\n    from sklvq.models import LVQBaseClass\n\nSTATE_KEYS = [\"variables\", \"nit\", \"fun\", \"step_size\"]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The sklvq package contains a number of different solvers.  Please see the API reference under\nDocumentation for the full list.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class CustomSteepestGradientDescent(SolverBaseClass):\n    def __init__(\n        self,\n        # init requires the objective instance to be given when  initialized. It will be passed\n        # to the (super) solver base class.\n        objective: ObjectiveBaseClass,\n        max_runs: int = 10,\n        batch_size: int = 1,\n        step_size: float = 0.1,\n        callback: callable = None,\n    ):\n        super().__init__(objective)\n        # In the actual implementation checks can be done to ensure proper values for the\n        # parameters of the solver (as is done in the actual code).\n        self.max_runs = max_runs\n        self.batch_size = batch_size\n        self.step_size = step_size\n        self.callback = callback\n\n    def solve(\n        self, data: np.ndarray, labels: np.ndarray, model: \"LVQBaseClass\",\n    ):\n        # Calls the callback function is provided with the initial values.\n        if self.callback is not None:\n            state = _update_state(\n                STATE_KEYS,\n                variables=np.copy(model.get_variables()),\n                nit=\"Initial\",\n                fun=self.objective(model, data, labels),\n            )\n            if self.callback(state):\n                return\n\n        batch_size = self.batch_size\n\n        # These checks cannot be done in init because data is not available at that moment.\n        if batch_size > data.shape[0]:\n            raise ValueError(\"Provided batch_size is invalid.\")\n\n        if batch_size <= 0:\n            batch_size = data.shape[0]\n\n        for i_run in range(0, self.max_runs):\n            # Randomize order of samples\n            shuffled_indices = shuffle(\n                np.array(range(0, labels.size)), random_state=model.random_state_\n            )\n\n            # Divide the shuffled indices into batches (not necessarily equal size,\n            # see documentation of numpy.array_split).\n            batches = np.array_split(\n                shuffled_indices,\n                list(range(batch_size, labels.size, batch_size)),\n                axis=0,\n            )\n\n            # Update step size using a simple annealing strategy\n            step_size = self.step_size / (1 + i_run / self.max_runs)\n\n            for i_batch in batches:\n                # Select the data\n                batch = data[i_batch, :]\n                batch_labels = labels[i_batch]\n\n                # Compute objective gradient\n                objective_gradient = self.objective.gradient(model, batch, batch_labels)\n\n                # Multiply each param by its given step_size\n                model.mul_step_size(step_size, objective_gradient)\n\n                # Update the model by subtracting the objective-gradient (descent) from the\n                # current models variables, e.g., (prototypes, omega) in case of GMLVQ\n                model.set_variables(\n                    np.subtract(  # returns out=objective_gradient\n                        model.get_variables(),\n                        objective_gradient,\n                        out=objective_gradient,\n                    )\n                )\n\n            # Call the callback function if provided with updated values.\n            if self.callback is not None:\n                state = _update_state(\n                    STATE_KEYS,\n                    variables=np.copy(model.get_variables()),\n                    nit=i_run + 1,\n                    fun=self.objective(model, data, labels),\n                    step_size=step_size,\n                )\n                # Simply return (stop the solver process) when callback returns true.\n                if self.callback(state):\n                    return"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The CustomSteepestGradientDescent above, accompanied with some tests and documentation, would\nmake a great addition to the sklvq package. However, it can also directly be passed to the\nalgorithm. Some other solvers might require more functionality not supported by the models,\nthis can be added dynamically to the model instances or by extending the required model and\ncreating a custom model class.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "data, labels = load_iris(return_X_y=True)\n\nmodel = GLVQ(\n    solver_type=CustomSteepestGradientDescent,\n    distance_type=\"squared-euclidean\",\n    activation_type=\"sigmoid\",\n    activation_params={\"beta\": 2},\n)\n\nmodel.fit(data, labels)\n\n# Predict the labels using the trained model\npredicted_labels = model.predict(data)\n\n# Print a classification report (sklearn)\nprint(classification_report(labels, predicted_labels))"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}