.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/01_basic_usage/plot_01_glvq.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_01_basic_usage_plot_01_glvq.py: .. _GLVQ: ====================== Generalized LVQ (GLVQ) ====================== Example of how to fit the GLVQ `[1]`_ algorithm on the classic iris dataset. .. GENERATED FROM PYTHON SOURCE LINES 11-30 .. code-block:: Python import matplotlib import matplotlib.pyplot as plt import numpy as np from sklearn.datasets import load_iris from sklearn.metrics import classification_report from sklearn.preprocessing import StandardScaler from sklvq import GLVQ matplotlib.rc("xtick", labelsize="small") matplotlib.rc("ytick", labelsize="small") # Contains also the target_names and feature_names, which we will use for the plots. iris = load_iris() data = iris.data labels = iris.target .. GENERATED FROM PYTHON SOURCE LINES 31-36 Fitting the Model ................. Scale the data and create a GLVQ object with, e.g., custom distance function, activation function and solver. See the API reference under documentation for defaults and other possible parameters. .. GENERATED FROM PYTHON SOURCE LINES 36-52 .. code-block:: Python # Sklearn's standardscaler to perform z-transform scaler = StandardScaler() # Compute (fit) and apply (transform) z-transform data = scaler.fit_transform(data) # The creation of the model object used to fit the data to. model = GLVQ( distance_type="squared-euclidean", activation_type="swish", activation_params={"beta": 2}, solver_type="steepest-gradient-descent", solver_params={"max_runs": 20, "step_size": 0.1}, ) .. GENERATED FROM PYTHON SOURCE LINES 53-56 The next step is to fit the GLVQ object to the data and use the predict method to make the predictions. Note that this example only works on the training data and therefor does not say anything about the generalizability of the fitted model. .. GENERATED FROM PYTHON SOURCE LINES 56-66 .. code-block:: Python # Train the model using the iris dataset model.fit(data, labels) # Predict the labels using the trained model predicted_labels = model.predict(data) # To get a sense of the training performance we could print the classification report. print(classification_report(labels, predicted_labels)) .. rst-class:: sphx-glr-script-out .. code-block:: none precision recall f1-score support 0 1.00 1.00 1.00 50 1 0.96 0.94 0.95 50 2 0.94 0.96 0.95 50 accuracy 0.97 150 macro avg 0.97 0.97 0.97 150 weighted avg 0.97 0.97 0.97 150 .. GENERATED FROM PYTHON SOURCE LINES 67-73 Extracting the Prototypes ......................... The GLVQ model produces prototypes as representations for the different classes. These prototypes can be accessed and, e.g., plotted for visual inspection. Note that the prototypes of the model are within the z-score space and are transformed back before they are plotted. .. GENERATED FROM PYTHON SOURCE LINES 73-100 .. code-block:: Python colors = ["blue", "red", "green"] num_prototypes = model.prototypes_.shape[0] num_features = model.prototypes_.shape[1] fig, ax = plt.subplots(num_prototypes, 1) fig.suptitle("Prototype of each class") for i, prototype in enumerate(model.prototypes_): # Reverse the z-transform to go back to the original feature space. prototype = scaler.inverse_transform(np.atleast_2d(prototype)).squeeze() ax[i].bar( range(num_features), prototype, color=colors[i], label=iris.target_names[model.prototypes_labels_[i]], ) ax[i].set_xticks(range(num_features)) if i == (num_prototypes - 1): ax[i].set_xticklabels([name[:-5] for name in iris.feature_names]) else: ax[i].set_xticklabels([], visible=False) ax[i].tick_params(axis="x", which="both", bottom=False, top=False, labelbottom=False) ax[i].set_ylabel("cm") ax[i].legend() .. image-sg:: /auto_examples/01_basic_usage/images/sphx_glr_plot_01_glvq_001.png :alt: Prototype of each class :srcset: /auto_examples/01_basic_usage/images/sphx_glr_plot_01_glvq_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 101-105 References .......... _`[1]` Sato, A., and Yamada, K. (1996) “Generalized Learning Vector Quantization.” Advances in Neural Network Information Processing Systems, 423–429, 1996. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.545 seconds) .. _sphx_glr_download_auto_examples_01_basic_usage_plot_01_glvq.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_01_glvq.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_01_glvq.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_01_glvq.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_