Skip to content

Commit

Permalink
Merge pull request #876 from roboflow/fix-aggregator-init-params
Browse files Browse the repository at this point in the history
Fix model_id bug with InferenceAggregator block
  • Loading branch information
PawelPeczek-Roboflow authored Dec 12, 2024
2 parents ebffba0 + b785224 commit ea7dc65
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,6 @@ def __init__(
api_key: Optional[str],
background_tasks: Optional[BackgroundTasks],
thread_pool_executor: Optional[ThreadPoolExecutor],
model_id: str,
):
if api_key is None:
raise ValueError(
Expand All @@ -228,7 +227,6 @@ def __init__(
self._background_tasks = background_tasks
self._thread_pool_executor = thread_pool_executor
self._predictions_aggregator = PredictionsAggregator()
self._model_id = model_id

@classmethod
def get_init_parameters(cls) -> List[str]:
Expand All @@ -244,10 +242,11 @@ def run(
predictions: Union[sv.Detections, dict],
frequency: int,
unique_aggregator_key: str,
model_id: str,
) -> BlockResult:
self._last_report_time_cache_key = f"workflows:steps_cache:roboflow_core/model_monitoring_inference_aggregator@v1:{unique_aggregator_key}:last_report_time"
if predictions:
self._predictions_aggregator.collect(predictions, self._model_id)
self._predictions_aggregator.collect(predictions, model_id)
if not self._is_in_reporting_range(frequency):
return {
"error_status": False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ def test_run_not_in_reporting_range_success(
api_key="my_api_key",
background_tasks=None,
thread_pool_executor=None,
model_id="my_model_id",
)
result = block.run(
fire_and_forget=True,
frequency=10,
predictions=predictions,
unique_aggregator_key=unique_aggregator_key,
model_id="my_model_id",
)

# then
Expand Down Expand Up @@ -121,13 +121,13 @@ def test_run_in_reporting_range_success_with_object_detection(
api_key=api_key,
background_tasks=None,
thread_pool_executor=None,
model_id="construction-safety/10",
)
result = block.run(
fire_and_forget=False,
frequency=10,
predictions=predictions,
unique_aggregator_key=unique_aggregator_key,
model_id="construction-safety/10",
)

# then
Expand Down Expand Up @@ -217,13 +217,13 @@ def test_run_in_reporting_range_success_with_single_label_classification(
api_key=api_key,
background_tasks=None,
thread_pool_executor=None,
model_id="pills-classification/1",
)
result = block.run(
fire_and_forget=False,
frequency=10,
predictions=predictions,
unique_aggregator_key=unique_aggregator_key,
model_id="pills-classification/1",
)

# then
Expand Down Expand Up @@ -313,13 +313,13 @@ def test_run_in_reporting_range_success_with_multi_label_classification(
api_key=api_key,
background_tasks=None,
thread_pool_executor=None,
model_id="animals/32",
)
result = block.run(
fire_and_forget=False,
frequency=10,
predictions=predictions,
unique_aggregator_key=unique_aggregator_key,
model_id="animals/32",
)

# then
Expand Down Expand Up @@ -415,13 +415,13 @@ def test_send_inference_results_to_model_monitoring_failure(
api_key=api_key,
background_tasks=None,
thread_pool_executor=None,
model_id="my_model_id",
)
result = block.run(
fire_and_forget=False,
frequency=1,
predictions=predictions,
unique_aggregator_key=unique_aggregator_key,
model_id="my_model_id",
)

# then
Expand Down Expand Up @@ -479,13 +479,13 @@ def test_run_when_not_in_reporting_range(
api_key=api_key,
background_tasks=None,
thread_pool_executor=None,
model_id="my_model_id",
)
result = block.run(
fire_and_forget=False,
frequency=10,
predictions=predictions,
unique_aggregator_key=unique_aggregator_key,
model_id="my_model_id",
)

# then
Expand Down Expand Up @@ -545,13 +545,13 @@ def test_run_when_fire_and_forget_with_background_tasks(
api_key=api_key,
background_tasks=background_tasks,
thread_pool_executor=None,
model_id="my_model_id",
)
result = block.run(
fire_and_forget=True,
frequency=10,
predictions=predictions,
unique_aggregator_key=unique_aggregator_key,
model_id="my_model_id",
)

# then
Expand Down Expand Up @@ -609,13 +609,13 @@ def test_run_when_fire_and_forget_with_thread_pool(
api_key=api_key,
background_tasks=None,
thread_pool_executor=thread_pool_executor,
model_id="my_model_id",
)
result = block.run(
fire_and_forget=True,
frequency=10,
predictions=predictions,
unique_aggregator_key=unique_aggregator_key,
model_id="my_model_id",
)

# then
Expand Down

0 comments on commit ea7dc65

Please sign in to comment.