.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/01_basic_usage/plot_02_gmlvq.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_01_basic_usage_plot_02_gmlvq.py: ============================== Generalized Matrix LVQ (GMLVQ) ============================== Example of how to use GMLVQ `[1]`_ on the classic iris dataset. .. GENERATED FROM PYTHON SOURCE LINES 9-29 .. code-block:: Python import matplotlib import matplotlib.pyplot as plt import numpy as np from sklearn.datasets import load_iris from sklearn.metrics import classification_report from sklearn.preprocessing import StandardScaler from sklvq import GMLVQ matplotlib.rc("xtick", labelsize="small") matplotlib.rc("ytick", labelsize="small") # Contains also the target_names and feature_names, which we will use for the plots. iris = load_iris() data = iris.data labels = iris.target feature_names = [name[:-5] for name in iris.feature_names] .. GENERATED FROM PYTHON SOURCE LINES 30-35 Fitting the Model ................. Scale the data and create a GMLVQ object with, e.g., a distance function, activation function and solver. See the API reference under documentation for defaults and other possible parameters. .. GENERATED FROM PYTHON SOURCE LINES 35-52 .. code-block:: Python # Sklearn's standardscaler to perform z-transform scaler = StandardScaler() # Compute (fit) and apply (transform) z-transform data = scaler.fit_transform(data) # The creation of the model object used to fit the data to. model = GMLVQ( distance_type="adaptive-squared-euclidean", activation_type="swish", activation_params={"beta": 2}, solver_type="waypoint-gradient-descent", solver_params={"max_runs": 10, "k": 3, "step_size": np.array([0.1, 0.05])}, random_state=1428, ) .. GENERATED FROM PYTHON SOURCE LINES 53-56 The next step is to fit the GMLVQ object to the data and use the predict method to make the predictions. Note that this example only works on the training data and therefore does not say anything about the generalizability of the fitted model. .. GENERATED FROM PYTHON SOURCE LINES 56-66 .. code-block:: Python # Train the model using the scaled data and true labels model.fit(data, labels) # Predict the labels using the trained model predicted_labels = model.predict(data) # To get a sense of the training performance we could print the classification report. print(classification_report(labels, predicted_labels)) .. rst-class:: sphx-glr-script-out .. code-block:: none precision recall f1-score support 0 1.00 1.00 1.00 50 1 0.98 0.96 0.97 50 2 0.96 0.98 0.97 50 accuracy 0.98 150 macro avg 0.98 0.98 0.98 150 weighted avg 0.98 0.98 0.98 150 .. GENERATED FROM PYTHON SOURCE LINES 67-72 Extracting the Relevance Matrix ............................... In addition to the prototypes (see the GLVQ example), GMLVQ learns a matrix `lambda_` which can tell us something about which features are most relevant for the classification. .. GENERATED FROM PYTHON SOURCE LINES 72-84 .. code-block:: Python # The relevance matrix is available after fitting the model. relevance_matrix = model.lambda_ # Plot the diagonal of the relevance matrix fig, ax = plt.subplots() fig.suptitle("Relevance Matrix Diagonal") ax.bar(feature_names, np.diagonal(relevance_matrix)) ax.set_ylabel("Weight") ax.grid(False) .. image-sg:: /auto_examples/01_basic_usage/images/sphx_glr_plot_02_gmlvq_001.png :alt: Relevance Matrix Diagonal :srcset: /auto_examples/01_basic_usage/images/sphx_glr_plot_02_gmlvq_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 85-91 Note that the relevance diagonal adds up to one. The most relevant features for distinguishing between the classes present in the iris dataset seem to be (in decreasing order) the petal length, petal width, sepal length, and sepal width. Although not very interesting for the iris dataset, one could use this information to select only the top most relevant features to be used for the classification and thus reducing the dimensionality of the problem. .. GENERATED FROM PYTHON SOURCE LINES 93-97 Transforming the data ..................... In addition to making predictions GMLVQ can be used to transform the data using the eigenvectors of the relevance matrix. .. GENERATED FROM PYTHON SOURCE LINES 97-132 .. code-block:: Python # Transform the data (scaled by square root of eigenvalues "scale = True") transformed_data = model.transform(data, scale=True) x_d = transformed_data[:, 0] y_d = transformed_data[:, 1] # Transform the model, i.e., the prototypes (scaled by square root of eigenvalues "scale = True"). These scaled # eigenvectors are stored in the attribute called `omega_hat_`. transformed_model = model.transform(model.prototypes_, scale=True) x_m = transformed_model[:, 0] y_m = transformed_model[:, 1] # Plot fig, ax = plt.subplots() fig.suptitle("Discriminative projection Iris data and GMLVQ prototypes") colors = ["blue", "red", "green"] for i, cls in enumerate(model.classes_): ii = cls == labels ax.scatter( x_d[ii], y_d[ii], c=colors[i], s=100, alpha=0.7, edgecolors="white", label=iris.target_names[model.prototypes_labels_[i]], ) ax.scatter(x_m, y_m, c=colors, s=180, alpha=0.8, edgecolors="black", linewidth=2.0) ax.set_xlabel("First eigenvector") ax.set_ylabel("Second eigenvector") ax.legend() ax.grid(True) .. image-sg:: /auto_examples/01_basic_usage/images/sphx_glr_plot_02_gmlvq_002.png :alt: Discriminative projection Iris data and GMLVQ prototypes :srcset: /auto_examples/01_basic_usage/images/sphx_glr_plot_02_gmlvq_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 133-137 The transformed data and prototypes can be used to visualize the problem in a lower dimension, which is also the space the model would compute the distance. The axis are the directions which are the most discriminating directions (combinations of features). Hence, inspecting the eigenvalues and eigenvectors (axis) themselves can be interesting. .. GENERATED FROM PYTHON SOURCE LINES 137-158 .. code-block:: Python # Plot the eigenvalues of the eigenvectors of the relevance matrix. fig, ax = plt.subplots() fig.suptitle("Eigenvalues") ax.bar(range(len(model.eigenvalues_)), model.eigenvalues_) ax.set_ylabel("Weight") ax.grid(False) # Plot the first two eigenvectors of the relevance matrix, which are stored in the attribute called `eigenvectors_` fig, ax = plt.subplots() fig.suptitle("First Eigenvector") ax.bar(feature_names, model.eigenvectors_[0, :]) ax.set_ylabel("Weight") ax.grid(False) fig, ax = plt.subplots() fig.suptitle("Second Eigenvector") ax.bar(feature_names, model.eigenvectors_[1, :]) ax.set_ylabel("Weight") ax.grid(False) .. rst-class:: sphx-glr-horizontal * .. image-sg:: /auto_examples/01_basic_usage/images/sphx_glr_plot_02_gmlvq_003.png :alt: Eigenvalues :srcset: /auto_examples/01_basic_usage/images/sphx_glr_plot_02_gmlvq_003.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/01_basic_usage/images/sphx_glr_plot_02_gmlvq_004.png :alt: First Eigenvector :srcset: /auto_examples/01_basic_usage/images/sphx_glr_plot_02_gmlvq_004.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/01_basic_usage/images/sphx_glr_plot_02_gmlvq_005.png :alt: Second Eigenvector :srcset: /auto_examples/01_basic_usage/images/sphx_glr_plot_02_gmlvq_005.png :class: sphx-glr-multi-img .. GENERATED FROM PYTHON SOURCE LINES 159-163 In the plots from the eigenvalues and eigenvector we see a similar effects as we could see from just the diagonal of `lambda_`. The two leading (most relevant or discriminating) eigenvectors mostly use the petal length and petal width in their calculation. The diagonal of the relevance matrix can therefore be considered as a summary of the relevances of the features. .. GENERATED FROM PYTHON SOURCE LINES 165-169 References .......... _`[1]` Schneider, P., Biehl, M., & Hammer, B. (2009). "Adaptive Relevance Matrices in Learning Vector Quantization" Neural Computation, 21(12), 3532–3561, 2009. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.304 seconds) .. _sphx_glr_download_auto_examples_01_basic_usage_plot_02_gmlvq.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_02_gmlvq.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_02_gmlvq.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_02_gmlvq.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_