fix: gpu memory leaks
This commit is contained in:
parent
3a47920186
commit
593611cdb7
13 changed files with 420 additions and 166 deletions
|
|
@ -127,15 +127,17 @@ class StreamConnection:
|
|||
self.status = ConnectionStatus.DISCONNECTED
|
||||
logger.info(f"Stream {self.stream_id} stopped")
|
||||
|
||||
def _on_frame_decoded(self, frame: torch.Tensor):
|
||||
def _on_frame_decoded(self, frame_ref):
|
||||
"""
|
||||
Event handler called by decoder when a new frame is decoded.
|
||||
This is the event-driven replacement for polling.
|
||||
|
||||
Args:
|
||||
frame: RGB frame tensor on GPU (3, H, W)
|
||||
frame_ref: FrameReference object containing the RGB frame tensor
|
||||
"""
|
||||
if not self.running:
|
||||
# If not running, free the frame immediately
|
||||
frame_ref.free()
|
||||
return
|
||||
|
||||
try:
|
||||
|
|
@ -143,12 +145,14 @@ class StreamConnection:
|
|||
self.frame_count += 1
|
||||
|
||||
# Submit to model controller for batched inference
|
||||
# Pass the FrameReference in metadata so we can free it later
|
||||
self.model_controller.submit_frame(
|
||||
stream_id=self.stream_id,
|
||||
frame=frame,
|
||||
frame=frame_ref.rgb_tensor,
|
||||
metadata={
|
||||
"frame_number": self.frame_count,
|
||||
"shape": tuple(frame.shape),
|
||||
"shape": tuple(frame_ref.rgb_tensor.shape),
|
||||
"frame_ref": frame_ref, # Store reference for later cleanup
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -164,6 +168,8 @@ class StreamConnection:
|
|||
logger.error(f"Error processing frame for {self.stream_id}: {e}", exc_info=True)
|
||||
self.error_queue.put(e)
|
||||
self.status = ConnectionStatus.ERROR
|
||||
# Free the frame on error
|
||||
frame_ref.free()
|
||||
|
||||
def _handle_inference_result(self, result: Dict[str, Any]):
|
||||
"""
|
||||
|
|
@ -173,12 +179,17 @@ class StreamConnection:
|
|||
Args:
|
||||
result: Inference result dictionary
|
||||
"""
|
||||
frame_ref = None
|
||||
try:
|
||||
# Extract detections
|
||||
detections = result["detections"]
|
||||
|
||||
# Run tracking (synchronous)
|
||||
tracked_objects = self._run_tracking_sync(detections)
|
||||
# Get FrameReference from metadata (if present)
|
||||
frame_ref = result["metadata"].get("frame_ref")
|
||||
|
||||
# Run tracking (synchronous) with frame shape for bbox scaling
|
||||
frame_shape = result["metadata"].get("shape")
|
||||
tracked_objects = self._run_tracking_sync(detections, frame_shape)
|
||||
|
||||
# Create tracking result
|
||||
tracking_result = TrackingResult(
|
||||
|
|
@ -196,13 +207,18 @@ class StreamConnection:
|
|||
except Exception as e:
|
||||
logger.error(f"Error handling inference result for {self.stream_id}: {e}", exc_info=True)
|
||||
self.error_queue.put(e)
|
||||
finally:
|
||||
# Free the frame reference - this is the last point in the pipeline
|
||||
if frame_ref is not None:
|
||||
frame_ref.free()
|
||||
|
||||
def _run_tracking_sync(self, detections, min_confidence=0.7):
|
||||
def _run_tracking_sync(self, detections, frame_shape=None, min_confidence=0.7):
|
||||
"""
|
||||
Run tracking synchronously (called from executor).
|
||||
|
||||
Args:
|
||||
detections: Detection tensor (N, 6) [x1, y1, x2, y2, conf, class_id]
|
||||
frame_shape: Original frame shape (C, H, W) for scaling bboxes
|
||||
min_confidence: Minimum confidence threshold for detections
|
||||
|
||||
Returns:
|
||||
|
|
@ -226,8 +242,8 @@ class StreamConnection:
|
|||
class_name=f"class_{int(det[5])}" if det.shape[0] > 5 else "unknown"
|
||||
))
|
||||
|
||||
# Update tracker with detections (lightweight, no model dependency!)
|
||||
return self.tracking_controller.update(detection_list)
|
||||
# Update tracker with detections (will scale bboxes to frame space)
|
||||
return self.tracking_controller.update(detection_list, frame_shape=frame_shape)
|
||||
|
||||
def tracking_results(self):
|
||||
"""
|
||||
|
|
@ -339,15 +355,31 @@ class StreamConnectionManager:
|
|||
"""
|
||||
Initialize the manager with a model.
|
||||
|
||||
Supports transparent loading of .pt (YOLO), .engine, and .trt files.
|
||||
For Ultralytics YOLO models (.pt), metadata is auto-detected - no manual
|
||||
input_shapes or precision needed! Non-YOLO models still require input_shapes.
|
||||
|
||||
Args:
|
||||
model_path: Path to TensorRT or PyTorch model file (.trt, .pt, .pth)
|
||||
model_path: Path to model file (.trt, .engine, .pt, .pth)
|
||||
- .engine: Ultralytics native format (recommended)
|
||||
- .pt: Auto-converts to .engine (YOLO models only)
|
||||
- .trt: Raw TensorRT engine
|
||||
model_id: Model identifier (default: "detector")
|
||||
preprocess_fn: Preprocessing function (e.g., YOLOv8Utils.preprocess)
|
||||
postprocess_fn: Postprocessing function (e.g., YOLOv8Utils.postprocess)
|
||||
num_contexts: Number of TensorRT execution contexts (default: 4)
|
||||
pt_input_shapes: Required for PT files - dict of input shapes
|
||||
pt_precision: Precision for PT conversion (torch.float16 or torch.float32)
|
||||
pt_input_shapes: [Optional] Only required for non-YOLO PyTorch models
|
||||
YOLO models auto-detect from embedded metadata
|
||||
pt_precision: [Optional] Precision for PT conversion (auto-detected for YOLO)
|
||||
**pt_conversion_kwargs: Additional PT conversion arguments
|
||||
|
||||
Example:
|
||||
# YOLO model - no manual parameters needed:
|
||||
manager.initialize(
|
||||
model_path="model.pt", # or .engine
|
||||
preprocess_fn=YOLOv8Utils.preprocess,
|
||||
postprocess_fn=YOLOv8Utils.postprocess
|
||||
)
|
||||
"""
|
||||
logger.info(f"Initializing StreamConnectionManager on GPU {self.gpu_id}")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue