new buffer paradigm
This commit is contained in:
parent
fdaeb9981c
commit
a519dea130
6 changed files with 341 additions and 327 deletions
|
|
@ -42,6 +42,7 @@ class TrackingResult:
|
|||
tracked_objects: List # List of TrackedObject from TrackingController
|
||||
detections: List # Raw detections
|
||||
frame_shape: Tuple[int, int, int]
|
||||
frame_tensor: Optional[torch.Tensor] # GPU tensor of the frame (C, H, W)
|
||||
metadata: Dict
|
||||
|
||||
|
||||
|
|
@ -158,6 +159,9 @@ class StreamConnection:
|
|||
|
||||
# Submit to model controller for batched inference
|
||||
# Pass the FrameReference in metadata so we can free it later
|
||||
logger.debug(
|
||||
f"[{self.stream_id}] Submitting frame {self.frame_count} to model controller"
|
||||
)
|
||||
self.model_controller.submit_frame(
|
||||
stream_id=self.stream_id,
|
||||
frame=cloned_tensor, # Use cloned tensor, not original
|
||||
|
|
@ -167,6 +171,9 @@ class StreamConnection:
|
|||
"frame_ref": frame_ref, # Store reference for later cleanup
|
||||
},
|
||||
)
|
||||
logger.debug(
|
||||
f"[{self.stream_id}] Frame {self.frame_count} submitted, queue size: {len(self.model_controller.frame_queue)}"
|
||||
)
|
||||
|
||||
# Update connection status based on decoder status
|
||||
if (
|
||||
|
|
@ -211,6 +218,12 @@ class StreamConnection:
|
|||
frame_shape = result["metadata"].get("shape")
|
||||
tracked_objects = self._run_tracking_sync(detections, frame_shape)
|
||||
|
||||
# Get ORIGINAL frame tensor from metadata (not the preprocessed one in result["frame"])
|
||||
# The frame in result["frame"] is preprocessed (resized, normalized)
|
||||
# We need the original frame for visualization
|
||||
frame_ref = result["metadata"].get("frame_ref")
|
||||
frame_tensor = frame_ref.rgb_tensor if frame_ref else None
|
||||
|
||||
# Create tracking result
|
||||
tracking_result = TrackingResult(
|
||||
stream_id=self.stream_id,
|
||||
|
|
@ -218,6 +231,7 @@ class StreamConnection:
|
|||
tracked_objects=tracked_objects,
|
||||
detections=detections,
|
||||
frame_shape=result["metadata"].get("shape"),
|
||||
frame_tensor=frame_tensor, # Original frame, not preprocessed
|
||||
metadata=result["metadata"],
|
||||
)
|
||||
|
||||
|
|
@ -328,7 +342,7 @@ class StreamConnectionManager:
|
|||
Args:
|
||||
gpu_id: GPU device ID (default: 0)
|
||||
batch_size: Maximum batch size for inference (default: 16)
|
||||
force_timeout: Force buffer switch timeout in seconds (default: 0.05)
|
||||
max_queue_size: Maximum frames in queue before dropping (default: 100)
|
||||
poll_interval: Frame polling interval in seconds (default: 0.01)
|
||||
|
||||
Example:
|
||||
|
|
@ -343,14 +357,14 @@ class StreamConnectionManager:
|
|||
self,
|
||||
gpu_id: int = 0,
|
||||
batch_size: int = 16,
|
||||
force_timeout: float = 0.05,
|
||||
max_queue_size: int = 100,
|
||||
poll_interval: float = 0.01,
|
||||
enable_pt_conversion: bool = True,
|
||||
backend: str = "tensorrt", # "tensorrt" or "ultralytics"
|
||||
):
|
||||
self.gpu_id = gpu_id
|
||||
self.batch_size = batch_size
|
||||
self.force_timeout = force_timeout
|
||||
self.max_queue_size = max_queue_size
|
||||
self.poll_interval = poll_interval
|
||||
self.backend = backend.lower()
|
||||
|
||||
|
|
@ -449,7 +463,7 @@ class StreamConnectionManager:
|
|||
inference_engine=self.inference_engine,
|
||||
model_id=model_id,
|
||||
batch_size=self.batch_size,
|
||||
force_timeout=self.force_timeout,
|
||||
max_queue_size=self.max_queue_size,
|
||||
preprocess_fn=preprocess_fn,
|
||||
postprocess_fn=postprocess_fn,
|
||||
)
|
||||
|
|
@ -473,7 +487,7 @@ class StreamConnectionManager:
|
|||
model_repository=self.model_repository,
|
||||
model_id=model_id,
|
||||
batch_size=self.batch_size,
|
||||
force_timeout=self.force_timeout,
|
||||
max_queue_size=self.max_queue_size,
|
||||
preprocess_fn=preprocess_fn,
|
||||
postprocess_fn=postprocess_fn,
|
||||
)
|
||||
|
|
@ -656,7 +670,7 @@ class StreamConnectionManager:
|
|||
"gpu_id": self.gpu_id,
|
||||
"num_connections": len(self.connections),
|
||||
"batch_size": self.batch_size,
|
||||
"force_timeout": self.force_timeout,
|
||||
"max_queue_size": self.max_queue_size,
|
||||
"poll_interval": self.poll_interval,
|
||||
},
|
||||
"model_controller": self.model_controller.get_stats()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue