import numpy as np
import math
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.initializer import KaimingUniform, Uniform
from ppdet.core.workspace import register
from ppdet.modeling.heads.centernet_head import ConvLayer
all = ['SWINFairMOTEmbeddingHead']
@register
class SWINFairMOTEmbeddingHead(nn.Layer):
shared = ['num_classes']
"""
Args:
in_channels (int): the channel number of input to SWINFairMOTEmbeddingHead.
ch_head (int): the channel of features before fed into embedding, 256 by default.
ch_emb (int): the channel of the embedding feature, 128 by default.
num_identities_dict (dict): the number of identities of each category,
support single class and multi-calss, {0: 14455} as default.
"""
def __init__(self,
# in_channels,
in_channels_list,
ch_head=256,
ch_emb=128,
num_classes=1,
num_identities_dict={0: 14455}):
super(SWINFairMOTEmbeddingHead, self).__init__()
assert num_classes >= 1
self.num_classes = num_classes
self.ch_emb = ch_emb
self.num_identities_dict = num_identities_dict
self.reid_convs = nn.LayerList()
for in_channels in in_channels_list:
self.reid_convs.append(
nn.Sequential(
ConvLayer(
in_channels, ch_head, kernel_size=3, padding=1, bias=True),
nn.ReLU(),
ConvLayer(
ch_head, ch_emb, kernel_size=1, stride=1, padding=0, bias=True)))
# self.reid = nn.Sequential(
# ConvLayer(
# in_channels, ch_head, kernel_size=3, padding=1, bias=True),
# nn.ReLU(),
# ConvLayer(
# ch_head, ch_emb, kernel_size=1, stride=1, padding=0, bias=True))
param_attr = paddle.ParamAttr(initializer=KaimingUniform())
bound = 1 / math.sqrt(ch_emb)
bias_attr = paddle.ParamAttr(initializer=Uniform(-bound, bound))
self.reid_loss = nn.CrossEntropyLoss(ignore_index=-1, reduction='sum')
if num_classes == 1:
nID = self.num_identities_dict[0] # single class
self.classifier = nn.Linear(
ch_emb, nID, weight_attr=param_attr, bias_attr=bias_attr)
# When num_identities(nID) is 1, emb_scale is set as 1
self.emb_scale = math.sqrt(2) * math.log(nID - 1) if nID > 1 else 1
else:
self.classifiers = dict()
self.emb_scale_dict = dict()
for cls_id, nID in self.num_identities_dict.items():
self.classifiers[str(cls_id)] = nn.Linear(
ch_emb, nID, weight_attr=param_attr, bias_attr=bias_attr)
# When num_identities(nID) is 1, emb_scale is set as 1
self.emb_scale_dict[str(cls_id)] = math.sqrt(2) * math.log(
nID - 1) if nID > 1 else 1
@classmethod
def from_config(cls, cfg, input_shape):
# print('###########################################')
# print(input_shape)
# print()
# if isinstance(input_shape, (list, tuple)):
# input_shape = input_shape[0]
# in_channels_list = []
# for channel in input_shape:
# in_channels_list.append(channel.channels)
# print(in_channels_list)
return {'in_channels_list': cfg['in_channels_list']}
def process_by_class(self, bboxes, embedding, bbox_inds, topk_clses):
pred_dets, pred_embs = [], []
for cls_id in range(self.num_classes):
inds_masks = topk_clses == cls_id
inds_masks = paddle.cast(inds_masks, 'float32')
pos_num = inds_masks.sum().numpy()
if pos_num == 0:
continue
cls_inds_mask = inds_masks > 0
bbox_mask = paddle.nonzero(cls_inds_mask)
cls_bboxes = paddle.gather_nd(bboxes, bbox_mask)
pred_dets.append(cls_bboxes)
cls_inds = paddle.masked_select(bbox_inds, cls_inds_mask)
cls_inds = cls_inds.unsqueeze(-1)
cls_embedding = paddle.gather_nd(embedding, cls_inds)
pred_embs.append(cls_embedding)
return paddle.concat(pred_dets), paddle.concat(pred_embs)
def forward(self,
neck_feats,
inputs,
bboxes=None,
bbox_inds=None,
topk_clses=None):
#######################################################################
reid_feats = []
for feat, reid_conv in zip(neck_feats, self.reid_convs):
reid_feat = reid_conv(feat)
reid_feats.append(reid_feat)
# reid_feats = [] # 创建一个列表来收集每个特征图的reid特征
# for neck_feat in neck_feats:
# reid_feat = self.reid(neck_feat)
# reid_feats.append(reid_feat)
# reid_feat = paddle.concat(reid_feats, axis=2) # 合并所有特征图的reid特征
target_size = max(feat.shape[2:] for feat in reid_feats)
reid_feats_upsampled = [F.interpolate(feat, size=target_size, mode='bilinear', align_corners=False) for feat in
reid_feats]
reid_feat = sum(reid_feats_upsampled) / len(reid_feats_upsampled)
#######################################################################
# reid_feat = self.reid(neck_feat)
if self.training:
if self.num_classes == 1:
loss = self.get_loss(reid_feat, inputs)
else:
loss = self.get_mc_loss(reid_feat, inputs)
return loss
else:
assert bboxes is not None and bbox_inds is not None
reid_feat = F.normalize(reid_feat)
embedding = paddle.transpose(reid_feat, [0, 2, 3, 1])
embedding = paddle.reshape(embedding, [-1, self.ch_emb])
# embedding shape: [bs * h * w, ch_emb]
if self.num_classes == 1:
pred_dets = bboxes
pred_embs = paddle.gather(embedding, bbox_inds)
else:
pred_dets, pred_embs = self.process_by_class(
bboxes, embedding, bbox_inds, topk_clses)
return pred_dets, pred_embs
def get_loss(self, feat, inputs):
index = inputs['index']
mask = inputs['index_mask']
target = inputs['reid']
target = paddle.masked_select(target, mask > 0)
target = paddle.unsqueeze(target, 1)
feat = paddle.transpose(feat, perm=[0, 2, 3, 1])
feat_n, feat_h, feat_w, feat_c = feat.shape
feat = paddle.reshape(feat, shape=[feat_n, -1, feat_c])
index = paddle.unsqueeze(index, 2)
batch_inds = list()
for i in range(feat_n):
batch_ind = paddle.full(
shape=[1, index.shape[1], 1], fill_value=i, dtype='int64')
batch_inds.append(batch_ind)
batch_inds = paddle.concat(batch_inds, axis=0)
index = paddle.concat(x=[batch_inds, index], axis=2)
feat = paddle.gather_nd(feat, index=index)
mask = paddle.unsqueeze(mask, axis=2)
mask = paddle.expand_as(mask, feat)
mask.stop_gradient = True
feat = paddle.masked_select(feat, mask > 0)
feat = paddle.reshape(feat, shape=[-1, feat_c])
feat = F.normalize(feat)
feat = self.emb_scale * feat
logit = self.classifier(feat)
target.stop_gradient = True
loss = self.reid_loss(logit, target)
valid = (target != self.reid_loss.ignore_index)
valid.stop_gradient = True
count = paddle.sum((paddle.cast(valid, dtype=np.int32)))
count.stop_gradient = True
if count > 0:
loss = loss / count
return loss
def get_mc_loss(self, feat, inputs):
# feat.shape = [bs, ch_emb, h, w]
assert 'cls_id_map' in inputs and 'cls_tr_ids' in inputs
index = inputs['index']
mask = inputs['index_mask']
cls_id_map = inputs['cls_id_map'] # [bs, h, w]
cls_tr_ids = inputs['cls_tr_ids'] # [bs, num_classes, h, w]
feat = paddle.transpose(feat, perm=[0, 2, 3, 1])
feat_n, feat_h, feat_w, feat_c = feat.shape
feat = paddle.reshape(feat, shape=[feat_n, -1, feat_c])
index = paddle.unsqueeze(index, 2)
batch_inds = list()
for i in range(feat_n):
batch_ind = paddle.full(
shape=[1, index.shape[1], 1], fill_value=i, dtype='int64')
batch_inds.append(batch_ind)
batch_inds = paddle.concat(batch_inds, axis=0)
index = paddle.concat(x=[batch_inds, index], axis=2)
feat = paddle.gather_nd(feat, index=index)
print('=======/home/aistudio/PaddleDetection/ppdet/modeling/reid/SWINFairMOTEmbeddingHead.py========')
print("mask shape:", mask.shape)
print("feat shape:", feat.shape)
mask = paddle.unsqueeze(mask, axis=2)
mask = paddle.expand(mask, shape=feat.shape)
# mask = paddle.expand_as(mask, feat)
mask.stop_gradient = True
feat = paddle.masked_select(feat, mask > 0)
feat = paddle.reshape(feat, shape=[-1, feat_c])
reid_losses = 0
for cls_id, id_num in self.num_identities_dict.items():
# target
cur_cls_tr_ids = paddle.reshape(
cls_tr_ids[:, cls_id, :, :], shape=[feat_n, -1]) # [bs, h*w]
cls_id_target = paddle.gather_nd(cur_cls_tr_ids, index=index)
mask = inputs['index_mask']
cls_id_target = paddle.masked_select(cls_id_target, mask > 0)
cls_id_target.stop_gradient = True
# feat
cls_id_feat = self.emb_scale_dict[str(cls_id)] * F.normalize(feat)
cls_id_pred = self.classifiers[str(cls_id)](cls_id_feat)
loss = self.reid_loss(cls_id_pred, cls_id_target)
valid = (cls_id_target != self.reid_loss.ignore_index)
valid.stop_gradient = True
count = paddle.sum((paddle.cast(valid, dtype=np.int32)))
count.stop_gradient = True
if count > 0:
loss = loss / count
reid_losses += loss
return reid_losses
发生异常: OSError
(External) CUDA error(719), unspecified launch failure.
[Hint: 'cudaErrorLaunchFailure'. An exception occurred on the device while executing a kernel. Common causes include dereferencing an invalid device pointerand accessing out of bounds shared memory. Less common cases can be system specific - more information about these cases canbe found in the system specific user guide. This leaves the process in an inconsistent state and any further CUDA work willreturn the same error. To continue using CUDA, the process must be terminated and relaunched.] (at /paddle/paddle/phi/backends/gpu/cuda/cuda_info.cc:272)
File "/home/aistudio/PaddleDetection/ppdet/modeling/reid/SWINFairMOTEmbeddingHead.py", line 223, in get_mc_loss
mask = paddle.unsqueeze(mask, axis=2)
File "/home/aistudio/PaddleDetection/ppdet/modeling/reid/SWINFairMOTEmbeddingHead.py", line 141, in forward
loss = self.get_mc_loss(reid_feat, inputs)
File "/home/aistudio/PaddleDetection/ppdet/modeling/architectures/fairmot.py", line 79, in _forward
reid_loss = self.reid(neck_feat, self.inputs)
File "/home/aistudio/PaddleDetection/ppdet/modeling/architectures/fairmot.py", line 100, in get_loss
loss = self._forward()
File "/home/aistudio/PaddleDetection/ppdet/modeling/architectures/meta_arch.py", line 60, in forward
out = self.get_loss()
File "/home/aistudio/PaddleDetection/ppdet/engine/trainer.py", line 577, in train
outputs = model(data)
File "/home/aistudio/PaddleDetection/tools/train.py", line 159, in run
trainer.train(FLAGS.eval)
File "/home/aistudio/PaddleDetection/tools/train.py", line 207, in main
run(FLAGS, cfg)
File "/home/aistudio/PaddleDetection/tools/train.py", line 211, in
main()
OSError: (External) CUDA error(719), unspecified launch failure.
[Hint: 'cudaErrorLaunchFailure'. An exception occurred on the device while executing a kernel. Common causes include dereferencing an invalid device pointerand accessing out of bounds shared memory. Less common cases can be system specific - more information about these cases canbe found in the system specific user guide. This leaves the process in an inconsistent state and any further CUDA work willreturn the same error. To continue using CUDA, the process must be terminated and relaunched.] (at /paddle/paddle/phi/backends/gpu/cuda/cuda_info.cc:272)
请提出你的问题 Please ask your question
运行是在aistudio上运行的
发生异常的代码:
import numpy as np import math import paddle import paddle.nn as nn import paddle.nn.functional as F from paddle.nn.initializer import KaimingUniform, Uniform from ppdet.core.workspace import register from ppdet.modeling.heads.centernet_head import ConvLayer
all = ['SWINFairMOTEmbeddingHead']
@register class SWINFairMOTEmbeddingHead(nn.Layer): shared = ['num_classes'] """ Args: in_channels (int): the channel number of input to SWINFairMOTEmbeddingHead. ch_head (int): the channel of features before fed into embedding, 256 by default. ch_emb (int): the channel of the embedding feature, 128 by default. num_identities_dict (dict): the number of identities of each category, support single class and multi-calss, {0: 14455} as default. """
发生异常: OSError (External) CUDA error(719), unspecified launch failure. [Hint: 'cudaErrorLaunchFailure'. An exception occurred on the device while executing a kernel. Common causes include dereferencing an invalid device pointerand accessing out of bounds shared memory. Less common cases can be system specific - more information about these cases canbe found in the system specific user guide. This leaves the process in an inconsistent state and any further CUDA work willreturn the same error. To continue using CUDA, the process must be terminated and relaunched.] (at /paddle/paddle/phi/backends/gpu/cuda/cuda_info.cc:272) File "/home/aistudio/PaddleDetection/ppdet/modeling/reid/SWINFairMOTEmbeddingHead.py", line 223, in get_mc_loss mask = paddle.unsqueeze(mask, axis=2) File "/home/aistudio/PaddleDetection/ppdet/modeling/reid/SWINFairMOTEmbeddingHead.py", line 141, in forward loss = self.get_mc_loss(reid_feat, inputs) File "/home/aistudio/PaddleDetection/ppdet/modeling/architectures/fairmot.py", line 79, in _forward reid_loss = self.reid(neck_feat, self.inputs) File "/home/aistudio/PaddleDetection/ppdet/modeling/architectures/fairmot.py", line 100, in get_loss loss = self._forward() File "/home/aistudio/PaddleDetection/ppdet/modeling/architectures/meta_arch.py", line 60, in forward out = self.get_loss() File "/home/aistudio/PaddleDetection/ppdet/engine/trainer.py", line 577, in train outputs = model(data) File "/home/aistudio/PaddleDetection/tools/train.py", line 159, in run trainer.train(FLAGS.eval) File "/home/aistudio/PaddleDetection/tools/train.py", line 207, in main run(FLAGS, cfg) File "/home/aistudio/PaddleDetection/tools/train.py", line 211, in
main()
OSError: (External) CUDA error(719), unspecified launch failure.
[Hint: 'cudaErrorLaunchFailure'. An exception occurred on the device while executing a kernel. Common causes include dereferencing an invalid device pointerand accessing out of bounds shared memory. Less common cases can be system specific - more information about these cases canbe found in the system specific user guide. This leaves the process in an inconsistent state and any further CUDA work willreturn the same error. To continue using CUDA, the process must be terminated and relaunched.] (at /paddle/paddle/phi/backends/gpu/cuda/cuda_info.cc:272)