{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Receiver Operating Characteristic Curve\n\nExample of plotting the ROC curve for a classification task.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import matplotlib\nimport matplotlib.pyplot as plt\nimport numpy as np\nfrom sklearn.datasets import load_breast_cancer\nfrom sklearn.metrics import roc_curve, roc_auc_score, confusion_matrix\nfrom sklearn.preprocessing import StandardScaler\n\nfrom sklvq import GMLVQ\n\nmatplotlib.rc(\"xtick\", labelsize=\"small\")\nmatplotlib.rc(\"ytick\", labelsize=\"small\")\n\ndata, labels = load_breast_cancer(return_X_y=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Create a GMLVQ object and pass it a distance function, activation function and solver. See the\nAPI reference under documentation for defaults.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "model = GMLVQ(\n    distance_type=\"adaptive-squared-euclidean\",\n    activation_type=\"swish\",\n    activation_params={\"beta\": 2},\n    solver_type=\"waypoint-gradient-descent\",\n    solver_params={\"max_runs\": 10, \"k\": 3, \"step_size\": np.array([0.1, 0.05])},\n    random_state=31415,\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Fit the GMLVQ object to the data and plot the roc curve.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Object to perform z-transform\nscaler = StandardScaler()\n\n# Compute (fit) and apply (transform) z-transform\ndata = scaler.fit_transform(data)\n\n# Train the model using the scaled X and true labels\nmodel.fit(data, labels)\n\n# Get the decision values (which are used in predict) instead of the labels. The values are with\n# respect to the \"greater\" class, i.e., index 1.\nlabel_score = model.decision_function(data)\n\n# roc_curve expects the y_score to be with respect to the positive class.\nfpr, tpr, thresholds = roc_curve(\n    y_true=labels, y_score=label_score, pos_label=1, drop_intermediate=True\n)\nroc_auc = roc_auc_score(y_true=labels, y_score=label_score)\n\n#  Sometimes it is good to know where the Nearest prototype classifier is on this curve. This can\n#  be computed using the confusion matrix function from sklearn.\ntn, fp, fn, tp = confusion_matrix(y_true=labels, y_pred=model.predict(data)).ravel()\n\n# The tpr and fpr of the npc are then given by:\nnpc_tpr = tp / (tp + fn)\nnpc_fpr = fp / (fp + tn)\n\nfig, ax = plt.subplots()\nfig.suptitle(\"Receiver operating characteristic \")\n# Plot the ROC curve\nax.plot(fpr, tpr, color=\"darkorange\", lw=2, label=\"ROC AUC = {:.3f}\".format(roc_auc))\n# Plot the random line\nax.plot([0, 1], [0, 1], color=\"navy\", lw=2, linestyle=\"--\")\n# Plot the NPC classifier\nax.plot(npc_fpr, npc_tpr, color=\"green\", marker=\"o\", markersize=\"12\")\nax.set_xlabel(\"False Positive Rate\")\nax.set_ylabel(\"True Positive Rate\")\nax.legend(loc=\"lower right\")\nax.grid(False)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}