add models and update tracking system

This commit is contained in:
Siwat Sirichai 2025-11-09 20:35:17 +07:00
parent fd470b3765
commit 2b0cfc4b72
9 changed files with 780 additions and 478 deletions

View file

@ -332,6 +332,7 @@ class StreamConnectionManager:
batch_size: int = 16,
force_timeout: float = 0.05,
poll_interval: float = 0.01,
enable_pt_conversion: bool = True,
):
self.gpu_id = gpu_id
self.batch_size = batch_size
@ -341,7 +342,10 @@ class StreamConnectionManager:
# Factories
self.decoder_factory = StreamDecoderFactory(gpu_id=gpu_id)
self.tracking_factory = TrackingFactory(gpu_id=gpu_id)
self.model_repository = TensorRTModelRepository(gpu_id=gpu_id)
self.model_repository = TensorRTModelRepository(
gpu_id=gpu_id,
enable_pt_conversion=enable_pt_conversion
)
# Controllers
self.model_controller: Optional[ModelController] = None
@ -360,16 +364,22 @@ class StreamConnectionManager:
preprocess_fn: Optional[Callable] = None,
postprocess_fn: Optional[Callable] = None,
num_contexts: int = 4,
pt_input_shapes: Optional[Dict] = None,
pt_precision: Optional[Any] = None,
**pt_conversion_kwargs
):
"""
Initialize the manager with a model.
Args:
model_path: Path to TensorRT model file
model_path: Path to TensorRT or PyTorch model file (.trt, .pt, .pth)
model_id: Model identifier (default: "detector")
preprocess_fn: Preprocessing function (e.g., YOLOv8Utils.preprocess)
postprocess_fn: Postprocessing function (e.g., YOLOv8Utils.postprocess)
num_contexts: Number of TensorRT execution contexts (default: 4)
pt_input_shapes: Required for PT files - dict of input shapes
pt_precision: Precision for PT conversion (torch.float16 or torch.float32)
**pt_conversion_kwargs: Additional PT conversion arguments
"""
logger.info(f"Initializing StreamConnectionManager on GPU {self.gpu_id}")
@ -377,7 +387,14 @@ class StreamConnectionManager:
loop = asyncio.get_event_loop()
await loop.run_in_executor(
None,
lambda: self.model_repository.load_model(model_id, model_path, num_contexts=num_contexts)
lambda: self.model_repository.load_model(
model_id,
model_path,
num_contexts=num_contexts,
pt_input_shapes=pt_input_shapes,
pt_precision=pt_precision,
**pt_conversion_kwargs
)
)
logger.info(f"Loaded model {model_id} from {model_path}")
@ -392,13 +409,10 @@ class StreamConnectionManager:
)
await self.model_controller.start()
# Create tracking controller
self.tracking_controller = self.tracking_factory.create_controller(
model_repository=self.model_repository,
model_id=model_id,
tracker_type="iou",
)
logger.info("TrackingController created")
# Don't create a shared tracking controller here
# Each stream will get its own tracking controller to avoid track accumulation
self.tracking_controller = None
self.model_id_for_tracking = model_id # Store for later use
self.initialized = True
logger.info("StreamConnectionManager initialized successfully")
@ -440,12 +454,24 @@ class StreamConnectionManager:
# Create decoder
decoder = self.decoder_factory.create_decoder(rtsp_url, buffer_size=buffer_size)
# Create dedicated tracking controller for THIS stream
# This prevents track accumulation across multiple streams
tracking_controller = self.tracking_factory.create_controller(
model_repository=self.model_repository,
model_id=self.model_id_for_tracking,
tracker_type="iou",
max_age=30,
min_confidence=0.5,
iou_threshold=0.3,
)
logger.info(f"Created dedicated TrackingController for stream {stream_id}")
# Create connection
connection = StreamConnection(
stream_id=stream_id,
decoder=decoder,
model_controller=self.model_controller,
tracking_controller=self.tracking_controller,
tracking_controller=tracking_controller,
poll_interval=self.poll_interval,
)