Skip to content

Commit

Permalink
Update test_common.py
Browse files Browse the repository at this point in the history
  • Loading branch information
icfaust authored Jan 13, 2025
1 parent 9de1f51 commit 6eaec9a
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion sklearnex/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,10 @@ def sklearnex_trace(estimator_name, method_name):
of trace._modname.
"""
# get estimator
estimator = (SPECIAL_INSTANCES | PATCHED_MODELS)[estimator_name]
try:
est = PATCHED_MODELS[estimator_name]()
except KeyError:
est = SPECIAL_INSTANCES[estimator_name]

# get dataset
X, y = gen_dataset(est)[0]
Expand Down

0 comments on commit 6eaec9a

Please sign in to comment.