Skip to content

Commit

Permalink
use myutil as the external package for some common medical image oper…
Browse files Browse the repository at this point in the history
…ations.
  • Loading branch information
Jingnan-Jia committed Oct 29, 2020
1 parent 51664ab commit 27998a7
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 142 deletions.
11 changes: 6 additions & 5 deletions .idea/segmentation_metrics.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

151 changes: 104 additions & 47 deletions .idea/workspace.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

66 changes: 3 additions & 63 deletions seg_metrics/seg_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,57 +9,11 @@
import PySimpleGUI as gui
import matplotlib.pyplot as plt
import glob
# %%
def load_itk(filename):
'''
:param filename: absolute file path
:return: ct, origin, spacing, all of them has coordinate (z,y,x) if filename exists. Otherwise, 3 empty list.
'''
# print('start load data')
# Reads the image using SimpleITK
if (os.path.isfile(filename)):
itkimage = sitk.ReadImage(filename)

else:
print('nonfound:', filename)
return [], [], []

# Convert the image to a numpy array first ands then shuffle the dimensions to get axis in the order z,y,x
ct_scan = sitk.GetArrayFromImage(itkimage)

# ct_scan[ct_scan>4] = 0 #filter trachea (label 5)
# Read the origin of the ct_scan, will be used to convert the coordinates from world to voxel and vice versa.
origin = np.array(list(reversed(itkimage.GetOrigin()))) # note: after reverseing, origin=(z,y,x)

# Read the spacing along each dimension
spacing = np.array(list(reversed(itkimage.GetSpacing()))) # note: after reverseing, spacing =(z,y,x)
orientation = itkimage.GetDirection()
if (orientation[-1] == -1):
ct_scan = ct_scan[::-1]

return ct_scan, origin, spacing


def get_gdth_pred_names(gdth_path, pred_path):
gdth_files = sorted(glob.glob(gdth_path + '/*' + '.nrrd'))
gdth_files.extend(sorted(glob.glob(gdth_path + '/*' + '.mhd')))
gdth_files.extend(sorted(glob.glob(gdth_path + '/*' + '.mha')))
import sys
from myutil.myutil import load_itk, get_gdth_pred_names, one_hot_encode_3d

pred_files = sorted(glob.glob(pred_path + '/*' + '.nrrd'))
pred_files.extend(sorted(glob.glob(pred_path + '/*' + '.mhd')))
pred_files.extend(sorted(glob.glob(pred_path + '/*' + '.mha')))

if len(gdth_files) == 0:
raise Exception('ground truth files are None, Please check the directories', gdth_path)
if len(pred_files) == 0:
raise Exception(' predicted files are None, Please check the directories', pred_path)

if len(pred_files) < len(gdth_files): # only predict several ct
gdth_files = gdth_files[:len(pred_files)]

return gdth_files, pred_files

# %%
def show_itk(itk, idx):
ref_surface_array = sitk.GetArrayViewFromImage(itk)
plt.figure()
Expand Down Expand Up @@ -257,20 +211,6 @@ def get_metrics_dict_all_labels(labels, gdth, pred, spacing, metrics_type):

return metrics_dict

def one_hot_encode_3D(patch, labels):

labels = np.array(labels) # i.e. [0,4,5,6,7,8]
patches = []
for i, l in enumerate(labels):
a = np.where(patch != l, 0, 1)
patches.append(a)

patches = np.array(patches)
patches = np.rollaxis(patches, 0, len(patches.shape)) # from [6, 64, 128, 128] to [64, 128, 128, 6]?

return np.float64(patches)



def write_metrics(labels, gdth_path, pred_path, csv_file, metrics=None):
"""
Expand Down
Loading

0 comments on commit 27998a7

Please sign in to comment.