Source code for sklvq.solvers._base

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any

import scipy.optimize as spo

if TYPE_CHECKING:
    import numpy as np

    from sklvq.models._base import LVQBaseClass
    from sklvq.objectives._base import ObjectiveBaseClass


ABC_METHOD_NOT_IMPL_MSG = "You should implement this!"


[docs] class SolverBaseClass(ABC): """Solver base class Abstract class for implementing solvers. Provides abstract methods with expected calls signatures. See also -------- SteepestGradientDescent, WaypointGradientDescent, AdaptiveMomentEstimation, BroydenFletcherGoldfarbShanno, LimitedMemoryBfgs """ def __init__(self, objective: ObjectiveBaseClass): self.objective = objective
[docs] @abstractmethod def solve( self, data: np.ndarray, labels: np.ndarray, model: LVQBaseClass, ) -> None: """ Solve updates the model it is given and does not return anything. Parameters ---------- data : ndarray of shape (number of observations, number of dimensions) The data. labels : ndarray of size (number of observations) The labels of the samples in the data. model : LVQBaseClass The initial model that will also hold the final result """ raise NotImplementedError(ABC_METHOD_NOT_IMPL_MSG)
class ScipyBaseSolver(SolverBaseClass): """ScipyBaseSolver Class to wrap around scipy solvers. See also -------- BroydenFletcherGoldfarbShanno, LimitedMemoryBfgs """ def __init__(self, objective, method: str = "L-BFGS-B", **kwargs): self.method = method self.params = kwargs super().__init__(objective) def _objective_wrapper(self, variables, model, data, labels): model.set_variables(variables) return self.objective(model, data, labels) def _objective_gradient_wrapper(self, variables, model, data, labels): model.set_variables(variables) return self.objective.gradient(model, data, labels) def solve( self, data: np.ndarray, labels: np.ndarray, model: LVQBaseClass, ): """ Solve updates the model it is given and does not return anything. Parameters ---------- data : ndarray of shape (number of observations, number of dimensions) The data. labels : ndarray of size (number of observations) The labels of the samples in the data. model : LVQBaseClass The initial model that will also hold the final result """ params = {"jac": self._objective_gradient_wrapper} if self.params is not None: params.update(self.params) result = spo.minimize( self._objective_wrapper, model.get_variables(), method=self.method, args=(model, data, labels), **params, ) # Update model model.set_variables(result.x) def _update_state(state_keys: list[str], **kwargs: Any) -> dict: # Helper function that can be used to update state dict. The state_keys is a list of strings # indicating the keys the dictionary should hold. If not provided in the kwargs they are set # to None. state = dict.fromkeys(state_keys) state.update(**kwargs) return state