45 lines
1.7 KiB
Python
45 lines
1.7 KiB
Python
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import torch.nn as nn
|
|
from monai.losses import DiceCELoss, DiceFocalLoss, DiceLoss, FocalLoss
|
|
|
|
|
|
class Loss(nn.Module):
|
|
def __init__(self, focal):
|
|
super(Loss, self).__init__()
|
|
if focal:
|
|
self.loss = DiceFocalLoss(gamma=2.0, softmax=True, to_onehot_y=True, batch=True)
|
|
else:
|
|
self.loss = DiceCELoss(softmax=True, to_onehot_y=True, batch=True)
|
|
|
|
def forward(self, y_pred, y_true):
|
|
return self.loss(y_pred, y_true)
|
|
|
|
|
|
class LossBraTS(nn.Module):
|
|
def __init__(self, focal):
|
|
super(LossBraTS, self).__init__()
|
|
self.dice = DiceLoss(sigmoid=True, batch=True)
|
|
self.ce = FocalLoss(gamma=2.0, to_onehot_y=False) if focal else nn.BCEWithLogitsLoss()
|
|
|
|
def _loss(self, p, y):
|
|
return self.dice(p, y) + self.ce(p, y.float())
|
|
|
|
def forward(self, p, y):
|
|
y_wt, y_tc, y_et = y > 0, ((y == 1) + (y == 3)) > 0, y == 3
|
|
p_wt, p_tc, p_et = p[:, 0].unsqueeze(1), p[:, 1].unsqueeze(1), p[:, 2].unsqueeze(1)
|
|
l_wt, l_tc, l_et = self._loss(p_wt, y_wt), self._loss(p_tc, y_tc), self._loss(p_et, y_et)
|
|
return l_wt + l_tc + l_et
|