Merge 07d6234c72
into ffad84899b
This commit is contained in:
commit
e8a142dcf8
|
@ -108,7 +108,7 @@ class SSD300(nn.Module):
|
|||
def bbox_view(self, src, loc, conf):
|
||||
ret = []
|
||||
for s, l, c in zip(src, loc, conf):
|
||||
ret.append((l(s).view(s.size(0), 4, -1), c(s).view(s.size(0), self.label_num, -1)))
|
||||
ret.append((l(s).reshape(s.size(0), 4, -1), c(s).reshape(s.size(0), self.label_num, -1)))
|
||||
|
||||
locs, confs = list(zip(*ret))
|
||||
locs, confs = torch.cat(locs, 2).contiguous(), torch.cat(confs, 2).contiguous()
|
||||
|
|
|
@ -46,6 +46,7 @@ def train_loop(model, loss_func, epoch, optim, train_dataloader, val_dataloader,
|
|||
bbox = bbox.view(N, M, 4)
|
||||
label = label.view(N, M)
|
||||
|
||||
img = img.to(memory_format=torch.channels_last)
|
||||
ploc, plabel = model(img)
|
||||
ploc, plabel = ploc.float(), plabel.float()
|
||||
|
||||
|
@ -116,7 +117,7 @@ def benchmark_train_loop(model, loss_func, epoch, optim, train_dataloader, val_d
|
|||
|
||||
|
||||
|
||||
|
||||
img = img.to(memory_format=torch.channels_last)
|
||||
ploc, plabel = model(img)
|
||||
ploc, plabel = ploc.float(), plabel.float()
|
||||
|
||||
|
|
Loading…
Reference in a new issue