Skip to content

Commit

Permalink
Merge pull request #813 from roboflow/pulling_owlv2_config_into_env
Browse files Browse the repository at this point in the history
bumping owlv2 version and putting cache size in env
  • Loading branch information
PawelPeczek-Roboflow authored Nov 15, 2024
2 parents 85d2bf6 + 1c5ef0d commit fc9c75c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
10 changes: 9 additions & 1 deletion inference/core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,19 @@

# Gaze version ID, default is "L2CS"
GAZE_VERSION_ID = os.getenv("GAZE_VERSION_ID", "L2CS")
OWLV2_VERSION_ID = os.getenv("OWLV2_VERSION_ID", "owlv2-base-patch16-ensemble")

# Gaze model ID
GAZE_MODEL_ID = f"gaze/{CLIP_VERSION_ID}"

# OWLv2 version ID, default is "owlv2-large-patch14-ensemble"
OWLV2_VERSION_ID = os.getenv("OWLV2_VERSION_ID", "owlv2-large-patch14-ensemble")

# OWLv2 image cache size, default is 1000 since each image has max <MAX_DETECTIONS> boxes at ~4kb each
OWLV2_IMAGE_CACHE_SIZE = int(os.getenv("OWLV2_IMAGE_CACHE_SIZE", 1000))

# OWLv2 model cache size, default is 100 as memory is num_prompts * ~4kb and num_prompts is rarely above 1000 (but could be much higher)
OWLV2_MODEL_CACHE_SIZE = int(os.getenv("OWLV2_MODEL_CACHE_SIZE", 100))

# Maximum batch size for GAZE, default is 8
GAZE_MAX_BATCH_SIZE = int(os.getenv("GAZE_MAX_BATCH_SIZE", 8))

Expand Down
18 changes: 12 additions & 6 deletions inference/models/owlv2/owlv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@
ObjectDetectionInferenceResponse,
ObjectDetectionPrediction,
)
from inference.core.env import DEVICE, MAX_DETECTIONS
from inference.core.env import (
DEVICE,
MAX_DETECTIONS,
OWLV2_IMAGE_CACHE_SIZE,
OWLV2_MODEL_CACHE_SIZE,
OWLV2_VERSION_ID,
)
from inference.core.models.roboflow import (
DEFAULT_COLOR_PALETTE,
RoboflowCoreModel,
Expand Down Expand Up @@ -256,7 +262,7 @@ class OwlV2(RoboflowCoreModel):
task_type = "object-detection"
box_format = "xywh"

def __init__(self, *args, model_id="owlv2/owlv2-base-patch16-ensemble", **kwargs):
def __init__(self, *args, model_id=f"owlv2/{OWLV2_VERSION_ID}", **kwargs):
super().__init__(*args, model_id=model_id, **kwargs)
hf_id = os.path.join("google", self.version_id)
processor = Owlv2Processor.from_pretrained(hf_id)
Expand All @@ -281,11 +287,11 @@ def __init__(self, *args, model_id="owlv2/owlv2-base-patch16-ensemble", **kwargs

def reset_cache(self):
# each entry should be on the order of 300*4KB, so 1000 is 400MB of CUDA memory
self.image_embed_cache = LimitedSizeDict(size_limit=1000)
self.image_embed_cache = LimitedSizeDict(size_limit=OWLV2_IMAGE_CACHE_SIZE)
# each entry should be on the order of 10 bytes, so 1000 is 10KB
self.image_size_cache = LimitedSizeDict(size_limit=1000)
# entry size will vary depending on the number of samples, but 100 should be safe
self.class_embeddings_cache = LimitedSizeDict(size_limit=100)
self.image_size_cache = LimitedSizeDict(size_limit=OWLV2_IMAGE_CACHE_SIZE)
# entry size will vary depending on the number of samples, but 10 should be safe
self.class_embeddings_cache = LimitedSizeDict(size_limit=OWLV2_MODEL_CACHE_SIZE)

def draw_predictions(
self,
Expand Down

0 comments on commit fc9c75c

Please sign in to comment.