Examples PyTorch SimcatsDataset
Imports
from simcats_datasets.loading.pytorch import SimcatsDataset
from simcats_datasets.support_functions.data_preprocessing import min_max_0_1, add_newaxis, only_two_classes
from simcats_datasets.loading.load_ground_truth import load_tct_masks, load_tct_by_dot_masks, load_idt_masks, load_ct_masks, load_ct_by_dot_masks, load_tc_region_masks, load_tc_region_minus_tct_masks, load_c_region_masks
from simcats_datasets.support_functions.pytorch_format_output import format_dict_csd_float_ground_truth_long
from matplotlib import pyplot as plt
from time import process_time
DIPlib -- a quantitative image analysis library
Version 3.5.1 (Jul 5 2024)
For more information see https://diplib.org
Configs Examples: Using Strings or Callables
It is possible to supply preprocessors, load_ground_truth etc. either as string or as Callables to the pytorch dataset class.
If not all data should be loaded, it is possible to supply specific_ids to be loaded (as list or range). Please look at the examples in this config.
# specific_ids as range
specific_ids = range(100)
# specific_ids as list
specific_ids = [0, 10, 90, 120]
# load_ground_truth using Callables
from simcats_datasets.loading.load_ground_truth import load_tct_masks
load_ground_truth = load_tct_masks
# load_ground_truth using strings
load_ground_truth = "load_tct_masks"
# data_preprocessors using Callables
from simcats_datasets.support_functions.data_preprocessing import min_max_0_1, add_newaxis
data_preprocessors = [min_max_0_1, add_newaxis]
# data_preprocessors using strings
data_preprocessors = ["min_max_0_1", "add_newaxis"]
# ground_truth_preprocessors using Callables
from simcats_datasets.support_functions.data_preprocessing import only_two_classes
ground_truth_preprocessors = [only_two_classes]
# ground_truth_preprocessors using strings
ground_truth_preprocessors = ["only_two_classes"]
Examples for Different Ground Truth Types
The implemented pytorch dataset class allows using several types of ground truths which can be used to train machine learning models for different tasks. For further information on the different ground truth types, please have a look at simcats_datasets.loading.load_ground_truth
.
h5_path = r"example_dataset_GaAs_v1_random_variations_v2.h5"
specific_ids = [1, 4, 8]
# Using Callables
data_preprocessors = None
# Using Callables
ground_truth_preprocessors = None
# using Callables
format_output = format_dict_csd_float_ground_truth_long
# set to True if the plots in this notebook should be saved as files
save_plots = False
if save_plots:
from pathlib import Path
Path("./example_plots").mkdir(parents=True, exist_ok=True)
TCT Masks
# Using Callables
load_ground_truth = load_tct_masks
ds = SimcatsDataset(h5_path=h5_path, specific_ids=specific_ids, data_preprocessors=data_preprocessors, ground_truth_preprocessors=ground_truth_preprocessors, load_ground_truth=load_ground_truth, format_output=format_output)
for i in range(len(ds)):
fig, ax = plt.subplots(1, 2, figsize=(6, 3))
ax[0].set_title('CSD')
ax[0].imshow(ds[i]["csd"], origin="lower", interpolation="none")
ax[1].set_title('Ground Truth')
ax[1].imshow(ds[i]["ground_truth"], origin="lower", interpolation="none", vmin=0)
if save_plots:
plt.savefig(f"./example_plots/example_{load_ground_truth.__name__}_{specific_ids[i]}.png", dpi=300)
plt.show()



TCT by Dot Masks
# Using Callables
load_ground_truth = load_tct_by_dot_masks
ds = SimcatsDataset(h5_path=h5_path, specific_ids=specific_ids, data_preprocessors=data_preprocessors, ground_truth_preprocessors=ground_truth_preprocessors, load_ground_truth=load_ground_truth, format_output=format_output)
for i in range(len(ds)):
fig, ax = plt.subplots(1, 2, figsize=(6, 3))
ax[0].set_title('CSD')
ax[0].imshow(ds[i]["csd"], origin="lower", interpolation="none")
ax[1].set_title('Ground Truth')
ax[1].imshow(ds[i]["ground_truth"], origin="lower", interpolation="none", vmin=0)
if save_plots:
plt.savefig(f"./example_plots/example_{load_ground_truth.__name__}_{i}.png", dpi=300)
plt.show()



IDT Masks
# Using Callables
load_ground_truth = load_idt_masks
ds = SimcatsDataset(h5_path=h5_path, specific_ids=specific_ids, data_preprocessors=data_preprocessors, ground_truth_preprocessors=ground_truth_preprocessors, load_ground_truth=load_ground_truth, format_output=format_output)
for i in range(len(ds)):
fig, ax = plt.subplots(1, 2, figsize=(6, 3))
ax[0].set_title('CSD')
ax[0].imshow(ds[i]["csd"], origin="lower", interpolation="none")
ax[1].set_title('Ground Truth')
ax[1].imshow(ds[i]["ground_truth"], origin="lower", interpolation="none", vmin=0)
if save_plots:
plt.savefig(f"./example_plots/example_{load_ground_truth.__name__}_{i}.png", dpi=300)
plt.show()



CT Masks
# Using Callables
load_ground_truth = load_ct_masks
ds = SimcatsDataset(h5_path=h5_path, specific_ids=specific_ids, data_preprocessors=data_preprocessors, ground_truth_preprocessors=ground_truth_preprocessors, load_ground_truth=load_ground_truth, format_output=format_output)
for i in range(len(ds)):
fig, ax = plt.subplots(1, 2, figsize=(6, 3))
ax[0].set_title('CSD')
ax[0].imshow(ds[i]["csd"], origin="lower", interpolation="none")
ax[1].set_title('Ground Truth')
ax[1].imshow(ds[i]["ground_truth"], origin="lower", interpolation="none", vmin=0)
if save_plots:
plt.savefig(f"./example_plots/example_{load_ground_truth.__name__}_{i}.png", dpi=300)
plt.show()



CT by Dot Masks
# Using Callables
load_ground_truth = load_ct_by_dot_masks
ds = SimcatsDataset(h5_path=h5_path, specific_ids=specific_ids, data_preprocessors=data_preprocessors, ground_truth_preprocessors=ground_truth_preprocessors, load_ground_truth=load_ground_truth, format_output=format_output)
for i in range(len(ds)):
fig, ax = plt.subplots(1, 2, figsize=(6, 3))
ax[0].set_title('CSD')
ax[0].imshow(ds[i]["csd"], origin="lower", interpolation="none")
ax[1].set_title('Ground Truth')
ax[1].imshow(ds[i]["ground_truth"], origin="lower", interpolation="none", vmin=0)
if save_plots:
plt.savefig(f"./example_plots/example_{load_ground_truth.__name__}_{i}.png", dpi=300)
plt.show()



TC Region Masks
# Using Callables
load_ground_truth = load_tc_region_masks
ds = SimcatsDataset(h5_path=h5_path, specific_ids=specific_ids, data_preprocessors=data_preprocessors, ground_truth_preprocessors=ground_truth_preprocessors, load_ground_truth=load_ground_truth, format_output=format_output)
for i in range(len(ds)):
fig, ax = plt.subplots(1, 2, figsize=(6, 3))
ax[0].set_title('CSD')
ax[0].imshow(ds[i]["csd"], origin="lower", interpolation="none")
ax[1].set_title('Ground Truth')
ax[1].imshow(ds[i]["ground_truth"], origin="lower", interpolation="none", vmin=0)
if save_plots:
plt.savefig(f"./example_plots/example_{load_ground_truth.__name__}_{i}.png", dpi=300)
plt.show()



TC Region Minus TCT Masks
# Using Callables
load_ground_truth = load_tc_region_minus_tct_masks
ds = SimcatsDataset(h5_path=h5_path, specific_ids=specific_ids, data_preprocessors=data_preprocessors, ground_truth_preprocessors=ground_truth_preprocessors, load_ground_truth=load_ground_truth, format_output=format_output)
for i in range(len(ds)):
fig, ax = plt.subplots(1, 2, figsize=(6, 3))
ax[0].set_title('CSD')
ax[0].imshow(ds[i]["csd"], origin="lower", interpolation="none")
ax[1].set_title('Ground Truth')
ax[1].imshow(ds[i]["ground_truth"], origin="lower", interpolation="none", vmin=0)
if save_plots:
plt.savefig(f"./example_plots/example_{load_ground_truth.__name__}_{i}.png", dpi=300)
plt.show()



C Region Masks
# Using Callables
load_ground_truth = load_c_region_masks
ds = SimcatsDataset(h5_path=h5_path, specific_ids=specific_ids, data_preprocessors=data_preprocessors, ground_truth_preprocessors=ground_truth_preprocessors, load_ground_truth=load_ground_truth, format_output=format_output)
for i in range(len(ds)):
fig, ax = plt.subplots(1, 2, figsize=(6, 3))
ax[0].set_title('CSD')
ax[0].imshow(ds[i]["csd"], origin="lower", interpolation="none")
ax[1].set_title('Ground Truth')
ax[1].imshow(ds[i]["ground_truth"], origin="lower", interpolation="none", vmin=0)
if save_plots:
plt.savefig(f"./example_plots/example_{load_ground_truth.__name__}_{i}.png", dpi=300)
plt.show()



Examples for Preprocessors
The dataset class allows to supply preprocessors for the CSD data and for the ground truth. For example it might make sense to standardize the input data for machine learning models or to add a new axis to the CSD data, as it is missing the color channel axis which is typically expected. A common example where a ground truth preprocessing makes sense, is if one wants to train a model with just two classes, for example only differentiating between charge transitions and background, without labeling the transitions independently. Please have a look at simcats_datasets.support_functions.preprocessing for further preprocessors.
Adding the Color-Channel-Axis
ds = SimcatsDataset(h5_path=h5_path, specific_ids=[0], load_ground_truth="load_tct_masks")
print(f"Original CSD shape: {ds[0]['csd'].shape}")
ds = SimcatsDataset(h5_path=h5_path, specific_ids=[0], data_preprocessors=["add_newaxis"], load_ground_truth="load_tct_masks")
print(f"CSD shape with preprocessor that adds a new axis: {ds[0]['csd'].shape}")
Original CSD shape: torch.Size([96, 96])
CSD shape with preprocessor that adds a new axis: torch.Size([1, 96, 96])
Restrict Ground Truth to Only Two Classes
ds = SimcatsDataset(h5_path=h5_path, specific_ids=specific_ids[:1], load_ground_truth="load_tct_masks")
fig, ax = plt.subplots(1, 2, figsize=(6, 3))
ax[0].set_title('CSD')
ax[0].imshow(ds[0]["csd"], origin="lower", interpolation="none")
ax[1].set_title('Ground Truth')
ax[1].imshow(ds[0]["ground_truth"], origin="lower", interpolation="none", vmin=0)
plt.suptitle(f"CSD with original ground truth (TCT mask)")
plt.show()
ds = SimcatsDataset(h5_path=h5_path, specific_ids=specific_ids[:1], ground_truth_preprocessors=["only_two_classes"], load_ground_truth="load_tct_masks")
fig, ax = plt.subplots(1, 2, figsize=(6, 3))
ax[0].set_title('CSD')
ax[0].imshow(ds[0]["csd"], origin="lower", interpolation="none")
ax[1].set_title('Ground Truth')
ax[1].imshow(ds[0]["ground_truth"], origin="lower", interpolation="none", vmin=0)
plt.suptitle(f"CSD with ground truth (TCT mask) limited to two classes")
plt.show()


Preloading vs. Loading on Demand
SimCATS-Datasets allows to preload all data into memory. This is especially useful, if the data is required multiple times, for example for a multiple epoch training of a machine learning network.
Config
# number of repetitions (calls to __getitem__)
repetitions = 10000
# Using Callables
load_ground_truth = load_tct_masks
# Using Callables
data_preprocessors = [min_max_0_1, add_newaxis]
# Using Callables
ground_truth_preprocessors = [only_two_classes]
Load and Measure Time
# with preloading
ds_preload = SimcatsDataset(h5_path=h5_path, specific_ids=None, data_preprocessors=data_preprocessors, ground_truth_preprocessors=ground_truth_preprocessors, load_ground_truth=load_ground_truth)
start_time = process_time()
for i in range(repetitions):
_temp = ds_preload[0]
stop_time = process_time()
print(f"Time [requesting data, with preloading, {repetitions} images]: {stop_time - start_time} seconds")
# without preloading
ds_no_preload = SimcatsDataset(h5_path=h5_path, specific_ids=None, data_preprocessors=data_preprocessors, ground_truth_preprocessors=ground_truth_preprocessors, load_ground_truth=load_ground_truth, preload=False)
start_time = process_time()
for i in range(repetitions):
_temp = ds_no_preload[0]
stop_time = process_time()
print(f"Time [requesting data, without preloading, {repetitions} images]: {stop_time - start_time} seconds")
Time [requesting data, with preloading, 10000 images]: 0.140625 seconds
Time [requesting data, without preloading, 10000 images]: 15.03125 seconds