Refactor: done phase 4
This commit is contained in:
parent
7e8034c6e5
commit
9e4c23c75c
8 changed files with 1533 additions and 37 deletions
|
@ -358,4 +358,82 @@ class ModelManager:
|
|||
Returns:
|
||||
Set of model IDs that are currently downloaded
|
||||
"""
|
||||
return self._downloaded_models.copy()
|
||||
return self._downloaded_models.copy()
|
||||
|
||||
def get_pipeline_config(self, model_id: int) -> Optional[Any]:
|
||||
"""
|
||||
Get the pipeline configuration for a model.
|
||||
|
||||
Args:
|
||||
model_id: The model ID
|
||||
|
||||
Returns:
|
||||
PipelineConfig object if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
if model_id not in self._downloaded_models:
|
||||
logger.warning(f"Model {model_id} not downloaded")
|
||||
return None
|
||||
|
||||
model_path = self._model_paths.get(model_id)
|
||||
if not model_path:
|
||||
logger.warning(f"Model path not found for model {model_id}")
|
||||
return None
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from .pipeline import PipelineParser
|
||||
|
||||
# Load pipeline.json
|
||||
pipeline_file = model_path / "pipeline.json"
|
||||
if not pipeline_file.exists():
|
||||
logger.warning(f"No pipeline.json found for model {model_id}")
|
||||
return None
|
||||
|
||||
# Create PipelineParser object and parse the configuration
|
||||
pipeline_parser = PipelineParser()
|
||||
success = pipeline_parser.parse(pipeline_file)
|
||||
|
||||
if success:
|
||||
return pipeline_parser
|
||||
else:
|
||||
logger.error(f"Failed to parse pipeline.json for model {model_id}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting pipeline config for model {model_id}: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def get_yolo_model(self, model_id: int, model_filename: str) -> Optional[Any]:
|
||||
"""
|
||||
Create a YOLOWrapper instance for a specific model file.
|
||||
|
||||
Args:
|
||||
model_id: The model ID
|
||||
model_filename: The .pt model filename
|
||||
|
||||
Returns:
|
||||
YOLOWrapper instance if successful, None otherwise
|
||||
"""
|
||||
try:
|
||||
# Get the model file path
|
||||
model_file_path = self.get_model_file_path(model_id, model_filename)
|
||||
if not model_file_path or not model_file_path.exists():
|
||||
logger.error(f"Model file {model_filename} not found for model {model_id}")
|
||||
return None
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from .inference import YOLOWrapper
|
||||
|
||||
# Create YOLOWrapper instance
|
||||
yolo_model = YOLOWrapper(
|
||||
model_path=model_file_path,
|
||||
model_id=f"{model_id}_{model_filename}",
|
||||
device=None # Auto-detect device
|
||||
)
|
||||
|
||||
logger.info(f"Created YOLOWrapper for model {model_id}: {model_filename}")
|
||||
return yolo_model
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating YOLO model for {model_id}:{model_filename}: {e}", exc_info=True)
|
||||
return None
|
Loading…
Add table
Add a link
Reference in a new issue