ultralytic export
This commit is contained in:
parent
bf7b68edb1
commit
fdaeb9981c
14 changed files with 2241 additions and 507 deletions
|
|
@ -5,25 +5,28 @@ This module provides high-level connection management for multiple RTSP streams,
|
|||
coordinating decoders, batched inference, and tracking with callbacks and threading.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from typing import Dict, Optional, Callable, Tuple, Any, List
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from .model_controller import ModelController
|
||||
from .stream_decoder import StreamDecoderFactory
|
||||
from .base_model_controller import BaseModelController
|
||||
from .model_repository import TensorRTModelRepository
|
||||
from .stream_decoder import StreamDecoderFactory
|
||||
from .tensorrt_model_controller import TensorRTModelController
|
||||
from .ultralytics_model_controller import UltralyticsModelController
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConnectionStatus(Enum):
|
||||
"""Status of a stream connection"""
|
||||
|
||||
CONNECTING = "connecting"
|
||||
CONNECTED = "connected"
|
||||
DISCONNECTED = "disconnected"
|
||||
|
|
@ -33,6 +36,7 @@ class ConnectionStatus(Enum):
|
|||
@dataclass
|
||||
class TrackingResult:
|
||||
"""Result emitted to user callbacks"""
|
||||
|
||||
stream_id: str
|
||||
timestamp: float
|
||||
tracked_objects: List # List of TrackedObject from TrackingController
|
||||
|
|
@ -61,7 +65,7 @@ class StreamConnection:
|
|||
self,
|
||||
stream_id: str,
|
||||
decoder,
|
||||
model_controller: ModelController,
|
||||
model_controller: BaseModelController,
|
||||
tracking_controller,
|
||||
poll_interval: float = 0.01,
|
||||
):
|
||||
|
|
@ -107,7 +111,9 @@ class StreamConnection:
|
|||
break
|
||||
else:
|
||||
# Timeout - but don't fail hard, let it try to connect in background
|
||||
logger.warning(f"Stream {self.stream_id} not connected after {max_wait}s, will continue trying...")
|
||||
logger.warning(
|
||||
f"Stream {self.stream_id} not connected after {max_wait}s, will continue trying..."
|
||||
)
|
||||
self.status = ConnectionStatus.CONNECTING
|
||||
|
||||
def stop(self):
|
||||
|
|
@ -144,28 +150,42 @@ class StreamConnection:
|
|||
self.last_frame_time = time.time()
|
||||
self.frame_count += 1
|
||||
|
||||
# CRITICAL: Clone the GPU tensor to decouple from decoder's frame buffer
|
||||
# The decoder reuses frame buffer memory, so we must copy the tensor
|
||||
# before submitting to async batched inference to prevent race conditions
|
||||
# where the decoder overwrites memory while inference is still reading it.
|
||||
cloned_tensor = frame_ref.rgb_tensor.clone()
|
||||
|
||||
# Submit to model controller for batched inference
|
||||
# Pass the FrameReference in metadata so we can free it later
|
||||
self.model_controller.submit_frame(
|
||||
stream_id=self.stream_id,
|
||||
frame=frame_ref.rgb_tensor,
|
||||
frame=cloned_tensor, # Use cloned tensor, not original
|
||||
metadata={
|
||||
"frame_number": self.frame_count,
|
||||
"shape": tuple(frame_ref.rgb_tensor.shape),
|
||||
"shape": tuple(cloned_tensor.shape),
|
||||
"frame_ref": frame_ref, # Store reference for later cleanup
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Update connection status based on decoder status
|
||||
if self.decoder.is_connected() and self.status != ConnectionStatus.CONNECTED:
|
||||
if (
|
||||
self.decoder.is_connected()
|
||||
and self.status != ConnectionStatus.CONNECTED
|
||||
):
|
||||
logger.info(f"Stream {self.stream_id} reconnected")
|
||||
self.status = ConnectionStatus.CONNECTED
|
||||
elif not self.decoder.is_connected() and self.status == ConnectionStatus.CONNECTED:
|
||||
elif (
|
||||
not self.decoder.is_connected()
|
||||
and self.status == ConnectionStatus.CONNECTED
|
||||
):
|
||||
logger.warning(f"Stream {self.stream_id} disconnected")
|
||||
self.status = ConnectionStatus.DISCONNECTED
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing frame for {self.stream_id}: {e}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error processing frame for {self.stream_id}: {e}", exc_info=True
|
||||
)
|
||||
self.error_queue.put(e)
|
||||
self.status = ConnectionStatus.ERROR
|
||||
# Free the frame on error
|
||||
|
|
@ -205,7 +225,10 @@ class StreamConnection:
|
|||
self.result_queue.put(tracking_result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling inference result for {self.stream_id}: {e}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error handling inference result for {self.stream_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
self.error_queue.put(e)
|
||||
finally:
|
||||
# Free the frame reference - this is the last point in the pipeline
|
||||
|
|
@ -235,12 +258,16 @@ class StreamConnection:
|
|||
if confidence < min_confidence:
|
||||
continue
|
||||
|
||||
detection_list.append(Detection(
|
||||
bbox=det[:4].cpu().tolist(),
|
||||
confidence=confidence,
|
||||
class_id=int(det[5]) if det.shape[0] > 5 else 0,
|
||||
class_name=f"class_{int(det[5])}" if det.shape[0] > 5 else "unknown"
|
||||
))
|
||||
detection_list.append(
|
||||
Detection(
|
||||
bbox=det[:4].cpu().tolist(),
|
||||
confidence=confidence,
|
||||
class_id=int(det[5]) if det.shape[0] > 5 else 0,
|
||||
class_name=f"class_{int(det[5])}"
|
||||
if det.shape[0] > 5
|
||||
else "unknown",
|
||||
)
|
||||
)
|
||||
|
||||
# Update tracker with detections (will scale bboxes to frame space)
|
||||
return self.tracking_controller.update(detection_list, frame_shape=frame_shape)
|
||||
|
|
@ -319,21 +346,38 @@ class StreamConnectionManager:
|
|||
force_timeout: float = 0.05,
|
||||
poll_interval: float = 0.01,
|
||||
enable_pt_conversion: bool = True,
|
||||
backend: str = "tensorrt", # "tensorrt" or "ultralytics"
|
||||
):
|
||||
self.gpu_id = gpu_id
|
||||
self.batch_size = batch_size
|
||||
self.force_timeout = force_timeout
|
||||
self.poll_interval = poll_interval
|
||||
self.backend = backend.lower()
|
||||
|
||||
# Factories
|
||||
self.decoder_factory = StreamDecoderFactory(gpu_id=gpu_id)
|
||||
self.model_repository = TensorRTModelRepository(
|
||||
gpu_id=gpu_id,
|
||||
enable_pt_conversion=enable_pt_conversion
|
||||
)
|
||||
|
||||
# Initialize inference engine based on backend
|
||||
self.inference_engine = None
|
||||
self.model_repository = None # Legacy - will be removed
|
||||
|
||||
if self.backend == "ultralytics":
|
||||
# Use Ultralytics native YOLO inference
|
||||
from .inference_engine import UltralyticsEngine
|
||||
|
||||
self.inference_engine = UltralyticsEngine()
|
||||
logger.info("Using Ultralytics inference engine")
|
||||
else:
|
||||
# Use native TensorRT inference
|
||||
self.model_repository = TensorRTModelRepository(
|
||||
gpu_id=gpu_id, enable_pt_conversion=enable_pt_conversion
|
||||
)
|
||||
logger.info("Using native TensorRT inference engine")
|
||||
|
||||
# Controllers
|
||||
self.model_controller: Optional[ModelController] = None
|
||||
self.model_controller = (
|
||||
None # Will be TensorRTModelController or UltralyticsModelController
|
||||
)
|
||||
|
||||
# Connections
|
||||
self.connections: Dict[str, StreamConnection] = {}
|
||||
|
|
@ -350,7 +394,7 @@ class StreamConnectionManager:
|
|||
num_contexts: int = 4,
|
||||
pt_input_shapes: Optional[Dict] = None,
|
||||
pt_precision: Optional[Any] = None,
|
||||
**pt_conversion_kwargs
|
||||
**pt_conversion_kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the manager with a model.
|
||||
|
|
@ -382,28 +426,58 @@ class StreamConnectionManager:
|
|||
)
|
||||
"""
|
||||
logger.info(f"Initializing StreamConnectionManager on GPU {self.gpu_id}")
|
||||
logger.info(f"Backend: {self.backend}")
|
||||
|
||||
# Load model (synchronous)
|
||||
self.model_repository.load_model(
|
||||
model_id,
|
||||
model_path,
|
||||
num_contexts=num_contexts,
|
||||
pt_input_shapes=pt_input_shapes,
|
||||
pt_precision=pt_precision,
|
||||
**pt_conversion_kwargs
|
||||
)
|
||||
logger.info(f"Loaded model {model_id} from {model_path}")
|
||||
# Initialize engine based on backend
|
||||
if self.backend == "ultralytics":
|
||||
# Use Ultralytics native inference
|
||||
logger.info("Initializing Ultralytics YOLO engine...")
|
||||
device = torch.device(f"cuda:{self.gpu_id}")
|
||||
|
||||
# Create model controller
|
||||
self.model_controller = ModelController(
|
||||
model_repository=self.model_repository,
|
||||
model_id=model_id,
|
||||
batch_size=self.batch_size,
|
||||
force_timeout=self.force_timeout,
|
||||
preprocess_fn=preprocess_fn,
|
||||
postprocess_fn=postprocess_fn,
|
||||
)
|
||||
self.model_controller.start()
|
||||
metadata = self.inference_engine.initialize(
|
||||
model_path=model_path,
|
||||
device=device,
|
||||
batch=self.batch_size,
|
||||
half=False, # Use FP32 for now
|
||||
imgsz=640,
|
||||
**pt_conversion_kwargs,
|
||||
)
|
||||
logger.info(f"Ultralytics engine initialized: {metadata}")
|
||||
|
||||
# Create Ultralytics model controller
|
||||
self.model_controller = UltralyticsModelController(
|
||||
inference_engine=self.inference_engine,
|
||||
model_id=model_id,
|
||||
batch_size=self.batch_size,
|
||||
force_timeout=self.force_timeout,
|
||||
preprocess_fn=preprocess_fn,
|
||||
postprocess_fn=postprocess_fn,
|
||||
)
|
||||
self.model_controller.start()
|
||||
|
||||
else:
|
||||
# Use native TensorRT with model repository
|
||||
logger.info("Initializing TensorRT engine...")
|
||||
self.model_repository.load_model(
|
||||
model_id,
|
||||
model_path,
|
||||
num_contexts=num_contexts,
|
||||
pt_input_shapes=pt_input_shapes,
|
||||
pt_precision=pt_precision,
|
||||
**pt_conversion_kwargs,
|
||||
)
|
||||
logger.info(f"Loaded model {model_id} from {model_path}")
|
||||
|
||||
# Create TensorRT model controller
|
||||
self.model_controller = TensorRTModelController(
|
||||
model_repository=self.model_repository,
|
||||
model_id=model_id,
|
||||
batch_size=self.batch_size,
|
||||
force_timeout=self.force_timeout,
|
||||
preprocess_fn=preprocess_fn,
|
||||
postprocess_fn=postprocess_fn,
|
||||
)
|
||||
self.model_controller.start()
|
||||
|
||||
# Don't create a shared tracking controller here
|
||||
# Each stream will get its own tracking controller to avoid track accumulation
|
||||
|
|
@ -452,12 +526,13 @@ class StreamConnectionManager:
|
|||
|
||||
# Create lightweight tracker (NO model_repository dependency!)
|
||||
from .tracking_controller import ObjectTracker
|
||||
|
||||
tracking_controller = ObjectTracker(
|
||||
gpu_id=self.gpu_id,
|
||||
tracker_type="iou",
|
||||
max_age=30,
|
||||
iou_threshold=0.3,
|
||||
class_names=None # TODO: pass class names if available
|
||||
class_names=None, # TODO: pass class names if available
|
||||
)
|
||||
logger.info(f"Created lightweight ObjectTracker for stream {stream_id}")
|
||||
|
||||
|
|
@ -472,8 +547,7 @@ class StreamConnectionManager:
|
|||
|
||||
# Register callback with model controller
|
||||
self.model_controller.register_callback(
|
||||
stream_id,
|
||||
connection._handle_inference_result
|
||||
stream_id, connection._handle_inference_result
|
||||
)
|
||||
|
||||
# Start connection
|
||||
|
|
@ -487,14 +561,12 @@ class StreamConnectionManager:
|
|||
threading.Thread(
|
||||
target=self._forward_results,
|
||||
args=(connection, on_tracking_result),
|
||||
daemon=True
|
||||
daemon=True,
|
||||
).start()
|
||||
|
||||
if on_error:
|
||||
threading.Thread(
|
||||
target=self._forward_errors,
|
||||
args=(connection, on_error),
|
||||
daemon=True
|
||||
target=self._forward_errors, args=(connection, on_error), daemon=True
|
||||
).start()
|
||||
|
||||
logger.info(f"Stream {stream_id} connected successfully")
|
||||
|
|
@ -549,7 +621,10 @@ class StreamConnectionManager:
|
|||
for result in connection.tracking_results():
|
||||
callback(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in result forwarding for {connection.stream_id}: {e}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error in result forwarding for {connection.stream_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def _forward_errors(self, connection: StreamConnection, callback: Callable):
|
||||
"""
|
||||
|
|
@ -563,7 +638,10 @@ class StreamConnectionManager:
|
|||
for error in connection.errors():
|
||||
callback(error)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in error forwarding for {connection.stream_id}: {e}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error in error forwarding for {connection.stream_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
|
|
@ -581,7 +659,9 @@ class StreamConnectionManager:
|
|||
"force_timeout": self.force_timeout,
|
||||
"poll_interval": self.poll_interval,
|
||||
},
|
||||
"model_controller": self.model_controller.get_stats() if self.model_controller else {},
|
||||
"model_controller": self.model_controller.get_stats()
|
||||
if self.model_controller
|
||||
else {},
|
||||
"connections": {
|
||||
stream_id: conn.get_stats()
|
||||
for stream_id, conn in self.connections.items()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue