ultralytic export

This commit is contained in:
Siwat Sirichai 2025-11-11 01:28:19 +07:00
parent bf7b68edb1
commit fdaeb9981c
14 changed files with 2241 additions and 507 deletions

View file

@ -5,25 +5,28 @@ This module provides high-level connection management for multiple RTSP streams,
coordinating decoders, batched inference, and tracking with callbacks and threading.
"""
import threading
import time
from typing import Dict, Optional, Callable, Tuple, Any, List
from dataclasses import dataclass
from enum import Enum
import logging
import queue
import threading
import time
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from .model_controller import ModelController
from .stream_decoder import StreamDecoderFactory
from .base_model_controller import BaseModelController
from .model_repository import TensorRTModelRepository
from .stream_decoder import StreamDecoderFactory
from .tensorrt_model_controller import TensorRTModelController
from .ultralytics_model_controller import UltralyticsModelController
logger = logging.getLogger(__name__)
class ConnectionStatus(Enum):
"""Status of a stream connection"""
CONNECTING = "connecting"
CONNECTED = "connected"
DISCONNECTED = "disconnected"
@ -33,6 +36,7 @@ class ConnectionStatus(Enum):
@dataclass
class TrackingResult:
"""Result emitted to user callbacks"""
stream_id: str
timestamp: float
tracked_objects: List # List of TrackedObject from TrackingController
@ -61,7 +65,7 @@ class StreamConnection:
self,
stream_id: str,
decoder,
model_controller: ModelController,
model_controller: BaseModelController,
tracking_controller,
poll_interval: float = 0.01,
):
@ -107,7 +111,9 @@ class StreamConnection:
break
else:
# Timeout - but don't fail hard, let it try to connect in background
logger.warning(f"Stream {self.stream_id} not connected after {max_wait}s, will continue trying...")
logger.warning(
f"Stream {self.stream_id} not connected after {max_wait}s, will continue trying..."
)
self.status = ConnectionStatus.CONNECTING
def stop(self):
@ -144,28 +150,42 @@ class StreamConnection:
self.last_frame_time = time.time()
self.frame_count += 1
# CRITICAL: Clone the GPU tensor to decouple from decoder's frame buffer
# The decoder reuses frame buffer memory, so we must copy the tensor
# before submitting to async batched inference to prevent race conditions
# where the decoder overwrites memory while inference is still reading it.
cloned_tensor = frame_ref.rgb_tensor.clone()
# Submit to model controller for batched inference
# Pass the FrameReference in metadata so we can free it later
self.model_controller.submit_frame(
stream_id=self.stream_id,
frame=frame_ref.rgb_tensor,
frame=cloned_tensor, # Use cloned tensor, not original
metadata={
"frame_number": self.frame_count,
"shape": tuple(frame_ref.rgb_tensor.shape),
"shape": tuple(cloned_tensor.shape),
"frame_ref": frame_ref, # Store reference for later cleanup
}
},
)
# Update connection status based on decoder status
if self.decoder.is_connected() and self.status != ConnectionStatus.CONNECTED:
if (
self.decoder.is_connected()
and self.status != ConnectionStatus.CONNECTED
):
logger.info(f"Stream {self.stream_id} reconnected")
self.status = ConnectionStatus.CONNECTED
elif not self.decoder.is_connected() and self.status == ConnectionStatus.CONNECTED:
elif (
not self.decoder.is_connected()
and self.status == ConnectionStatus.CONNECTED
):
logger.warning(f"Stream {self.stream_id} disconnected")
self.status = ConnectionStatus.DISCONNECTED
except Exception as e:
logger.error(f"Error processing frame for {self.stream_id}: {e}", exc_info=True)
logger.error(
f"Error processing frame for {self.stream_id}: {e}", exc_info=True
)
self.error_queue.put(e)
self.status = ConnectionStatus.ERROR
# Free the frame on error
@ -205,7 +225,10 @@ class StreamConnection:
self.result_queue.put(tracking_result)
except Exception as e:
logger.error(f"Error handling inference result for {self.stream_id}: {e}", exc_info=True)
logger.error(
f"Error handling inference result for {self.stream_id}: {e}",
exc_info=True,
)
self.error_queue.put(e)
finally:
# Free the frame reference - this is the last point in the pipeline
@ -235,12 +258,16 @@ class StreamConnection:
if confidence < min_confidence:
continue
detection_list.append(Detection(
bbox=det[:4].cpu().tolist(),
confidence=confidence,
class_id=int(det[5]) if det.shape[0] > 5 else 0,
class_name=f"class_{int(det[5])}" if det.shape[0] > 5 else "unknown"
))
detection_list.append(
Detection(
bbox=det[:4].cpu().tolist(),
confidence=confidence,
class_id=int(det[5]) if det.shape[0] > 5 else 0,
class_name=f"class_{int(det[5])}"
if det.shape[0] > 5
else "unknown",
)
)
# Update tracker with detections (will scale bboxes to frame space)
return self.tracking_controller.update(detection_list, frame_shape=frame_shape)
@ -319,21 +346,38 @@ class StreamConnectionManager:
force_timeout: float = 0.05,
poll_interval: float = 0.01,
enable_pt_conversion: bool = True,
backend: str = "tensorrt", # "tensorrt" or "ultralytics"
):
self.gpu_id = gpu_id
self.batch_size = batch_size
self.force_timeout = force_timeout
self.poll_interval = poll_interval
self.backend = backend.lower()
# Factories
self.decoder_factory = StreamDecoderFactory(gpu_id=gpu_id)
self.model_repository = TensorRTModelRepository(
gpu_id=gpu_id,
enable_pt_conversion=enable_pt_conversion
)
# Initialize inference engine based on backend
self.inference_engine = None
self.model_repository = None # Legacy - will be removed
if self.backend == "ultralytics":
# Use Ultralytics native YOLO inference
from .inference_engine import UltralyticsEngine
self.inference_engine = UltralyticsEngine()
logger.info("Using Ultralytics inference engine")
else:
# Use native TensorRT inference
self.model_repository = TensorRTModelRepository(
gpu_id=gpu_id, enable_pt_conversion=enable_pt_conversion
)
logger.info("Using native TensorRT inference engine")
# Controllers
self.model_controller: Optional[ModelController] = None
self.model_controller = (
None # Will be TensorRTModelController or UltralyticsModelController
)
# Connections
self.connections: Dict[str, StreamConnection] = {}
@ -350,7 +394,7 @@ class StreamConnectionManager:
num_contexts: int = 4,
pt_input_shapes: Optional[Dict] = None,
pt_precision: Optional[Any] = None,
**pt_conversion_kwargs
**pt_conversion_kwargs,
):
"""
Initialize the manager with a model.
@ -382,28 +426,58 @@ class StreamConnectionManager:
)
"""
logger.info(f"Initializing StreamConnectionManager on GPU {self.gpu_id}")
logger.info(f"Backend: {self.backend}")
# Load model (synchronous)
self.model_repository.load_model(
model_id,
model_path,
num_contexts=num_contexts,
pt_input_shapes=pt_input_shapes,
pt_precision=pt_precision,
**pt_conversion_kwargs
)
logger.info(f"Loaded model {model_id} from {model_path}")
# Initialize engine based on backend
if self.backend == "ultralytics":
# Use Ultralytics native inference
logger.info("Initializing Ultralytics YOLO engine...")
device = torch.device(f"cuda:{self.gpu_id}")
# Create model controller
self.model_controller = ModelController(
model_repository=self.model_repository,
model_id=model_id,
batch_size=self.batch_size,
force_timeout=self.force_timeout,
preprocess_fn=preprocess_fn,
postprocess_fn=postprocess_fn,
)
self.model_controller.start()
metadata = self.inference_engine.initialize(
model_path=model_path,
device=device,
batch=self.batch_size,
half=False, # Use FP32 for now
imgsz=640,
**pt_conversion_kwargs,
)
logger.info(f"Ultralytics engine initialized: {metadata}")
# Create Ultralytics model controller
self.model_controller = UltralyticsModelController(
inference_engine=self.inference_engine,
model_id=model_id,
batch_size=self.batch_size,
force_timeout=self.force_timeout,
preprocess_fn=preprocess_fn,
postprocess_fn=postprocess_fn,
)
self.model_controller.start()
else:
# Use native TensorRT with model repository
logger.info("Initializing TensorRT engine...")
self.model_repository.load_model(
model_id,
model_path,
num_contexts=num_contexts,
pt_input_shapes=pt_input_shapes,
pt_precision=pt_precision,
**pt_conversion_kwargs,
)
logger.info(f"Loaded model {model_id} from {model_path}")
# Create TensorRT model controller
self.model_controller = TensorRTModelController(
model_repository=self.model_repository,
model_id=model_id,
batch_size=self.batch_size,
force_timeout=self.force_timeout,
preprocess_fn=preprocess_fn,
postprocess_fn=postprocess_fn,
)
self.model_controller.start()
# Don't create a shared tracking controller here
# Each stream will get its own tracking controller to avoid track accumulation
@ -452,12 +526,13 @@ class StreamConnectionManager:
# Create lightweight tracker (NO model_repository dependency!)
from .tracking_controller import ObjectTracker
tracking_controller = ObjectTracker(
gpu_id=self.gpu_id,
tracker_type="iou",
max_age=30,
iou_threshold=0.3,
class_names=None # TODO: pass class names if available
class_names=None, # TODO: pass class names if available
)
logger.info(f"Created lightweight ObjectTracker for stream {stream_id}")
@ -472,8 +547,7 @@ class StreamConnectionManager:
# Register callback with model controller
self.model_controller.register_callback(
stream_id,
connection._handle_inference_result
stream_id, connection._handle_inference_result
)
# Start connection
@ -487,14 +561,12 @@ class StreamConnectionManager:
threading.Thread(
target=self._forward_results,
args=(connection, on_tracking_result),
daemon=True
daemon=True,
).start()
if on_error:
threading.Thread(
target=self._forward_errors,
args=(connection, on_error),
daemon=True
target=self._forward_errors, args=(connection, on_error), daemon=True
).start()
logger.info(f"Stream {stream_id} connected successfully")
@ -549,7 +621,10 @@ class StreamConnectionManager:
for result in connection.tracking_results():
callback(result)
except Exception as e:
logger.error(f"Error in result forwarding for {connection.stream_id}: {e}", exc_info=True)
logger.error(
f"Error in result forwarding for {connection.stream_id}: {e}",
exc_info=True,
)
def _forward_errors(self, connection: StreamConnection, callback: Callable):
"""
@ -563,7 +638,10 @@ class StreamConnectionManager:
for error in connection.errors():
callback(error)
except Exception as e:
logger.error(f"Error in error forwarding for {connection.stream_id}: {e}", exc_info=True)
logger.error(
f"Error in error forwarding for {connection.stream_id}: {e}",
exc_info=True,
)
def get_stats(self) -> Dict[str, Any]:
"""
@ -581,7 +659,9 @@ class StreamConnectionManager:
"force_timeout": self.force_timeout,
"poll_interval": self.poll_interval,
},
"model_controller": self.model_controller.get_stats() if self.model_controller else {},
"model_controller": self.model_controller.get_stats()
if self.model_controller
else {},
"connections": {
stream_id: conn.get_stats()
for stream_id, conn in self.connections.items()