profiling
This commit is contained in:
parent
7044b1e588
commit
c0ffa3967b
9 changed files with 354 additions and 1298 deletions
|
|
@ -16,7 +16,6 @@ import torch
|
|||
|
||||
from .model_controller import ModelController
|
||||
from .stream_decoder import StreamDecoderFactory
|
||||
from .tracking_factory import TrackingFactory
|
||||
from .model_repository import TensorRTModelRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -133,28 +132,32 @@ class StreamConnection:
|
|||
|
||||
async def _frame_poller(self):
|
||||
"""Poll frames from threaded decoder and submit to model controller"""
|
||||
last_frame_ptr = None
|
||||
last_decoder_frame_count = -1
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Poll frame from decoder (runs in thread)
|
||||
frame = self.decoder.get_latest_frame(rgb=True)
|
||||
# Get current decoder frame count (no data transfer, just counter)
|
||||
decoder_frame_count = self.decoder.get_frame_count()
|
||||
|
||||
# Check if we got a new frame (avoid reprocessing same frame)
|
||||
if frame is not None and frame.data_ptr() != last_frame_ptr:
|
||||
last_frame_ptr = frame.data_ptr()
|
||||
self.last_frame_time = time.time()
|
||||
self.frame_count += 1
|
||||
# Check if decoder has a new frame (avoid reprocessing same frame)
|
||||
if decoder_frame_count > last_decoder_frame_count:
|
||||
# Poll frame from decoder (zero-copy - stays in VRAM)
|
||||
frame = self.decoder.get_latest_frame(rgb=True)
|
||||
|
||||
# Submit to model controller for batched inference
|
||||
await self.model_controller.submit_frame(
|
||||
stream_id=self.stream_id,
|
||||
frame=frame,
|
||||
metadata={
|
||||
"frame_number": self.frame_count,
|
||||
"shape": tuple(frame.shape),
|
||||
}
|
||||
)
|
||||
if frame is not None:
|
||||
last_decoder_frame_count = decoder_frame_count
|
||||
self.last_frame_time = time.time()
|
||||
self.frame_count += 1
|
||||
|
||||
# Submit to model controller for batched inference
|
||||
await self.model_controller.submit_frame(
|
||||
stream_id=self.stream_id,
|
||||
frame=frame,
|
||||
metadata={
|
||||
"frame_number": self.frame_count,
|
||||
"shape": tuple(frame.shape),
|
||||
}
|
||||
)
|
||||
|
||||
# Check decoder status
|
||||
if not self.decoder.is_connected():
|
||||
|
|
@ -211,53 +214,37 @@ class StreamConnection:
|
|||
logger.error(f"Error handling inference result for {self.stream_id}: {e}", exc_info=True)
|
||||
await self.error_queue.put(e)
|
||||
|
||||
def _run_tracking_sync(self, detections):
|
||||
def _run_tracking_sync(self, detections, min_confidence=0.7):
|
||||
"""
|
||||
Run tracking synchronously (called from executor).
|
||||
|
||||
Args:
|
||||
detections: Detection tensor (N, 6) [x1, y1, x2, y2, conf, class_id]
|
||||
min_confidence: Minimum confidence threshold for detections
|
||||
|
||||
Returns:
|
||||
List of TrackedObject instances
|
||||
"""
|
||||
# Use the TrackingController's internal tracking with detections
|
||||
# We need to manually update tracks since we already have detections
|
||||
import torch
|
||||
# Convert tensor detections to Detection objects, filtering by confidence
|
||||
from .tracking_controller import Detection
|
||||
|
||||
with self.tracking_controller._lock:
|
||||
self.tracking_controller._frame_count += 1
|
||||
detection_list = []
|
||||
for det in detections:
|
||||
confidence = float(det[4])
|
||||
|
||||
# If no detections, just cleanup and return current tracks
|
||||
if len(detections) == 0:
|
||||
self.tracking_controller._cleanup_stale_tracks()
|
||||
return list(self.tracking_controller._tracks.values())
|
||||
# Filter by confidence threshold (prevents track accumulation)
|
||||
if confidence < min_confidence:
|
||||
continue
|
||||
|
||||
# Run IoU tracking to associate detections with existing tracks
|
||||
associations = self.tracking_controller._iou_tracking(detections)
|
||||
detection_list.append(Detection(
|
||||
bbox=det[:4].cpu().tolist(),
|
||||
confidence=confidence,
|
||||
class_id=int(det[5]) if det.shape[0] > 5 else 0,
|
||||
class_name=f"class_{int(det[5])}" if det.shape[0] > 5 else "unknown"
|
||||
))
|
||||
|
||||
# Update or create tracks
|
||||
for (det_idx, track_id), detection in zip(associations, detections):
|
||||
bbox = detection[:4].cpu().tolist()
|
||||
confidence = float(detection[4])
|
||||
class_id = int(detection[5]) if detection.shape[0] > 5 else 0
|
||||
|
||||
if track_id == -1:
|
||||
# Create new track
|
||||
new_track = self.tracking_controller._create_track(
|
||||
bbox, confidence, class_id, self.tracking_controller._frame_count
|
||||
)
|
||||
self.tracking_controller._tracks[new_track.track_id] = new_track
|
||||
else:
|
||||
# Update existing track
|
||||
self.tracking_controller._tracks[track_id].update(
|
||||
bbox, confidence, self.tracking_controller._frame_count
|
||||
)
|
||||
|
||||
# Cleanup stale tracks
|
||||
self.tracking_controller._cleanup_stale_tracks()
|
||||
|
||||
return list(self.tracking_controller._tracks.values())
|
||||
# Update tracker with detections (lightweight, no model dependency!)
|
||||
return self.tracking_controller.update(detection_list)
|
||||
|
||||
async def tracking_results(self) -> AsyncIterator[TrackingResult]:
|
||||
"""
|
||||
|
|
@ -341,7 +328,6 @@ class StreamConnectionManager:
|
|||
|
||||
# Factories
|
||||
self.decoder_factory = StreamDecoderFactory(gpu_id=gpu_id)
|
||||
self.tracking_factory = TrackingFactory(gpu_id=gpu_id)
|
||||
self.model_repository = TensorRTModelRepository(
|
||||
gpu_id=gpu_id,
|
||||
enable_pt_conversion=enable_pt_conversion
|
||||
|
|
@ -349,7 +335,6 @@ class StreamConnectionManager:
|
|||
|
||||
# Controllers
|
||||
self.model_controller: Optional[ModelController] = None
|
||||
self.tracking_controller = None
|
||||
|
||||
# Connections
|
||||
self.connections: Dict[str, StreamConnection] = {}
|
||||
|
|
@ -454,17 +439,16 @@ class StreamConnectionManager:
|
|||
# Create decoder
|
||||
decoder = self.decoder_factory.create_decoder(rtsp_url, buffer_size=buffer_size)
|
||||
|
||||
# Create dedicated tracking controller for THIS stream
|
||||
# This prevents track accumulation across multiple streams
|
||||
tracking_controller = self.tracking_factory.create_controller(
|
||||
model_repository=self.model_repository,
|
||||
model_id=self.model_id_for_tracking,
|
||||
# Create lightweight tracker (NO model_repository dependency!)
|
||||
from .tracking_controller import ObjectTracker
|
||||
tracking_controller = ObjectTracker(
|
||||
gpu_id=self.gpu_id,
|
||||
tracker_type="iou",
|
||||
max_age=30,
|
||||
min_confidence=0.5,
|
||||
iou_threshold=0.3,
|
||||
class_names=None # TODO: pass class names if available
|
||||
)
|
||||
logger.info(f"Created dedicated TrackingController for stream {stream_id}")
|
||||
logger.info(f"Created lightweight ObjectTracker for stream {stream_id}")
|
||||
|
||||
# Create connection
|
||||
connection = StreamConnection(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue