Source code for sklvq.solvers._base

from abc import ABC, abstractmethod
from typing import List, Any
from typing import TYPE_CHECKING

import numpy as np
import scipy.optimize as spo

from ..objectives import ObjectiveBaseClass

if TYPE_CHECKING:
    from ..models import LVQBaseClass


[docs]class SolverBaseClass(ABC): """Solver base class Abstract class for implementing solvers. Provides abstract methods with expected calls signatures. See also -------- SteepestGradientDescent, WaypointGradientDescent, AdaptiveMomentEstimation, BroydenFletcherGoldfarbShanno, LimitedMemoryBfgs """ def __init__(self, objective: ObjectiveBaseClass): self.objective = objective
[docs] @abstractmethod def solve( self, data: np.ndarray, labels: np.ndarray, model: "LVQBaseClass", ) -> None: """ Solve updates the model it is given and does not return anything. Parameters ---------- data : ndarray of shape (number of observations, number of dimensions) The data. labels : ndarray of size (number of observations) The labels of the samples in the data. model : LVQBaseClass The initial model that will also hold the final result """ raise NotImplementedError("You should implement this!")
class ScipyBaseSolver(SolverBaseClass): """ScipyBaseSolver Class to wrap around scipy solvers. See also -------- BroydenFletcherGoldfarbShanno, LimitedMemoryBfgs """ def __init__(self, objective, method: str = "L-BFGS-B", **kwargs): self.method = method self.params = kwargs super().__init__(objective) def _objective_wrapper(self, variables, model, data, labels): model.set_variables(variables) return self.objective(model, data, labels) def _objective_gradient_wrapper(self, variables, model, data, labels): model.set_variables(variables) return self.objective.gradient(model, data, labels) def solve( self, data: np.ndarray, labels: np.ndarray, model: "LVQBaseClass", ): """ Solve updates the model it is given and does not return anything. Parameters ---------- data : ndarray of shape (number of observations, number of dimensions) The data. labels : ndarray of size (number of observations) The labels of the samples in the data. model : LVQBaseClass The initial model that will also hold the final result """ params = {"jac": self._objective_gradient_wrapper} if self.params is not None: params.update(self.params) result = spo.minimize( self._objective_wrapper, model.get_variables(), method=self.method, args=(model, data, labels), **params ) # Update model model.set_variables(result.x) def _update_state(state_keys: List[str], **kwargs: Any) -> dict: # Helper function that can be used to update state dict. The state_keys is a list of strings # indicating the keys the dictionary should hold. If not provided in the kwargs they are set # to None. state = dict.fromkeys(state_keys) state.update(**kwargs) return state