fix: model calling method
All checks were successful
Build Worker Base and Application Images / check-base-changes (push) Successful in 8s
Build Worker Base and Application Images / build-base (push) Has been skipped
Build Worker Base and Application Images / build-docker (push) Successful in 2m44s
Build Worker Base and Application Images / deploy-stack (push) Successful in 9s

This commit is contained in:
ziesorx 2025-09-25 15:06:41 +07:00
parent 5bb68b6e10
commit 2e5316ca01
3 changed files with 82 additions and 33 deletions

View file

@ -81,8 +81,28 @@ class YOLOWrapper:
from ultralytics import YOLO
logger.info(f"Loading YOLO model from {self.model_path}")
# Load model normally first
self.model = YOLO(str(self.model_path))
# Determine if this is a classification model based on filename or model structure
# Classification models typically have 'cls' in filename
is_classification = 'cls' in str(self.model_path).lower()
# For classification models, create a separate instance with task parameter
if is_classification:
try:
# Reload with classification task (like ML engineer's approach)
self.model = YOLO(str(self.model_path), task="classify")
logger.info(f"Loaded classification model {self.model_id} with task='classify'")
except Exception as e:
logger.warning(f"Failed to load with task='classify', using default: {e}")
# Fall back to regular loading
self.model = YOLO(str(self.model_path))
logger.info(f"Loaded model {self.model_id} with default task")
else:
logger.info(f"Loaded detection model {self.model_id}")
# Move model to device
if self.device == 'cuda' and torch.cuda.is_available():
self.model.to('cuda')
@ -141,7 +161,7 @@ class YOLOWrapper:
import time
start_time = time.time()
# Run inference
# Run inference using direct model call (like ML engineer's approach)
results = self.model(
image,
conf=confidence_threshold,
@ -291,11 +311,11 @@ class YOLOWrapper:
raise RuntimeError(f"Model {self.model_id} not loaded")
try:
# Run inference
results = self.model(image, verbose=False)
# Run inference using predict method for classification (like ML engineer's approach)
results = self.model.predict(source=image, verbose=False)
# For classification models, extract probabilities
if hasattr(results[0], 'probs'):
if results and len(results) > 0 and hasattr(results[0], 'probs') and results[0].probs is not None:
probs = results[0].probs
top_indices = probs.top5[:top_k]
top_conf = probs.top5conf[:top_k].cpu().numpy()
@ -307,7 +327,7 @@ class YOLOWrapper:
return predictions
else:
logger.warning(f"Model {self.model_id} does not support classification")
logger.warning(f"Model {self.model_id} does not support classification or no probs found")
return {}
except Exception as e:
@ -350,6 +370,10 @@ class YOLOWrapper:
"""Get the number of classes the model can detect"""
return len(self._class_names)
def is_classification_model(self) -> bool:
"""Check if this is a classification model"""
return 'cls' in str(self.model_path).lower() or 'classify' in str(self.model_path).lower()
def clear_cache(self) -> None:
"""Clear the model cache"""
with self._cache_lock: