Source code for themis_ml.meta_estimators
"""Module for Fairness-aware base estimators."""
import numpy as np
from sklearn.base import (
BaseEstimator, ClassifierMixin, MetaEstimatorMixin, clone)
from sklearn.utils.validation import check_array, check_X_y, check_is_fitted
from .checks import check_binary, s_is_needed_on_fit, s_is_needed_on_predict
[docs]class FairnessAwareMetaEstimator(
BaseEstimator, ClassifierMixin, MetaEstimatorMixin):
def __init__(self, estimator, relabeller=None):
"""Initialize metaestimator for composing fairness-aware methods.
:param Estimator estimator:
:param Transformer|None relabeller:
"""
self.relabeller = relabeller
self.estimator = estimator
def fit(self, X, y, s=None):
X, y = check_X_y(X, y)
y = check_binary(y)
self.relabeller_ = None
self.estimator_ = clone(self.estimator)
# fit_transform y labels using estimator
if self.relabeller is not None:
self.relabeller_ = clone(self.relabeller)
y = self.relabeller_.fit_transform(X, y, s=s)
# fit estimator
if s_is_needed_on_fit(self.estimator_, s):
s = check_binary(np.array(s).astype(int))
self.estimator_.fit(X, y, s)
else:
# since relabeller by definition needs s, this checks whether
# relabeller is None and the `s` array is provided.
if self.relabeller_ is None and s is not None:
raise ValueError(
"`s` arg provided but %s fit doesn't accept `s`" %
self.estimator_)
self.estimator_.fit(X, y)
def predict(self, X, s=None):
check_is_fitted(self, ["estimator_", "relabeller_"])
X = check_array(X)
if s_is_needed_on_predict(self.estimator_, s):
s = check_binary(np.array(s).astype(int))
return self.estimator_.predict(X, s)
else:
if s is not None:
raise ValueError(
"`s` arg provided but %s predict doesn't accept `s`" %
self.estimator_)
return self.estimator_.predict(X)
def predict_proba(self, X, s=None):
if not hasattr(self.estimator_, "predict_proba"):
raise AttributeError(
"%s has no method `predict_proba`" % self.estimator_)
check_is_fitted(self, ["estimator_", "relabeller_"])
X = check_array(X)
if s_is_needed_on_predict(self.estimator_, s):
s = check_binary(np.array(s).astype(int))
return self.estimator_.predict_proba(X, s)
else:
if s is not None:
raise ValueError(
"`s` arg provided but %s predict doesn't accept `s`" %
self.estimator_)
return self.estimator_.predict_proba(X)