new buffer paradigm

This commit is contained in:
Siwat Sirichai 2025-11-11 02:02:12 +07:00
parent fdaeb9981c
commit a519dea130
6 changed files with 341 additions and 327 deletions

View file

@ -42,6 +42,7 @@ class TrackingResult:
tracked_objects: List # List of TrackedObject from TrackingController
detections: List # Raw detections
frame_shape: Tuple[int, int, int]
frame_tensor: Optional[torch.Tensor] # GPU tensor of the frame (C, H, W)
metadata: Dict
@ -158,6 +159,9 @@ class StreamConnection:
# Submit to model controller for batched inference
# Pass the FrameReference in metadata so we can free it later
logger.debug(
f"[{self.stream_id}] Submitting frame {self.frame_count} to model controller"
)
self.model_controller.submit_frame(
stream_id=self.stream_id,
frame=cloned_tensor, # Use cloned tensor, not original
@ -167,6 +171,9 @@ class StreamConnection:
"frame_ref": frame_ref, # Store reference for later cleanup
},
)
logger.debug(
f"[{self.stream_id}] Frame {self.frame_count} submitted, queue size: {len(self.model_controller.frame_queue)}"
)
# Update connection status based on decoder status
if (
@ -211,6 +218,12 @@ class StreamConnection:
frame_shape = result["metadata"].get("shape")
tracked_objects = self._run_tracking_sync(detections, frame_shape)
# Get ORIGINAL frame tensor from metadata (not the preprocessed one in result["frame"])
# The frame in result["frame"] is preprocessed (resized, normalized)
# We need the original frame for visualization
frame_ref = result["metadata"].get("frame_ref")
frame_tensor = frame_ref.rgb_tensor if frame_ref else None
# Create tracking result
tracking_result = TrackingResult(
stream_id=self.stream_id,
@ -218,6 +231,7 @@ class StreamConnection:
tracked_objects=tracked_objects,
detections=detections,
frame_shape=result["metadata"].get("shape"),
frame_tensor=frame_tensor, # Original frame, not preprocessed
metadata=result["metadata"],
)
@ -328,7 +342,7 @@ class StreamConnectionManager:
Args:
gpu_id: GPU device ID (default: 0)
batch_size: Maximum batch size for inference (default: 16)
force_timeout: Force buffer switch timeout in seconds (default: 0.05)
max_queue_size: Maximum frames in queue before dropping (default: 100)
poll_interval: Frame polling interval in seconds (default: 0.01)
Example:
@ -343,14 +357,14 @@ class StreamConnectionManager:
self,
gpu_id: int = 0,
batch_size: int = 16,
force_timeout: float = 0.05,
max_queue_size: int = 100,
poll_interval: float = 0.01,
enable_pt_conversion: bool = True,
backend: str = "tensorrt", # "tensorrt" or "ultralytics"
):
self.gpu_id = gpu_id
self.batch_size = batch_size
self.force_timeout = force_timeout
self.max_queue_size = max_queue_size
self.poll_interval = poll_interval
self.backend = backend.lower()
@ -449,7 +463,7 @@ class StreamConnectionManager:
inference_engine=self.inference_engine,
model_id=model_id,
batch_size=self.batch_size,
force_timeout=self.force_timeout,
max_queue_size=self.max_queue_size,
preprocess_fn=preprocess_fn,
postprocess_fn=postprocess_fn,
)
@ -473,7 +487,7 @@ class StreamConnectionManager:
model_repository=self.model_repository,
model_id=model_id,
batch_size=self.batch_size,
force_timeout=self.force_timeout,
max_queue_size=self.max_queue_size,
preprocess_fn=preprocess_fn,
postprocess_fn=postprocess_fn,
)
@ -656,7 +670,7 @@ class StreamConnectionManager:
"gpu_id": self.gpu_id,
"num_connections": len(self.connections),
"batch_size": self.batch_size,
"force_timeout": self.force_timeout,
"max_queue_size": self.max_queue_size,
"poll_interval": self.poll_interval,
},
"model_controller": self.model_controller.get_stats()