fix: gpu memory leaks

This commit is contained in:
Siwat Sirichai 2025-11-10 22:10:46 +07:00
parent 3a47920186
commit 593611cdb7
13 changed files with 420 additions and 166 deletions

View file

@ -127,15 +127,17 @@ class StreamConnection:
self.status = ConnectionStatus.DISCONNECTED
logger.info(f"Stream {self.stream_id} stopped")
def _on_frame_decoded(self, frame: torch.Tensor):
def _on_frame_decoded(self, frame_ref):
"""
Event handler called by decoder when a new frame is decoded.
This is the event-driven replacement for polling.
Args:
frame: RGB frame tensor on GPU (3, H, W)
frame_ref: FrameReference object containing the RGB frame tensor
"""
if not self.running:
# If not running, free the frame immediately
frame_ref.free()
return
try:
@ -143,12 +145,14 @@ class StreamConnection:
self.frame_count += 1
# Submit to model controller for batched inference
# Pass the FrameReference in metadata so we can free it later
self.model_controller.submit_frame(
stream_id=self.stream_id,
frame=frame,
frame=frame_ref.rgb_tensor,
metadata={
"frame_number": self.frame_count,
"shape": tuple(frame.shape),
"shape": tuple(frame_ref.rgb_tensor.shape),
"frame_ref": frame_ref, # Store reference for later cleanup
}
)
@ -164,6 +168,8 @@ class StreamConnection:
logger.error(f"Error processing frame for {self.stream_id}: {e}", exc_info=True)
self.error_queue.put(e)
self.status = ConnectionStatus.ERROR
# Free the frame on error
frame_ref.free()
def _handle_inference_result(self, result: Dict[str, Any]):
"""
@ -173,12 +179,17 @@ class StreamConnection:
Args:
result: Inference result dictionary
"""
frame_ref = None
try:
# Extract detections
detections = result["detections"]
# Run tracking (synchronous)
tracked_objects = self._run_tracking_sync(detections)
# Get FrameReference from metadata (if present)
frame_ref = result["metadata"].get("frame_ref")
# Run tracking (synchronous) with frame shape for bbox scaling
frame_shape = result["metadata"].get("shape")
tracked_objects = self._run_tracking_sync(detections, frame_shape)
# Create tracking result
tracking_result = TrackingResult(
@ -196,13 +207,18 @@ class StreamConnection:
except Exception as e:
logger.error(f"Error handling inference result for {self.stream_id}: {e}", exc_info=True)
self.error_queue.put(e)
finally:
# Free the frame reference - this is the last point in the pipeline
if frame_ref is not None:
frame_ref.free()
def _run_tracking_sync(self, detections, min_confidence=0.7):
def _run_tracking_sync(self, detections, frame_shape=None, min_confidence=0.7):
"""
Run tracking synchronously (called from executor).
Args:
detections: Detection tensor (N, 6) [x1, y1, x2, y2, conf, class_id]
frame_shape: Original frame shape (C, H, W) for scaling bboxes
min_confidence: Minimum confidence threshold for detections
Returns:
@ -226,8 +242,8 @@ class StreamConnection:
class_name=f"class_{int(det[5])}" if det.shape[0] > 5 else "unknown"
))
# Update tracker with detections (lightweight, no model dependency!)
return self.tracking_controller.update(detection_list)
# Update tracker with detections (will scale bboxes to frame space)
return self.tracking_controller.update(detection_list, frame_shape=frame_shape)
def tracking_results(self):
"""
@ -339,15 +355,31 @@ class StreamConnectionManager:
"""
Initialize the manager with a model.
Supports transparent loading of .pt (YOLO), .engine, and .trt files.
For Ultralytics YOLO models (.pt), metadata is auto-detected - no manual
input_shapes or precision needed! Non-YOLO models still require input_shapes.
Args:
model_path: Path to TensorRT or PyTorch model file (.trt, .pt, .pth)
model_path: Path to model file (.trt, .engine, .pt, .pth)
- .engine: Ultralytics native format (recommended)
- .pt: Auto-converts to .engine (YOLO models only)
- .trt: Raw TensorRT engine
model_id: Model identifier (default: "detector")
preprocess_fn: Preprocessing function (e.g., YOLOv8Utils.preprocess)
postprocess_fn: Postprocessing function (e.g., YOLOv8Utils.postprocess)
num_contexts: Number of TensorRT execution contexts (default: 4)
pt_input_shapes: Required for PT files - dict of input shapes
pt_precision: Precision for PT conversion (torch.float16 or torch.float32)
pt_input_shapes: [Optional] Only required for non-YOLO PyTorch models
YOLO models auto-detect from embedded metadata
pt_precision: [Optional] Precision for PT conversion (auto-detected for YOLO)
**pt_conversion_kwargs: Additional PT conversion arguments
Example:
# YOLO model - no manual parameters needed:
manager.initialize(
model_path="model.pt", # or .engine
preprocess_fn=YOLOv8Utils.preprocess,
postprocess_fn=YOLOv8Utils.postprocess
)
"""
logger.info(f"Initializing StreamConnectionManager on GPU {self.gpu_id}")