add models and update tracking system
This commit is contained in:
parent
fd470b3765
commit
2b0cfc4b72
9 changed files with 780 additions and 478 deletions
|
|
@ -332,6 +332,7 @@ class StreamConnectionManager:
|
|||
batch_size: int = 16,
|
||||
force_timeout: float = 0.05,
|
||||
poll_interval: float = 0.01,
|
||||
enable_pt_conversion: bool = True,
|
||||
):
|
||||
self.gpu_id = gpu_id
|
||||
self.batch_size = batch_size
|
||||
|
|
@ -341,7 +342,10 @@ class StreamConnectionManager:
|
|||
# Factories
|
||||
self.decoder_factory = StreamDecoderFactory(gpu_id=gpu_id)
|
||||
self.tracking_factory = TrackingFactory(gpu_id=gpu_id)
|
||||
self.model_repository = TensorRTModelRepository(gpu_id=gpu_id)
|
||||
self.model_repository = TensorRTModelRepository(
|
||||
gpu_id=gpu_id,
|
||||
enable_pt_conversion=enable_pt_conversion
|
||||
)
|
||||
|
||||
# Controllers
|
||||
self.model_controller: Optional[ModelController] = None
|
||||
|
|
@ -360,16 +364,22 @@ class StreamConnectionManager:
|
|||
preprocess_fn: Optional[Callable] = None,
|
||||
postprocess_fn: Optional[Callable] = None,
|
||||
num_contexts: int = 4,
|
||||
pt_input_shapes: Optional[Dict] = None,
|
||||
pt_precision: Optional[Any] = None,
|
||||
**pt_conversion_kwargs
|
||||
):
|
||||
"""
|
||||
Initialize the manager with a model.
|
||||
|
||||
Args:
|
||||
model_path: Path to TensorRT model file
|
||||
model_path: Path to TensorRT or PyTorch model file (.trt, .pt, .pth)
|
||||
model_id: Model identifier (default: "detector")
|
||||
preprocess_fn: Preprocessing function (e.g., YOLOv8Utils.preprocess)
|
||||
postprocess_fn: Postprocessing function (e.g., YOLOv8Utils.postprocess)
|
||||
num_contexts: Number of TensorRT execution contexts (default: 4)
|
||||
pt_input_shapes: Required for PT files - dict of input shapes
|
||||
pt_precision: Precision for PT conversion (torch.float16 or torch.float32)
|
||||
**pt_conversion_kwargs: Additional PT conversion arguments
|
||||
"""
|
||||
logger.info(f"Initializing StreamConnectionManager on GPU {self.gpu_id}")
|
||||
|
||||
|
|
@ -377,7 +387,14 @@ class StreamConnectionManager:
|
|||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.model_repository.load_model(model_id, model_path, num_contexts=num_contexts)
|
||||
lambda: self.model_repository.load_model(
|
||||
model_id,
|
||||
model_path,
|
||||
num_contexts=num_contexts,
|
||||
pt_input_shapes=pt_input_shapes,
|
||||
pt_precision=pt_precision,
|
||||
**pt_conversion_kwargs
|
||||
)
|
||||
)
|
||||
logger.info(f"Loaded model {model_id} from {model_path}")
|
||||
|
||||
|
|
@ -392,13 +409,10 @@ class StreamConnectionManager:
|
|||
)
|
||||
await self.model_controller.start()
|
||||
|
||||
# Create tracking controller
|
||||
self.tracking_controller = self.tracking_factory.create_controller(
|
||||
model_repository=self.model_repository,
|
||||
model_id=model_id,
|
||||
tracker_type="iou",
|
||||
)
|
||||
logger.info("TrackingController created")
|
||||
# Don't create a shared tracking controller here
|
||||
# Each stream will get its own tracking controller to avoid track accumulation
|
||||
self.tracking_controller = None
|
||||
self.model_id_for_tracking = model_id # Store for later use
|
||||
|
||||
self.initialized = True
|
||||
logger.info("StreamConnectionManager initialized successfully")
|
||||
|
|
@ -440,12 +454,24 @@ class StreamConnectionManager:
|
|||
# Create decoder
|
||||
decoder = self.decoder_factory.create_decoder(rtsp_url, buffer_size=buffer_size)
|
||||
|
||||
# Create dedicated tracking controller for THIS stream
|
||||
# This prevents track accumulation across multiple streams
|
||||
tracking_controller = self.tracking_factory.create_controller(
|
||||
model_repository=self.model_repository,
|
||||
model_id=self.model_id_for_tracking,
|
||||
tracker_type="iou",
|
||||
max_age=30,
|
||||
min_confidence=0.5,
|
||||
iou_threshold=0.3,
|
||||
)
|
||||
logger.info(f"Created dedicated TrackingController for stream {stream_id}")
|
||||
|
||||
# Create connection
|
||||
connection = StreamConnection(
|
||||
stream_id=stream_id,
|
||||
decoder=decoder,
|
||||
model_controller=self.model_controller,
|
||||
tracking_controller=self.tracking_controller,
|
||||
tracking_controller=tracking_controller,
|
||||
poll_interval=self.poll_interval,
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue