This commit is contained in:
cyanguwa 2021-11-09 10:59:19 +01:00 committed by GitHub
commit e8a142dcf8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 2 deletions

View file

@ -108,7 +108,7 @@ class SSD300(nn.Module):
def bbox_view(self, src, loc, conf):
ret = []
for s, l, c in zip(src, loc, conf):
ret.append((l(s).view(s.size(0), 4, -1), c(s).view(s.size(0), self.label_num, -1)))
ret.append((l(s).reshape(s.size(0), 4, -1), c(s).reshape(s.size(0), self.label_num, -1)))
locs, confs = list(zip(*ret))
locs, confs = torch.cat(locs, 2).contiguous(), torch.cat(confs, 2).contiguous()

View file

@ -46,6 +46,7 @@ def train_loop(model, loss_func, epoch, optim, train_dataloader, val_dataloader,
bbox = bbox.view(N, M, 4)
label = label.view(N, M)
img = img.to(memory_format=torch.channels_last)
ploc, plabel = model(img)
ploc, plabel = ploc.float(), plabel.float()
@ -116,7 +117,7 @@ def benchmark_train_loop(model, loss_func, epoch, optim, train_dataloader, val_d
img = img.to(memory_format=torch.channels_last)
ploc, plabel = model(img)
ploc, plabel = ploc.float(), plabel.float()