add StrongSORT Tacker
This commit is contained in:
parent
ffc2e99678
commit
b7d8b3266f
93 changed files with 20230 additions and 6 deletions
116
feeder/trackers/deepocsort/embedding.py
Normal file
116
feeder/trackers/deepocsort/embedding.py
Normal file
|
@ -0,0 +1,116 @@
|
|||
import pdb
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import torch
|
||||
import cv2
|
||||
import torchvision
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
class EmbeddingComputer:
|
||||
def __init__(self, dataset):
|
||||
self.model = None
|
||||
self.dataset = dataset
|
||||
self.crop_size = (128, 384)
|
||||
os.makedirs("./cache/embeddings/", exist_ok=True)
|
||||
self.cache_path = "./cache/embeddings/{}_embedding.pkl"
|
||||
self.cache = {}
|
||||
self.cache_name = ""
|
||||
|
||||
def load_cache(self, path):
|
||||
self.cache_name = path
|
||||
cache_path = self.cache_path.format(path)
|
||||
if os.path.exists(cache_path):
|
||||
with open(cache_path, "rb") as fp:
|
||||
self.cache = pickle.load(fp)
|
||||
|
||||
def compute_embedding(self, img, bbox, tag, is_numpy=True):
|
||||
if self.cache_name != tag.split(":")[0]:
|
||||
self.load_cache(tag.split(":")[0])
|
||||
|
||||
if tag in self.cache:
|
||||
embs = self.cache[tag]
|
||||
if embs.shape[0] != bbox.shape[0]:
|
||||
raise RuntimeError(
|
||||
"ERROR: The number of cached embeddings don't match the "
|
||||
"number of detections.\nWas the detector model changed? Delete cache if so."
|
||||
)
|
||||
return embs
|
||||
|
||||
if self.model is None:
|
||||
self.initialize_model()
|
||||
|
||||
# Make sure bbox is within image frame
|
||||
if is_numpy:
|
||||
h, w = img.shape[:2]
|
||||
else:
|
||||
h, w = img.shape[2:]
|
||||
results = np.round(bbox).astype(np.int32)
|
||||
results[:, 0] = results[:, 0].clip(0, w)
|
||||
results[:, 1] = results[:, 1].clip(0, h)
|
||||
results[:, 2] = results[:, 2].clip(0, w)
|
||||
results[:, 3] = results[:, 3].clip(0, h)
|
||||
|
||||
# Generate all the crops
|
||||
crops = []
|
||||
for p in results:
|
||||
if is_numpy:
|
||||
crop = img[p[1] : p[3], p[0] : p[2]]
|
||||
crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
|
||||
crop = cv2.resize(crop, self.crop_size, interpolation=cv2.INTER_LINEAR)
|
||||
crop = torch.as_tensor(crop.astype("float32").transpose(2, 0, 1))
|
||||
crop = crop.unsqueeze(0)
|
||||
else:
|
||||
crop = img[:, :, p[1] : p[3], p[0] : p[2]]
|
||||
crop = torchvision.transforms.functional.resize(crop, self.crop_size)
|
||||
|
||||
crops.append(crop)
|
||||
|
||||
crops = torch.cat(crops, dim=0)
|
||||
|
||||
# Create embeddings and l2 normalize them
|
||||
with torch.no_grad():
|
||||
crops = crops.cuda()
|
||||
crops = crops.half()
|
||||
embs = self.model(crops)
|
||||
embs = torch.nn.functional.normalize(embs)
|
||||
embs = embs.cpu().numpy()
|
||||
|
||||
self.cache[tag] = embs
|
||||
return embs
|
||||
|
||||
def initialize_model(self):
|
||||
"""
|
||||
model = torchreid.models.build_model(name="osnet_ain_x1_0", num_classes=2510, loss="softmax", pretrained=False)
|
||||
sd = torch.load("external/weights/osnet_ain_ms_d_c.pth.tar")["state_dict"]
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in sd.items():
|
||||
name = k[7:] # remove `module.`
|
||||
new_state_dict[name] = v
|
||||
# load params
|
||||
model.load_state_dict(new_state_dict)
|
||||
model.eval()
|
||||
model.cuda()
|
||||
"""
|
||||
if self.dataset == "mot17":
|
||||
path = "external/weights/mot17_sbs_S50.pth"
|
||||
elif self.dataset == "mot20":
|
||||
path = "external/weights/mot20_sbs_S50.pth"
|
||||
elif self.dataset == "dance":
|
||||
path = None
|
||||
else:
|
||||
raise RuntimeError("Need the path for a new ReID model.")
|
||||
|
||||
model = FastReID(path)
|
||||
model.eval()
|
||||
model.cuda()
|
||||
model.half()
|
||||
self.model = model
|
||||
|
||||
def dump_cache(self):
|
||||
if self.cache_name:
|
||||
with open(self.cache_path.format(self.cache_name), "wb") as fp:
|
||||
pickle.dump(self.cache, fp)
|
Loading…
Add table
Add a link
Reference in a new issue