from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any
import scipy.optimize as spo
if TYPE_CHECKING:
import numpy as np
from sklvq.models._base import LVQBaseClass
from sklvq.objectives._base import ObjectiveBaseClass
ABC_METHOD_NOT_IMPL_MSG = "You should implement this!"
[docs]
class SolverBaseClass(ABC):
"""Solver base class
Abstract class for implementing solvers. Provides abstract methods with expected calls
signatures.
See also
--------
SteepestGradientDescent, WaypointGradientDescent, AdaptiveMomentEstimation,
BroydenFletcherGoldfarbShanno, LimitedMemoryBfgs
"""
def __init__(self, objective: ObjectiveBaseClass):
self.objective = objective
[docs]
@abstractmethod
def solve(
self,
data: np.ndarray,
labels: np.ndarray,
model: LVQBaseClass,
) -> None:
"""
Solve updates the model it is given and does not return anything.
Parameters
----------
data : ndarray of shape (number of observations, number of dimensions)
The data.
labels : ndarray of size (number of observations)
The labels of the samples in the data.
model : LVQBaseClass
The initial model that will also hold the final result
"""
raise NotImplementedError(ABC_METHOD_NOT_IMPL_MSG)
class ScipyBaseSolver(SolverBaseClass):
"""ScipyBaseSolver
Class to wrap around scipy solvers.
See also
--------
BroydenFletcherGoldfarbShanno, LimitedMemoryBfgs
"""
def __init__(self, objective, method: str = "L-BFGS-B", **kwargs):
self.method = method
self.params = kwargs
super().__init__(objective)
def _objective_wrapper(self, variables, model, data, labels):
model.set_variables(variables)
return self.objective(model, data, labels)
def _objective_gradient_wrapper(self, variables, model, data, labels):
model.set_variables(variables)
return self.objective.gradient(model, data, labels)
def solve(
self,
data: np.ndarray,
labels: np.ndarray,
model: LVQBaseClass,
):
"""
Solve updates the model it is given and does not return anything.
Parameters
----------
data : ndarray of shape (number of observations, number of dimensions)
The data.
labels : ndarray of size (number of observations)
The labels of the samples in the data.
model : LVQBaseClass
The initial model that will also hold the final result
"""
params = {"jac": self._objective_gradient_wrapper}
if self.params is not None:
params.update(self.params)
result = spo.minimize(
self._objective_wrapper,
model.get_variables(),
method=self.method,
args=(model, data, labels),
**params,
)
# Update model
model.set_variables(result.x)
def _update_state(state_keys: list[str], **kwargs: Any) -> dict:
# Helper function that can be used to update state dict. The state_keys is a list of strings
# indicating the keys the dictionary should hold. If not provided in the kwargs they are set
# to None.
state = dict.fromkeys(state_keys)
state.update(**kwargs)
return state