Skip to content

Commit

Permalink
fix: Fixed incorrect pretrained status check in DualPredictor. (#52)
Browse files Browse the repository at this point in the history
* fix: incorrect pretrained status check in DualPredictor.
  • Loading branch information
M-Mouhcine authored Mar 11, 2024
1 parent 6e0a8f8 commit 3b9e4e5
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 9 deletions.
4 changes: 2 additions & 2 deletions deel/puncc/api/conformalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,9 @@ def fit(
predictor.fit(X_fit, y_fit, **kwargs) # Fit K-fold predictor

# Make sure that predictor is already trained if train arg is False
elif self.train is False and predictor.is_trained is False:
elif self.train is False and predictor.get_is_trained() is False:
raise RuntimeError(
"'train' argument is set to 'False' but model is not pre-trained"
"'train' argument is set to 'False' but model(s) not pre-trained."
)

else: # Skipping training
Expand Down
18 changes: 14 additions & 4 deletions deel/puncc/api/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"""
This module provides standard wrappings for ML models.
"""
import pkgutil
import importlib
from copy import deepcopy
from typing import Any
from typing import Iterable
Expand All @@ -37,7 +37,7 @@
from deel.puncc.api.utils import dual_predictor_check
from deel.puncc.api.utils import supported_types_check

if pkgutil.find_loader("tensorflow") is not None:
if importlib.util.find_spec("tensorflow") is not None:
import tensorflow as tf


Expand Down Expand Up @@ -133,6 +133,10 @@ def __init__(self, model: Any, is_trained: bool = False, **compile_kwargs):
if self.compile_kwargs:
_ = self.model.compile(**self.compile_kwargs)

def get_is_trained(self) -> bool:
"""Get flag that informs if the model is pre-trained."""
return self.is_trained

def fit(self, X: Iterable, y: Optional[Iterable] = None, **kwargs) -> None:
"""Fit model to the training data.
Expand Down Expand Up @@ -190,7 +194,7 @@ def copy(self):
if (
"tensorflow" in model_type_str
or "keras" in model_type_str
and pkgutil.find_loader("tensorflow") is not None
and importlib.util.find_spec("tensorflow") is not None
):
# pylint: disable=E1101
model = tf.keras.models.clone_model(self.model)
Expand Down Expand Up @@ -319,6 +323,12 @@ def __init__(
if len(self.compile_args[1].keys()) != 0 and is_trained[1] is False:
_ = self.models[1].compile(**self.compile_args[1])

def get_is_trained(self) -> bool:
"""Get flag that informs if the models are pre-trained.
Returns True only when both models are pretrained.
"""
return self.is_trained[0] and self.is_trained[1]

def fit(
self, X: Iterable, y: Iterable, dictargs: List[dict] = [{}, {}]
) -> None:
Expand Down Expand Up @@ -392,7 +402,7 @@ def copy(self):
try:
model_copy = deepcopy(model)
except Exception as e_outer:
if pkgutil.find_loader("tensorflow") is not None:
if importlib.util.find_spec("tensorflow") is not None:
try:
# pylint: disable=E1101
model_copy = tf.keras.models.clone_model(model)
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
author = "Mouhcine Mendil, Luca Mossina and Joseba Dalmau"

# The full version, including alpha/beta/rc tags
release = "0.7.4"
release = "0.7.6"


# -- General configuration ---------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ joblib
matplotlib
numpy
pandas
scikit-learn
scikit-learn~=1.3.0
tqdm
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@

setuptools.setup(
name="puncc",
version="0.7.5",
version="0.7.6",
author=", ".join(["Mouhcine Mendil", "Luca Mossina", "Joseba Dalmau"]),
author_email=", ".join(
[
Expand Down
71 changes: 71 additions & 0 deletions tests/api/test_conformalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from deel.puncc.api.calibration import BaseCalibrator
from deel.puncc.api.conformalization import ConformalPredictor
from deel.puncc.api.prediction import BasePredictor
from deel.puncc.api.prediction import DualPredictor
from deel.puncc.api.splitting import KFoldSplitter
from deel.puncc.api.splitting import RandomSplitter

Expand Down Expand Up @@ -72,6 +73,12 @@ def setUp(self):
pred_set_func=prediction_sets.constant_interval,
)

# Definition of a dual calibrator
self.dual_calibrator = BaseCalibrator(
nonconf_score_func=nonconformity_scores.cqr_score,
pred_set_func=prediction_sets.cqr_interval,
)

# Random splitter
self.random_splitter = RandomSplitter(ratio=0.2)

Expand Down Expand Up @@ -166,6 +173,70 @@ def test_pretrained_predictor(self):
)
conformal_predictor.fit(self.X_calib, self.y_calib)

def test_pretrained_dualpredictor(self):
# Predictor initialized with trained model
model1 = linear_model.LinearRegression()
model2 = linear_model.LinearRegression()
model1.fit(self.X_fit, self.y_fit)
model2.fit(self.X_fit, self.y_fit)
trained_predictor = DualPredictor(
[model1, model2], is_trained=[True, True]
)
notrained_predictor = DualPredictor(
[model1, model2], is_trained=[False, True]
)

# Conformal predictor (good)
conformal_predictor = ConformalPredictor(
predictor=trained_predictor,
calibrator=self.dual_calibrator,
splitter=self.random_splitter,
train=False,
)
# Compute nonconformity scores
conformal_predictor.fit(self.X_train, self.y_train)

# Conformal predictor (bad)
with self.assertRaises(RuntimeError):
conformal_predictor = ConformalPredictor(
predictor=notrained_predictor,
calibrator=self.dual_calibrator,
splitter=self.random_splitter,
train=False,
)
# Compute nonconformity scores
conformal_predictor.fit(self.X_train, self.y_train)

# Conformalization with no splitter (good)
conformal_predictor = ConformalPredictor(
predictor=trained_predictor,
calibrator=self.dual_calibrator,
splitter=None,
train=False,
)
conformal_predictor.fit(self.X_calib, self.y_calib)
conformal_predictor.predict(self.X_test, alpha=0.1)

# Conformalization with no splitter (bad)
with self.assertRaises(RuntimeError):
conformal_predictor = ConformalPredictor(
predictor=notrained_predictor,
calibrator=self.dual_calibrator,
splitter=None,
train=False,
)
conformal_predictor.fit(self.X_calib, self.y_calib)

# Conformalization with no splitter and train set to True (bad)
with self.assertRaises(RuntimeError):
conformal_predictor = ConformalPredictor(
predictor=notrained_predictor,
calibrator=self.dual_calibrator,
splitter=None,
train=True,
)
conformal_predictor.fit(self.X_calib, self.y_calib)

def test_get_nconf_scores_split(self):
# Conformal predictor
conformal_predictor = ConformalPredictor(
Expand Down

0 comments on commit 3b9e4e5

Please sign in to comment.