Skip to content

Commit

Permalink
Merge pull request #237 from dzenanz/master
Browse files Browse the repository at this point in the history
Remove SNR and CNR from training and inference
  • Loading branch information
dzenanz authored Nov 29, 2021
2 parents 8bb7f8f + 95fe0a5 commit dd222c0
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 27 deletions.
26 changes: 26 additions & 0 deletions miqa/learning/correlator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/usr/bin/env python3

import pandas as pd
from sklearn.metrics import confusion_matrix

df = pd.read_csv('M:/MIQA/data.csv') # manually converted TRUE/FALSE into 1/0
print(f'count NaN: {df.isnull().sum().sum()}')
correlation_df = df.corr()
correlation_df.to_csv('M:/MIQA/correlations2.csv')
print(correlation_df)

cm = pd.DataFrame(confusion_matrix(df['overall_qa_assessment'], df['cnr']))
cm.to_csv('M:/MIQA/oQA_CNR.csv')
print(cm)

cm = pd.DataFrame(confusion_matrix(df['overall_qa_assessment'], df['snr']))
cm.to_csv('M:/MIQA/oQA_SNR.csv')
print(cm)

# df = pd.read_csv('M:/MIQA/PredictHD_small/phenotype/bids_image_qc_information.tsv', sep='\t')
# df = df.drop(columns=['participant_id', 'session_id', 'series_number'])
# df.to_csv('M:/MIQA/dataQA.csv')
# print(f"count NaN: {df.isnull().sum().sum()}")
# correlation_df = df.corr()
# correlation_df.to_csv('M:/MIQA/correlations.csv')
# print(correlation_df)
4 changes: 2 additions & 2 deletions miqa/learning/models/miqaT1-val0.pth
Git LFS file not shown
4 changes: 2 additions & 2 deletions miqa/learning/models/miqaT1-val1.pth
Git LFS file not shown
4 changes: 2 additions & 2 deletions miqa/learning/models/miqaT1-val2.pth
Git LFS file not shown
4 changes: 2 additions & 2 deletions miqa/learning/models/miqaT1-val3.pth
Git LFS file not shown
4 changes: 2 additions & 2 deletions miqa/learning/models/miqaT1-val4.pth
Git LFS file not shown
8 changes: 2 additions & 6 deletions miqa/learning/nn_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

logger = logging.getLogger(__name__)

regression_count = 3 # first 3 values are overall QA, SNR and CNR
regression_count = 1 # use QA, ignore SNR and CNR
artifacts = [
'normal_variants',
'lesions',
Expand Down Expand Up @@ -154,11 +154,7 @@ def evaluate_model(model, data_loader, device, writer, epoch, run_name):


def label_results(result):
labeled_results = {
'overall_quality': clamp(result[0] / 10.0, 0.0, 1.0),
'signal_to_noise_ratio': clamp(result[1] / 10.0, 0.0, 1.0),
'contrast_to_noise_ratio': clamp(result[2] / 10.0, 0.0, 1.0),
}
labeled_results = {'overall_quality': clamp(result[0] / 10.0, 0.0, 1.0)}
for artifact_name, value in zip(artifacts, result[regression_count:]):
labeled_results[artifact_name] = clamp(value, 0.0, 1.0)
return labeled_results
Expand Down
14 changes: 3 additions & 11 deletions miqa/learning/nn_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,16 +163,8 @@ def forward(self, output, target):
qa_target = target[..., 0]
qa_loss = torch.mean((qa_out - qa_target) ** 2)

snr_out = output[..., 1]
snr_target = target[..., 1]
snr_loss = torch.mean((snr_out - snr_target) ** 2)

cnr_out = output[..., 2]
cnr_target = target[..., 2]
cnr_loss = torch.mean((cnr_out - cnr_target) ** 2)

# overall QA is more important than SNR and CNR
loss = 10 * qa_loss + snr_loss + cnr_loss
# overall QA is more important than individual artifacts
loss = 10 * qa_loss

for i in range(self.presence_count):
i_target = target[..., i + regression_count]
Expand Down Expand Up @@ -213,7 +205,7 @@ def create_train_and_test_data_loaders(df, count_train):
if exists:
images.append(row.file_path)

row_targets = [row.overall_qa_assessment, row.snr, row.cnr]
row_targets = [row.overall_qa_assessment]
for i in range(len(artifacts)):
artifact_value = row[artifact_column_indices[i]]
converted_result = convert_bool_to_int(artifact_value)
Expand Down

0 comments on commit dd222c0

Please sign in to comment.