class SNPReductionNet(ltn.LightningModule):
def __init__(
self,
output_dim: int,
blocks_gt: List[List[int]],
snp_onehot_bits: int,
dense_layer_dims: List[int],
learning_rate: float,
regression: bool,
):
r"""A PyTorch Lightning module for SNP reduction and phenotype prediction.
"""
super().__init__()
self.save_hyperparameters()
self.output_dim = output_dim
self.learning_rate = learning_rate
self.regression = regression
self._define_metrics()
self.model = SNPReductionNetModel(
output_dim=output_dim,
blocks_gt=blocks_gt,
snp_onehot_bits=snp_onehot_bits,
dense_layer_dims=dense_layer_dims,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)
def training_step(self, batch, batch_idx):
x = batch[const.dkey.litdata_omics][0]
y = batch[const.dkey.litdata_label]
y_pred = self.forward(x)
loss = self._loss(y_pred, y, const.title_train)
return loss
def validation_step(self, batch, batch_idx):
x = batch[const.dkey.litdata_omics][0]
y = batch[const.dkey.litdata_label]
y_pred = self.forward(x)
loss = self._loss(y_pred, y, const.title_val)
return loss
def test_step(self, batch, batch_idx):
x = batch[const.dkey.litdata_omics][0]
y = batch[const.dkey.litdata_label]
y_pred = self.forward(x)
loss = self._loss(y_pred, y, const.title_test)
return loss
def predict_step(self, batch, batch_idx, dataloader_idx=None):
x = batch[const.dkey.litdata_omics][0]
y_pred = self.forward(x)
return y_pred
def configure_optimizers(self):
optimizer = Adam(self.parameters(), lr=self.learning_rate)
return optimizer
def _define_metrics(self):
"""
Define the loss function and the metrics.
"""
if self.output_dim == 1:
self.loss_fn = nn.MSELoss()
self.mae = MeanAbsoluteError()
self.r2 = R2Score()
self.pcc = PearsonCorrCoef()
else:
if self.regression:
# Multi-label regression
self.loss_fn = nn.MSELoss(reduction='none')
# self.mae = MeanAbsoluteError()
# self.r2 = R2Score()
# self.pcc = PearsonCorrCoef()
else:
self.loss_fn = nn.CrossEntropyLoss()
self.recall_micro = MulticlassRecall(average="micro", num_classes=self.output_dim)
self.recall_macro = MulticlassRecall(average="macro", num_classes=self.output_dim)
self.recall_weighted = MulticlassRecall(average="weighted", num_classes=self.output_dim)
#
self.precision_micro = MulticlassPrecision(average="micro", num_classes=self.output_dim)
self.precision_macro = MulticlassPrecision(average="macro", num_classes=self.output_dim)
self.precision_weighted = MulticlassPrecision(average="weighted", num_classes=self.output_dim)
#
self.f1_micro = MulticlassF1Score(average="micro", num_classes=self.output_dim)
self.f1_macro = MulticlassF1Score(average="macro", num_classes=self.output_dim)
self.f1_weighted = MulticlassF1Score(average="weighted", num_classes=self.output_dim)
#
self.accuracy_micro = MulticlassAccuracy(average="micro", num_classes=self.output_dim)
self.accuracy_macro = MulticlassAccuracy(average="macro", num_classes=self.output_dim)
self.accuracy_weighted = MulticlassAccuracy(average="weighted", num_classes=self.output_dim)
#
self.auroc_macro = MulticlassAUROC(average="macro", num_classes=self.output_dim)
self.auroc_weighted = MulticlassAUROC(average="weighted", num_classes=self.output_dim)
def _loss(self, y_pred: torch.Tensor, y: torch.Tensor, which_step: str):
#!!!!!!!!!!!!!!!!!!!!!!!!!
if self.output_dim > 1 and not self.regression:
y = y.argmax(dim=-1)
if self.output_dim == 1:
loss = self.loss_fn(y_pred, y)
self.log(f"{which_step}_loss", loss, sync_dist=True)
self.log(f"{which_step}_mae", self.mae(y_pred, y), sync_dist=True)
if y.shape[0] < 2:
return loss
self.log(f"{which_step}_pcc", self.pcc(y_pred, y), sync_dist=True)
self.log(f"{which_step}_r2", self.r2(y_pred, y), sync_dist=True)
else:
if self.regression:
loss = self.loss_fn(y_pred, y)
loss = loss.mean(dim=0)
loss = loss.sum()
self.log(f"{which_step}_loss", loss, sync_dist=True)
# self.log(f"{which_step}_mae", self.mae(y_pred, y), sync_dist=True)
# if y.shape[0] < 2:
# return loss
# self.log(f"{which_step}_pcc", self.pcc(y_pred, y), sync_dist=True)
# self.log(f"{which_step}_r2", self.r2(y_pred, y), sync_dist=True)
else:
loss = self.loss_fn(y_pred, y)
self.log(f"{which_step}_loss", loss, sync_dist=True)
self.log(f"{which_step}_f1_micro", self.f1_micro(y_pred, y), sync_dist=True)
self.log(f"{which_step}_f1_macro", self.f1_macro(y_pred, y), sync_dist=True)
self.log(f"{which_step}_f1_weighted", self.f1_weighted(y_pred, y), sync_dist=True)
#
self.log(f"{which_step}_recall_micro", self.recall_micro(y_pred, y), sync_dist=True)
self.log(f"{which_step}_recall_macro", self.recall_macro(y_pred, y), sync_dist=True)
self.log(f"{which_step}_recall_weighted", self.recall_weighted(y_pred, y), sync_dist=True)
#
self.log(f"{which_step}_precision_micro", self.precision_micro(y_pred, y), sync_dist=True)
self.log(f"{which_step}_precision_macro", self.precision_macro(y_pred, y), sync_dist=True)
self.log(f"{which_step}_precision_weighted", self.precision_weighted(y_pred, y), sync_dist=True)
#
self.log(f"{which_step}_accuracy_micro", self.accuracy_micro(y_pred, y), sync_dist=True)
self.log(f"{which_step}_accuracy_macro", self.accuracy_macro(y_pred, y), sync_dist=True)
self.log(f"{which_step}_accuracy_weighted", self.accuracy_weighted(y_pred, y), sync_dist=True)
#
self.log(f"{which_step}_auroc_macro", self.auroc_macro(y_pred, y), sync_dist=True)
self.log(f"{which_step}_auroc_weighted", self.auroc_weighted(y_pred, y), sync_dist=True)
return loss