Giới thiệu
Table of Content
- Giới thiệu
- Import & configure
- 0. Initial stage (Create meta file)
- I. First stage: Depth inference
- II. Second Stage (xy inference)
- III. Third Stage (calc. location of axial t2)
Import & configure
!pip install kaggle
from google.colab import files
uploaded = files.upload()
for fn in uploaded.keys():
print('User uploaded file "{name}" with length {length} bytes'.format(
name=fn, length=len(uploaded[fn])))
# Then move kaggle.json into the folder where the API expects to find it.
!mkdir -p ~/.kaggle/ && mv kaggle.json ~/.kaggle/ && chmod 600 ~/.kaggle/kaggle.json
from google.colab import drive
drive.mount('./gdrive')!kaggle datasets download -d brendanartley/lumbar-coordinate-pretraining-dataset
!unzip -qq "/content/lumbar-coordinate-pretraining-dataset.zip"!pip install -q pytorch-lightning & pip install -q -U albumentations & pip install -q iterative-stratification
!pip install -q timm & pip install -q einops & pip install -q pytorch-lightning wandb & pip install torch-ema
!git clone https://github.com/mlpc-ucsd/CoaT
!pip install -q pydicomimport gc
import wandb
from pytorch_lightning.loggers import WandbLogger
import os
import yaml
import sys
import cv2
import random
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch
from glob import glob
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.optim import AdamW, Adam
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, TQDMProgressBar
import torchvision.transforms as T
import albumentations as A
import pandas.api.types
import sklearn.metrics
import timm
import scipy
import albumentations as A
from torchvision.transforms import v2
from torchvision import models
from tqdm.auto import tqdm
from joblib import Parallel, delayed
from torch.utils.data import default_collate
import pydicom as dcm
import transformersSEED = 126 # friend's birthday
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def seed_everything(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True # Fix the network according to random seed
print('Finish seeding with seed {}'.format(seed))
seed_everything(SEED)
print('Training on device {}'.format(device))time
if config['debug']:
\#meta_df = pd.read_parquet('/kaggle/input/rsna-newmeta/meta.parquet')
meta_df_list = []
meta_df_list = Parallel(n_jobs=-1)([delayed(create_dcm_df)(row.study_id, row.series_id, row.series_description) for _, row in series.iterrows()])
\#for _, row in tqdm(test_series.iterrows(), total=len(test_series)):
# meta_df_list.append(create_dcm_df(row.study_id, row.series_id, row.series_description))
meta_df = pd.concat(meta_df_list)
del meta_df_list
gc.collect()
meta_df.to_parquet('meta.parquet')
else:
# spend about 20 min for train data
meta_df_list = []
meta_df_list = Parallel(n_jobs=-1)([delayed(create_dcm_df)(row.study_id, row.series_id, row.series_description) for _, row in series.iterrows()])
\#for _, row in tqdm(test_series.iterrows(), total=len(test_series)):
# meta_df_list.append(create_dcm_df(row.study_id, row.series_id, row.series_description))
meta_df = pd.concat(meta_df_list)
del meta_df_list
gc.collect()
meta_df.to_parquet('meta.parquet')CPU times: user 290 ms, sys: 99.4 ms, total: 390 ms
Wall time: 3.09 s
`time prefix = ” import warnings warnings.filterwarnings(“ignore”) depth_predict = {‘scs’: { ‘L1/L2’:[], ‘L2/L3’: [], ‘L3/L4’: [], ‘L4/L5’: [], ‘L5/S1’: [] }, ‘nfn’: { ‘left_L1/L2’: [], ‘left_L2/L3’: [], ‘left_L3/L4’: [], ‘left_L4/L5’: [], ‘left_L5/S1’: [], ‘right_L1/L2’: [], ‘right_L2/L3’: [], ‘right_L3/L4’: [], ‘right_L4/L5’: [], ‘right_L5/S1’: [], } } model_path_dict = { ‘scs’: [ ‘/kaggle/input/rsna-spine-final-models/scs_depth_1024_ssr_0.ckpt’, ‘/kaggle/input/rsna-spine-final-models/scs_depth_1024_ssr_1.ckpt’, ‘/kaggle/input/rsna-spine-final-models/scs_depth_1024_ssr_2.ckpt’, ‘/kaggle/input/rsna-spine-final-models/scs_depth_1024_ssr_3.ckpt’, ‘/kaggle/input/rsna-spine-final-models/scs_depth_1024_ssr_4.ckpt’, ‘/kaggle/input/rsna-spine-final-models/scs_depth_0.ckpt’, ‘/kaggle/input/rsna-spine-final-models/scs_depth_1.ckpt’, ‘/kaggle/input/rsna-spine-final-models/scs_depth_2.ckpt’, ‘/kaggle/input/rsna-spine-final-models/scs_depth_3.ckpt’, ‘/kaggle/input/rsna-spine-final-models/scs_depth_4.ckpt’, ‘/kaggle/input/rsna-spine-final-models/scs_depth_1024_ssr_l1_0.ckpt’, ‘/kaggle/input/rsna-spine-final-models/scs_depth_1024_ssr_l1_1.ckpt’, ‘/kaggle/input/rsna-spine-final-models/scs_depth_1024_ssr_l1_2.ckpt’, ‘/kaggle/input/rsna-spine-final-models/scs_depth_1024_ssr_l1_3.ckpt’, ‘/kaggle/input/rsna-spine-final-models/scs_depth_1024_ssr_l1_4.ckpt’, ], ‘nfn’: [ ‘/kaggle/input/rsna-spine-final-models/nfn_depth_1024_ssr_0.ckpt’, ‘/kaggle/input/rsna-spine-final-models/nfn_depth_1024_ssr_1.ckpt’, ‘/kaggle/input/rsna-spine-final-models/nfn_depth_1024_ssr_2.ckpt’, ‘/kaggle/input/rsna-spine-final-models/nfn_depth_1024_ssr_3.ckpt’, ‘/kaggle/input/rsna-spine-final-models/nfn_depth_1024_ssr_4.ckpt’, ‘/kaggle/input/rsna-spine-final-models/nfn_depth_0.ckpt’, ‘/kaggle/input/rsna-spine-final-models/nfn_depth_1.ckpt’, ‘/kaggle/input/rsna-spine-final-models/nfn_depth_2.ckpt’, ‘/kaggle/input/rsna-spine-final-models/nfn_depth_3.ckpt’, ‘/kaggle/input/rsna-spine-final-models/nfn_depth_4.ckpt’, ‘/kaggle/input/rsna-spine-final-models/nfn_depth_1024_ssr_l1_0.ckpt’, ‘/kaggle/input/rsna-spine-final-models/nfn_depth_1024_ssr_l1_1.ckpt’, ‘/kaggle/input/rsna-spine-final-models/nfn_depth_1024_ssr_l1_2.ckpt’, ‘/kaggle/input/rsna-spine-final-models/nfn_depth_1024_ssr_l1_3.ckpt’, ‘/kaggle/input/rsna-spine-final-models/nfn_depth_1024_ssr_l1_4.ckpt’, ] } ##############DEPTH DETECT######################### for condition in [‘nfn’, ‘scs’]: print(condition) model_path_list = model_path_dict[condition] for model_path in model_path_list: _meta_df = meta_df.copy() #_series = series.copy() dataset_test = DepthDetectDataset(_meta_df, condition, ‘sub’) data_loader_test = DataLoader( dataset_test, batch_size=config[“test_bs”], shuffle=False, num_workers=4, pin_memory=False ) model_name = model_path.split(’/’)[-1] if ‘1024’ in model_name: widths = [128, 256, 512, 1024] else: widths = [64, 128, 256, 512] if ‘l1’ in model_name: model_type = ‘regression’ else: model_type = ‘classification’ model = DepthDetectModule.load_from_checkpoint(model_path, condition=condition, widths=widths, model_type=model_type) model.eval() model.zero_grad() model.to(device) pred_temp = {} for k in depth_predict[condition].keys(): pred_temp[k] = [] study_id_list = [] with torch.no_grad(): for data in tqdm(data_loader_test, total=len(data_loader_test)): images, study_id = data images = images.to(device) preds = model.forward(images) #print(preds) if model_type == ‘regression’: for k, v in preds.items(): pred_temp[k].append((v[:, -1]*32).to(‘cpu’).detach().numpy()) else: for k, v in preds.items(): pred_temp[k].append(torch.argmax(v, dim=1).to(‘cpu’).detach().numpy()) study_id_list.append(study_id.to(‘cpu’).reshape(-1).detach().numpy()) del images, study_id, preds gc.collect() for k, v in pred_temp.items(): depth_predict[condition][k].append(np.concatenate(v)) study_id = np.concatenate(study_id_list) del pred_temp, study_id_list gc.collect()
for k, v in depth_predict[condition].items():
depth_predict[condition][k] = np.median(np.array(depth_predict[condition][k]), axis=0)
depth_predict[condition]['study_id'] = study_id
del study_id
gc.collect()
**Pipeline inference và ensemble** cho mô hình **Depth Detection**, với các mô hình và checkpiont khác nhau trên cả hai condition: `'scs'` và `'nfn'`.
Ensemble là **kỹ thuật kết hợp nhiều mô hình** lại với nhau để tạo ra **kết quả cuối cùng tốt hơn, ổn định hơn** so với chỉ dùng một mô hình duy nhất.
Tổng thể:
- **Duyệt qua từng loại mô hình (**`**scs**`**,** `**nfn**`**),** load từng checkpoint mô hình, chạy dự đoán trên toàn bộ test set.
- **Gộp (ensemble)** kết quả dự đoán từ nhiều mô hình khác nhau (nhiều checkpoint, nhiều fold) → lấy giá trị trung vị (median) của từng đốt sống cho mỗi study (patient).
- **Tối ưu cho đánh giá chính xác nhất trên leaderboard hoặc trong thực tế.**
**1. Khởi tạo dictionary lưu kết quả dự đoán**
```Python
depth_predict = {
'scs': {'L1/L2':[], ...},
'nfn': {'left_L1/L2':[], ..., 'right_L5/S1':[]}
}
- Mỗi key sẽ lưu lại tất cả dự đoán từng mô hình, từng batch, từng study.
- Giúp gộp lại cho bước ensemble.
2.
**model_path_dict** - Dict chứa danh sách các checkpoint mô hình đã train cho từng condition (
scs,nfn). - Các checkpoint này có thể là nhiều fold, nhiều config, nhiều loại (classification/regression).
3. Vòng lặp chính qua 2 condition:
for condition in ['nfn', 'scs']: ….→ Chạy lần lượt cho từng loại bệnh hoặc label khác nhau. 4. Lặp qua từng checkpoint mô hình Với mỗi mô hình đã train (checkpoint khác nhau), sẽ chạy dự đoán full test set. 5. Chuẩn bị dữ liệu test, xác định backbone và loại task:
dataset_test = DepthDetectDataset(_meta_df, condition, 'sub')
data_loader_test = DataLoader(dataset_test, batch_size=config["test_bs"], ...)
if '1024' in model_name:
widths = [128, 256, 512, 1024]
else:
widths = [64, 128, 256, 512]
if 'l1' in model_name:
model_type = 'regression'
else:
model_type = 'classification'- Chọn cấu trúc backbone phù hợp với mô hình (dựa vào tên file checkpoint).
- Dùng đúng loại regression/classification tương ứng. 6. Load model từ checkpoint
model = DepthDetectModule.load_from_checkpoint(
model_path, condition=condition, widths=widths, model_type=model_type
)
model.eval()
model.zero_grad()
model.to(device)- Load mô hình đã huấn luyện.
- Đưa sang chế độ eval, chuyển lên GPU/CPU tùy cấu hình. 7. Vòng lặp inference từng batch
with torch.no_grad():
for data in tqdm(data_loader_test, total=len(data_loader_test)):
images, study_id = data
images = images.to(device)
preds = model.forward(images)
if model_type == 'regression':
for k, v in preds.items():
pred_temp[k].append((v[:, -1]*32).to('cpu').detach().numpy())
else:
for k, v in preds.items():
pred_temp[k].append(torch.argmax(v, dim=1).to('cpu').detach().numpy())
study_id_list.append(study_id.to('cpu').reshape(-1).detach().numpy())- Nếu là regression, lấy giá trị cuối cùng trên trục class, nhân lại với depth chuẩn (
32), convert về numpy. - Nếu classification, lấy index lớn nhất (class dự đoán).
- Lưu lại dự đoán và study_id từng batch. 8. Gộp kết quả từng model
for k, v in pred_temp.items():
depth_predict[condition][k].append(np.concatenate(v))
study_id = np.concatenate(study_id_list)- Sau mỗi mô hình, gộp kết quả các batch thành một mảng lớn cho mỗi đốt sống. 9. Ensemble kết quả nhiều model/checkpoint
for k, v in depth_predict[condition].items():
depth_predict[condition][k] = np.median(np.array(depth_predict[condition][k]), axis=0)
depth_predict[condition]['study_id'] = study_id- Sau khi chạy hết tất cả checkpoint, tính giá trị trung vị (median) trên trục model cho mỗi đốt sống, mỗi study.
- Kết quả là dự đoán đã được ensemble tối ưu.
10. Dọn bộ nhớ liên tục với
**gc.collect()**
6. Create label coordinate & align
def create_label_ins(study_id, depth, level, condition, desc):
coor_dict = {'study_id': [], 'series_id': [], 'instance_number': []}
_meta = meta_df.loc[meta_df.series_description==desc]
for s, d in zip(study_id, depth):
sub_meta = _meta.loc[_meta.study_id==s]
sub_meta = sub_meta.sort_values('ipp_x', ascending=True).reset_index(drop=True)
if len(sub_meta) > 32:
d = (d/32)*len(sub_meta)
try:
row = sub_meta.iloc[round(d)]
except:
if condition == 'Spinal Canal Stenosis':
row = sub_meta.iloc[int(len(sub_meta)//2)]
elif condition == 'Left Neural Foraminal Narrowing':
row = sub_meta.iloc[int(2*(len(sub_meta)//3))]
elif condition == 'Right Neural Foraminal Narrowing':
row = sub_meta.iloc[int(len(sub_meta)//3)]
print(s)
coor_dict['study_id'].append(s)
coor_dict['series_id'].append(row.series_id)
coor_dict['instance_number'].append(row.instance_number)
coor_dict['condition'] = condition
coor_dict['level'] = level.split('_')[-1]
return pd.DataFrame(coor_dict)Từ một list study_id và dự đoán depth (tọa độ slice/lát), tìm lại thông tin DICOM slice tương ứng (series_id, instance_number, …) cho từng case.
Dùng để:
-
join với label thật trong evaluate
-
submit lên hệ thống chấm điểm tự động
-
hiển thị hình ảnh lát cắt đúng trên giao diện
scs_study_id = depth_predict['scs']['study_id']
scs_coor_list = []
for k, v in depth_predict['scs'].items():
if k != 'study_id':
scs_coor_list.append(create_label_ins(scs_study_id, v, k, 'Spinal Canal Stenosis', 'Sagittal T2/STIR'))
nfn_study_id = depth_predict['nfn']['study_id']
nfn_coor_list = []
for k, v in depth_predict['nfn'].items():
if k != 'study_id':
if k.split('_')[0] == 'left':
condition = 'Left Neural Foraminal Narrowing'
else:
condition = 'Right Neural Foraminal Narrowing'
nfn_coor_list.append(create_label_ins(nfn_study_id, v, k, condition, 'Sagittal T1'))
scs_coor = pd.concat(scs_coor_list)
nfn_coor = pd.concat(nfn_coor_list)
pred_coor = pd.concat([scs_coor, nfn_coor]).sort_values(['study_id', 'series_id', 'level'])- Lấy kết quả dự đoán từ hai nhóm bệnh: SCS (Spinal Canal Stenosis) và NFN (Neural Foraminal Narrowing).
- Duyệt qua từng vị trí (level) của từng nhóm, sử dụng hàm
create_label_insđể tìm lại thông tin lát ảnh thật trong DICOM ứng với dự đoán. - Gộp toàn bộ kết quả lại thành một bảng duy nhất, sắp xếp chuẩn bị cho các bước xử lý tiếp theo.
del scs_coor, nfn_coor, scs_coor_list, nfn_coor_list, depth_predict
gc.collect()
pred_coor.head()
pred_coor.to_csv('stage1_coor.csv', index=False)- Giải phóng RAM, tránh chiếm dụng không cần thiết (rất quan trọng khi làm việc với dữ liệu y tế lớn).
gc.collect(): ép Python thực hiện “garbage collection” ngay.- Hiển thị 5 dòng đầu của DataFrame
pred_coorđể kiểm tra lại dữ liệu trước khi lưu (optional, hay dùng trong notebook). - Ghi toàn bộ DataFrame
pred_coorra file CSV (không kèm chỉ số dòng)
II. Second Stage (xy inference)
Infer xy-coordinate of locations of sagittal t1 & t2
- Mục tiêu: Với mỗi lát ảnh sagittal (T1 hoặc T2), mô hình dự đoán tọa độ (x, y) của vị trí cần xác định (ví dụ: vị trí tổn thương, vị trí giữa hai đốt sống,…). Ensemble or align (rule base)
- Ensemble: Kết hợp nhiều mô hình hoặc nhiều phép thử lại để lấy kết quả ổn định hơn (trung bình, trung vị, voting…).
- Align (rule base): Có thể dùng các quy tắc (rule-based) để điều chỉnh lại vị trí, ví dụ:
- Nếu T1 và T2 khác biệt nhiều về tọa độ, chọn điểm gần trung tâm nhất.
- Nếu kết quả nằm ngoài vùng hợp lệ của ảnh, ép vào vùng hợp lệ.
- Nếu giữa hai lát (slice) dự đoán gần nhau, có thể lấy trung bình, hoặc chọn điểm đã có trong ground-truth.
1. Coordinate prediction dataset
Khởi tạo
class CoorDetectDataset(Dataset):
def __init__(self, coor, meta, condition, usage='train'):
if condition == 'scs':
coor = coor.loc[coor.condition=='Spinal Canal Stenosis']
elif condition == 'ss':
coor = coor.loc[(coor.condition=='Left Subarticular Stenosis') | (coor.condition=='Right Subarticular Stenosis')]
elif condition == 'nfn':
coor = coor.loc[(coor.condition=='Right Neural Foraminal Narrowing') | (coor.condition=='Left Neural Foraminal Narrowing')]
\#g_coor = coor.groupby('study_id').count()
\#if condition == 'scs':
# self.id = g_coor.loc[g_coor.series_id==5].reset_index().study_id.unique()
\#else:
# self.id = g_coor.loc[g_coor.series_id==10].reset_index().study_id.unique()
self.id = coor.study_id.unique()
self.coor = coor
self.meta = meta
self.condition = condition
self.usage = usage
if 3637444890 in self.id:
self.id.remove(3637444890)
\#self.id = [2773343225]
\#self.id = [1782095928]
self.resize = v2.Resize((384, 384))
def __getitem__(self, index):
study_id = self.id[index]
\#print(study_id)
\#try:
if self.condition == 'scs':
volume = self.for_scs(study_id)
elif self.condition == 'nfn':
volume = self.for_nfn(study_id)
if self.condition == 'ss':
volume = self.for_ss(study_id)
return volume, torch.tensor(study_id)Mục tiêu chính:
- Chọn ra các study phù hợp theo từng loại bệnh học (
**condition**). - Chuẩn bị sample để huấn luyện hoặc inference: với mỗi
study_id, trả về một volume ảnh đã tiền xử lý sẵn, cùng label hoặc thông tin kèm theo. - Dễ dàng mở rộng/biến đổi theo từng loại task (
scs,ss,nfn).
1. Lọc dữ liệu theo condition
- Lấy đúng các dòng label trong DataFrame
coorthuộc nhóm bệnh cần train/infer. 2. Lấy danh sách study hợp lệself.id = coor.study_id.unique() - Tạo mảng gồm tất cả study_id không trùng lặp cho bài toán. 3. Loại bỏ study lỗi hoặc không hợp lệ
- Tránh lỗi khi train/infer (study này thiếu hoặc hỏng file). 4. Lưu các thuộc tính cần thiết
self.coor: label filteredself.meta: metadata của DICOM (để truy xuất file, series, vị trí slice)self.condition: conditionself.usage: usageself.resize: transform resize ảnh về kích thước chuẩn
getitem
- Dựa vào
condition, gọi đúng hàm load volume (ảnh 3D) đã tiền xử lý:for_scs(study_id)cho Spinal Canal Stenosisfor_nfn(study_id)cho Neural Foraminal Narrowingfor_ss(study_id)cho Subarticular Stenosis
- Trả về:
volume: tensor ảnh 3D, đã chuẩn hóa kích thước (thường là[D, H, W])study_id: tensor chứa id bệnh nhân, dùng để join kết quả hoặc debug.
Func for_scs
def for_scs(self, study_id):
meta = self.meta.loc[(self.meta.study_id==study_id) & (self.meta.series_description=='Sagittal T2/STIR')]
meta = meta.sort_values('ipp_x', ascending=True).reset_index(drop=True)
\#img = [self.normalize(self.load_dicom(f'/content/train_images/{row.study_id}/{row.series_id}/{row.instance_number}.dcm')) for _, row in meta.iterrows()]
coor = self.coor.loc[(self.coor.study_id==study_id) & (self.coor.condition=='Spinal Canal Stenosis')]
meta_list = []
for _, row in coor.iterrows():
series_id, instance_number = row.series_id, row.instance_number
meta_list.append(meta.loc[(meta.series_id==series_id) & (meta.instance_number==instance_number)])
sub_meta = pd.concat(meta_list)
idx = meta.loc[meta.ipp_x == sub_meta.ipp_x.median()].index[0]
\#print(old_idx)
img_row = meta.iloc[idx]
before_img_row = meta.iloc[idx-1]
after_img_row = meta.iloc[idx+1]
img = self.normalize(self.load_dicom(IMAGE_PATH + f'{img_row.study_id}/{img_row.series_id}/{img_row.instance_number}.dcm'))
bimg = self.normalize(self.load_dicom(IMAGE_PATH + f'{before_img_row.study_id}/{before_img_row.series_id}/{before_img_row.instance_number}.dcm'))
aimg = self.normalize(self.load_dicom(IMAGE_PATH + f'{after_img_row.study_id}/{after_img_row.series_id}/{after_img_row.instance_number}.dcm'))
img = self.resize(torch.tensor(img[None, ...]))
bimg = self.resize(torch.tensor(bimg[None, ...]))
aimg = self.resize(torch.tensor(aimg[None, ...]))
img = torch.cat([bimg, img, aimg]).to(torch.float32)
return img1. Lọc metadata của study SCS với đúng series
meta = self.meta.loc[
(self.meta.study_id==study_id) &
(self.meta.series_description=='Sagittal T2/STIR')
]
meta = meta.sort_values('ipp_x', ascending=True).reset_index(drop=True)- Lấy tất cả lát ảnh thuộc đúng study và đúng loại MRI (
Sagittal T2/STIR). - Sắp xếp theo trục không gian (ipp_x) để đảm bảo đúng thứ tự lát.
- Tìm vị trí các lát liên quan đến tổn thương
coor = self.coor.loc[
(self.coor.study_id==study_id) &
(self.coor.condition=='Spinal Canal Stenosis')
]
meta_list = []
for _, row in coor.iterrows():
series_id, instance_number = row.series_id, row.instance_number
meta_list.append(meta.loc[
(meta.series_id==series_id) & (meta.instance_number==instance_number)
])
sub_meta = pd.concat(meta_list)- Tìm trong bảng label (self.coor) tất cả các instance của study này với condition SCS.
- Với từng label, tìm lại metadata của lát ảnh đó.
- Ghép lại thành bảng con
sub_meta– tập hợp các lát ảnh liên quan.
- Tìm slice trung tâm nhất trong vùng tổn thương
idx = meta.loc[meta.ipp_x == sub_meta.ipp_x.median()].index[0]- Chọn tọa độ ipp_x trung vị trong tập các lát liên quan (giữa vùng tổn thương) →
idxlà chỉ số lát trung tâm.
- Lấy 3 lát ảnh: trung tâm, trước và sau
img_row = meta.iloc[idx]
before_img_row = meta.iloc[idx-1]
after_img_row = meta.iloc[idx+1]img_row: lát trung tâmbefore_img_row: lát trước (idx-1)after_img_row: lát sau (idx+1)
- Đọc và chuẩn hóa 3 lát ảnh
img = self.normalize(self.load_dicom(IMAGE_PATH + f'{img_row.study_id}/{img_row.series_id}/{img_row.instance_number}.dcm'))
bimg = self.normalize(self.load_dicom(IMAGE_PATH + f'{before_img_row.study_id}/{before_img_row.series_id}/{before_img_row.instance_number}.dcm'))
aimg = self.normalize(self.load_dicom(IMAGE_PATH + f'{after_img_row.study_id}/{after_img_row.series_id}/{after_img_row.instance_number}.dcm'))- Đọc từng file DICOM, chuẩn hóa giá trị pixel (theo hàm bạn định nghĩa).
- Resize các lát về kích thước chuẩn
img = self.resize(torch.tensor(img[None, ...]))
bimg = self.resize(torch.tensor(bimg[None, ...]))
aimg = self.resize(torch.tensor(aimg[None, ...]))- Đưa ảnh từ numpy → tensor, thêm chiều channel
[1, H, W], resize về[1, 384, 384]
- Ghép 3 lát thành một tensor đầu vào
img = torch.cat([bimg, img, aimg]).to(torch.float32)- Output cuối: tensor shape
[3, 384, 384](3 channel: lát trước, trung tâm, sau). - Dùng làm input cho mô hình detect xy (mô hình sẽ nhìn cận cảnh hơn, tăng độ chính xác).
- Trả về ảnh
return img
def for_ss(self, study_id):
meta = self.meta.loc[(self.meta.study_id==study_id) & (self.meta.series_description=='Axial T2')]
meta = meta.sort_values('ipp_z', ascending=False).reset_index(drop=True)
img = [self.normalize(self.load_dicom(f'/content/train_images/{row.study_id}/{row.series_id}/{row.instance_number}.dcm')) for _, row in meta.iterrows()]
coor = self.coor.loc[(self.coor.study_id==study_id)]
coor_dict = {}
for _, row in coor.iterrows():
series_id, instance_number = row.series_id, row.instance_number
target_row = meta.loc[(meta.series_id==series_id) & (meta.instance_number==instance_number)]
idx = target_row.index[0]
\#print(row.level, idx, idx/len(img))
\#plt.title(row.level)
\#plt.imshow(img[idx])
\#mask = torch.zeros(img[idx].shape)
\#mask[int(row.y)-10:int((row.y))+10, int(row.x)-10:int((row.x))+10] = 1
\#plt.imshow(mask, alpha=0.5)
\#plt.show()
height, width = img[idx].shape
z = idx/depth if len(img) < depth else idx/len(img)
x = row.x/width
y = row.y/height
if row.condition == 'Right Subarticular Stenosis':
coor_dict['right_' + row.level] = torch.tensor([x, y, z]).to(torch.float32)
else:
coor_dict['left_' + row.level] = torch.tensor([x, y, z]).to(torch.float32)
volume = torch.cat([self.resize(torch.tensor(i)[None, ...]).to(torch.float32) for i in img]).contiguous()
if volume.shape[0] < depth:
volume = torch.cat([volume, torch.zeros(depth-volume.shape[0], volume.shape[1], volume.shape[2])])
elif volume.shape[0] > depth:
volume = torch.nn.functional.interpolate(volume[None, None, ...], (depth, volume.shape[1], volume.shape[2])).squeeze()
return volume, coor_dictdef for_nfn(self, study_id):
meta = self.meta.loc[(self.meta.study_id==study_id) & (self.meta.series_description=='Sagittal T1')]
meta = meta.sort_values('ipp_x', ascending=True).reset_index(drop=True)
\#img = [self.normalize(self.load_dicom(f'/content/train_images/{row.study_id}/{row.series_id}/{row.instance_number}.dcm')) for _, row in meta.iterrows()]
coor = self.coor.loc[(self.coor.study_id==study_id)]
right_meta_list = []
left_meta_list = []
for _, row in coor.iterrows():
series_id, instance_number = row.series_id, row.instance_number
if row.condition == 'Right Neural Foraminal Narrowing':
right_meta_list.append(meta.loc[(meta.series_id==series_id) & (meta.instance_number==instance_number)])
else:
left_meta_list.append(meta.loc[(meta.series_id==series_id) & (meta.instance_number==instance_number)])
right_sub_meta = pd.concat(right_meta_list)
left_sub_meta = pd.concat(left_meta_list)
ridx = meta.loc[meta.ipp_x == right_sub_meta.ipp_x.median()].index[0]
lidx = meta.loc[meta.ipp_x == left_sub_meta.ipp_x.median()].index[0]
right_img_row = meta.iloc[min(max(ridx, 0), len(meta)-1)]
\#display(right_img_row)
right_before_img_row = meta.iloc[min(max(ridx-1, 0), len(meta)-1)]
rightafter_img_row = meta.iloc[min(max(ridx+1, 0), len(meta)-1)]
left_img_row = meta.iloc[min(max(lidx, 0), len(meta)-1)]
left_before_img_row = meta.iloc[min(max(lidx-1, 0), len(meta)-1)]
leftafter_img_row = meta.iloc[min(max(lidx+1, 0), len(meta)-1)]
rimg = self.normalize(self.load_dicom(IMAGE_PATH + f'{right_img_row.study_id}/{right_img_row.series_id}/{right_img_row.instance_number}.dcm'))
rbimg = self.normalize(self.load_dicom(IMAGE_PATH + f'{right_before_img_row.study_id}/{right_before_img_row.series_id}/{right_before_img_row.instance_number}.dcm'))
raimg = self.normalize(self.load_dicom(IMAGE_PATH + f'{rightafter_img_row.study_id}/{rightafter_img_row.series_id}/{rightafter_img_row.instance_number}.dcm'))
limg = self.normalize(self.load_dicom(IMAGE_PATH + f'{left_img_row.study_id}/{left_img_row.series_id}/{left_img_row.instance_number}.dcm'))
lbimg = self.normalize(self.load_dicom(IMAGE_PATH + f'{left_before_img_row.study_id}/{left_before_img_row.series_id}/{left_before_img_row.instance_number}.dcm'))
laimg = self.normalize(self.load_dicom(IMAGE_PATH + f'{leftafter_img_row.study_id}/{leftafter_img_row.series_id}/{leftafter_img_row.instance_number}.dcm'))
rimg = torch.cat([self.resize(torch.tensor(i)[None, ...]).to(torch.float32) for i in [rbimg, rimg, raimg]])
limg = torch.cat([self.resize(torch.tensor(i)[None, ...]).to(torch.float32) for i in [lbimg, limg, laimg]])
img = torch.stack([limg, rimg]).to(torch.float32).contiguous()
return img def normalize(self, x):
lower, upper = np.percentile(x, (1, 99))
x = np.clip(x, lower, upper)
x = x - np.min(x)
x = x / np.max(x)
return x
def __len__(self):
return len(self.id)
def load_dicom(self, path):
dicom = dcm.read_file(path)
data = dicom.pixel_array
return data2. Coordinate prediction models
class ConvNextSCSDetect(nn.Module):
def __init__(self, encoder):
super().__init__()
\#self.size = 384
if encoder == 'convnext':
self.encoder = timm.create_model('convnext_base.fb_in22k_ft_in1k_384', in_chans=3, pretrained=False, num_classes=0)
elif encoder == 'efficientnetv2-l':
self.encoder = timm.create_model('tf_efficientnetv2_l.in21k_ft_in1k', in_chans=3, pretrained=False, num_classes=0, drop_rate=0.)
self.in_features = self.encoder.num_features
self.flatten = nn.Sequential(nn.AdaptiveAvgPool2d((1,1)),
nn.Flatten(1),
\#nn.LayerNorm(self.in_features)
)
self.l1 = nn.Linear(self.in_features, 2)
self.l2 = nn.Linear(self.in_features, 2)
self.l3 = nn.Linear(self.in_features, 2)
self.l4 = nn.Linear(self.in_features, 2)
self.l5 = nn.Linear(self.in_features, 2)
def forward(self, x, label=None):
\#for loc, img in x.items():
\#print(img.shape)
# img = self.encoder.forward_features(img)
# img = self.flatten(img)
# x[loc] = img
x = self.encoder.forward_features(x)
x = self.flatten(x)
l1 = self.l1(x)
l2 = self.l2(x)
l3 = self.l3(x)
l4 = self.l4(x)
l5 = self.l5(x)
return {'L1/L2': l1.sigmoid(), 'L2/L3': l2.sigmoid(), 'L3/L4': l3.sigmoid(), 'L4/L5': l4.sigmoid(), 'L5/S1': l5.sigmoid()}class ConvNextNFNDetect(nn.Module):
def __init__(self, encoder):
super().__init__()
if encoder == 'convnext':
self.encoder = timm.create_model('convnext_base.fb_in22k_ft_in1k_384', in_chans=3, pretrained=False, num_classes=0)
elif encoder == 'efficientnetv2-l':
self.encoder = timm.create_model('tf_efficientnetv2_l.in21k_ft_in1k', in_chans=3, pretrained=False, num_classes=0, drop_rate=0.)
self.in_features = self.encoder.num_features
self.flatten = nn.Sequential(nn.AdaptiveAvgPool2d((1,1)),
nn.Flatten(1),
\#nn.LayerNorm(self.in_features)
)
self.ll1 = nn.Linear(self.in_features, 2)
self.ll2 = nn.Linear(self.in_features, 2)
self.ll3 = nn.Linear(self.in_features, 2)
self.ll4 = nn.Linear(self.in_features, 2)
self.ll5 = nn.Linear(self.in_features, 2)
self.rl1 = nn.Linear(self.in_features, 2)
self.rl2 = nn.Linear(self.in_features, 2)
self.rl3 = nn.Linear(self.in_features, 2)
self.rl4 = nn.Linear(self.in_features, 2)
self.rl5 = nn.Linear(self.in_features, 2)
def forward(self, x, label=None):
shape = x.shape
x = x.reshape(shape[0]*shape[1], 3, shape[-2], shape[-1])
x = self.encoder.forward_features(x)
x = self.flatten(x)
x = x.reshape(shape[0], shape[1], -1)
x_left = x[:, 0, :]
x_right = x[:, 1, :]
ll1 = self.ll1(x_left)
ll2 = self.ll2(x_left)
ll3 = self.ll3(x_left)
ll4 = self.ll4(x_left)
ll5 = self.ll5(x_left)
rl1 = self.rl1(x_right)
rl2 = self.rl2(x_right)
rl3 = self.rl3(x_right)
rl4 = self.rl4(x_right)
rl5 = self.rl5(x_right)
return {'left_L1/L2': ll1.sigmoid(),'left_L2/L3': ll2.sigmoid(),'left_L3/L4': ll3.sigmoid(), 'left_L4/L5': ll4.sigmoid(), 'left_L5/S1': ll5.sigmoid(),
'right_L1/L2': rl1.sigmoid(), 'right_L2/L3': rl2.sigmoid(), 'right_L3/L4': rl3.sigmoid(), 'right_L4/L5': rl4.sigmoid(), 'right_L5/S1': rl5.sigmoid()}3. Coordinate detection lightning module
class DetectModule(pl.LightningModule):
def __init__(self, condition, encoder):
super().__init__()
self.config = condition
if condition == 'scs':
self.model = ConvNextSCSDetect(encoder)
elif condition == 'nfn':
self.model = ConvNextNFNDetect(encoder)
elif condition == 'ss':
pass
\#self.ema = ExponentialMovingAverage(self.model.parameters(), decay=0.995)
\#self.ema.to(device)
\#self.model = torch.optim.swa_utils.AveragedModel(self.model,
# multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.999))
def forward(self, batch):
preds = self.model(batch)
return preds4. Coordinate inference
time
def create_label_coor(study_id, coor_df, coor, level, condition, desc):
_meta = meta_df.loc[meta_df.series_description==desc].copy()
_coor = coor_df.loc[coor_df.condition == condition]
_coor_df = {'study_id': [], 'series_id': [], 'x': [], 'y': []}
for s, c in zip(study_id, coor):
sub_meta = _meta.loc[_meta.study_id == s]
sub_coor = _coor.loc[(_coor.study_id==s) & (_coor.level==level.split('_')[-1])].squeeze(axis=0)
\#display(sub_coor)
meta_row = sub_meta.loc[(sub_meta.instance_number==sub_coor.instance_number) & (sub_meta.series_id==sub_coor.series_id)].squeeze(axis=0)
x = round(meta_row.width * c[0])
y = round(meta_row.height * c[1])
_coor_df['study_id'].append(s)
_coor_df['series_id'].append(sub_coor.series_id)
_coor_df['x'].append(x)
_coor_df['y'].append(y)
_coor_df['level'] = level.split('_')[-1]
_coor_df['condition'] = condition
del _meta, _coor, sub_meta, sub_coor, meta_row
return pd.DataFrame(_coor_df)scs_study_id = coor_predict['scs']['study_id']
scs_coor_list = []
for k, v in coor_predict['scs'].items():
if k != 'study_id':
scs_coor_list.append(create_label_coor(scs_study_id, pred_coor, v, k, 'Spinal Canal Stenosis', 'Sagittal T2/STIR'))
nfn_study_id = coor_predict['nfn']['study_id']
nfn_coor_list = []
for k, v in coor_predict['nfn'].items():
if k != 'study_id':
if k.split('_')[0] == 'left':
condition = 'Left Neural Foraminal Narrowing'
else:
condition = 'Right Neural Foraminal Narrowing'
nfn_coor_list.append(create_label_coor(nfn_study_id, pred_coor, v, k, condition, 'Sagittal T1'))
scs_coor = pd.concat(scs_coor_list)
nfn_coor = pd.concat(nfn_coor_list)
_pred_coor = pd.concat([scs_coor, nfn_coor]).sort_values(['study_id', 'series_id', 'level'])pred_coor_stage2 = pd.merge(pred_coor, _pred_coor, on=['study_id', 'series_id', 'level', 'condition'], how='inner')display(pred_coor_stage2.head())
pred_coor_stage2.to_csv('stage2_coor.csv', index=False)| study_id | series_id | instance_number | condition | level | x | y | |
|---|---|---|---|---|---|---|---|
| 0 | 44036939 | 2828203845 | 9 | Left Neural Foraminal Narrowing | L1/L2 | 387 | 201 |
| 1 | 44036939 | 2828203845 | 18 | Right Neural Foraminal Narrowing | L1/L2 | 384 | 179 |
| 2 | 44036939 | 2828203845 | 9 | Left Neural Foraminal Narrowing | L2/L3 | 348 | 267 |
| 3 | 44036939 | 2828203845 | 18 | Right Neural Foraminal Narrowing | L2/L3 | 350 | 243 |
| 4 | 44036939 | 2828203845 | 9 | Left Neural Foraminal Narrowing | L3/L4 | 311 | 326 |
III. Third Stage (calc. location of axial t2)
- calcurate depth of axial t2 for each location roughly, using xyz-coordinate (refered to @hengck’s transformation from sagittal t2 to axial t2)
- roughly separate each locations
- infer instance number
- infer xy-coordinate
1. Calculate axial slice
# project 2d to 3d
def project_to_3d(row):
sx, sy, sz = row.ipp_x, row.ipp_y, row.ipp_z
x, y = row.x, row.y
o0, o1, o2, o3, o4, o5 = row.iop
delx, dely = row.ps_x, row.ps_y
xx = o0 * delx * x + o3 * dely * y + sx
yy = o1 * delx * x + o4 * dely * y + sy
zz = o2 * delx * x + o5 * dely * y + sz
return xx,yy,zzdef sag_to_ax(sub_coor, sub_meta):
point = sub_coor[['ipp_x', 'ipp_y', 'ipp_z']].values \#2d
level_list = sub_coor.level.tolist()
# here we project 2d to 3d
center=[]
for _, row in sub_coor.iterrows():
xx,yy,zz = project_to_3d(row)
center.append([xx,yy,zz])
center = np.array(center) \#3d
# == 2. we get closest axial slices to the CSC points =================
\#df = valid_data[0].axial_t2[0].df
orientation = np.array(sub_meta.iop.values.tolist())
position= np.array(sub_meta[['ipp_x', 'ipp_y', 'ipp_z']].values.tolist())
ox = orientation[:, :3]
oy = orientation[:, 3:]
oz = np.cross(ox,oy)
t = center.reshape(-1,1,3) - position.reshape(1,-1,3)
dis = (oz.reshape(1,-1,3) * t).sum(-1) # np.dot(point-s,oz)
dis = np.fabs(dis)
closest = dis.argmin(-1)
closest_df = sub_meta.iloc[closest]
closest_df['level'] = level_list#['L1/L2', 'L2/L3', 'L3/L4', 'L4/L5', 'L5/S1']
closest_df['x'] = 0
closest_df['y'] = 0
\#closest_df = pd.concat([closest_df, closest_df])
\#closest_df['condition'] = ['Left Subarticular Stenosis']*5 + ['Right Subarticular Stenosis']*5
return closest_df[['study_id', 'series_id', 'instance_number', 'level']]# sagittal t2 => axial t2
scs_coor = pred_coor_stage2.loc[pred_coor_stage2.condition=='Spinal Canal Stenosis'].copy()
scs_coor = scs_coor.merge(meta_df, on=['study_id', 'series_id', 'instance_number'], how='left')
study_id = scs_coor.study_id.unique()
ax_meta = meta_df.loc[(meta_df.series_description=='Axial T2')]
closest_ax_list = []
for s in tqdm(study_id, total=len(study_id)):
sub_coor = scs_coor.loc[scs_coor.study_id==s]
sub_meta = ax_meta.loc[ax_meta.study_id==s]
closest_ax_list.append(sag_to_ax(sub_coor, sub_meta)) closest_ax = pd.concat(closest_ax_list)
closest_ax.head()| study_id | series_id | instance_number | level | |
|---|---|---|---|---|
| 16 | 44036939 | 3481971518 | 17 | |
| 22 | 44036939 | 3481971518 | 23 | |
| 27 | 44036939 | 3481971518 | 28 | |
| 33 | 44036939 | 3481971518 | 34 | |
| 38 | 44036939 | 3481971518 | 39 |
2. Subarticular stenosis coordinate prediction dataset
class SSDetectDataset(Dataset):
def __init__(self, ax, usage='train'):
self.ax = ax
self.id = ax.study_id.unique()
self.usage = usage
self.id = list(set(self.id) - set([3637444890]))
\#self.id = [2773343225]
\#self.id = [1782095928]
self.resize = v2.Resize((384, 384))
def __getitem__(self, index):
study_id = self.id[index]
volume = self.for_ss(study_id)
return volume, torch.tensor(study_id) def for_ss(self, study_id):
ax = self.ax.loc[self.ax.study_id==study_id]
img_dict = {}
for _, row in ax.iterrows():
series_id, instance_number = row.series_id, row.instance_number
img = self.load_dicom(IMAGE_PATH + f'{study_id}/{series_id}/{instance_number}.dcm').astype(np.float32)
img = self.resize(torch.tensor(img)[None, ...])
img = self.normalize(img)
img_dict[row.level] = img
img_list = []
for k in ['L1/L2', 'L2/L3', 'L3/L4', 'L4/L5', 'L5/S1']:
img_list.append(img_dict[k])
volume = torch.stack(img_list).contiguous()
return volume
def normalize(self, x):
upper = torch.quantile(x, torch.tensor([0.99]))
lower = torch.quantile(x, torch.tensor([0.01]))
x = torch.clip(x, lower, upper)
x = x - torch.min(x)
x = x / (torch.max(x)+1e-6)
return x