Prerequisite: Embeddings Guide
Task Types: Binary classification, multi-class classification, survival prediction
Task Types: Binary classification, multi-class classification, survival prediction
Environment Activation
Copy
Ask AI
source standard_model/bin/activate
See the Quickstart Guide for environment creation and installation.
Setup
First, create dummy patient data and extract embeddings (see Embeddings Guide for details):Copy
Ask AI
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from smb_biopan_utils import process_ehr_info
# Load model
model_id = "standardmodelbio/SMB-v1-1.7B-Structure"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
device_map="auto"
)
model.eval()
Create Dummy Dataset
Generate synthetic patient cohort with labels:Copy
Ask AI
def create_dummy_cohort(n_patients=100):
"""Create dummy MEDS data for multiple patients with labels."""
np.random.seed(42)
all_data = []
labels_binary = [] # Readmission within 30 days
labels_multiclass = [] # Disease phenotype (0-3)
labels_survival = [] # Overall survival in months
conditions = ['ICD10:C34.90', 'ICD10:I25.10', 'ICD10:E11.9', 'ICD10:J44.9']
procedures = ['CPT:71260', 'CPT:93000', 'CPT:99213', 'CPT:43239']
medications = ['RxNorm:583214', 'RxNorm:197361', 'RxNorm:311671', 'RxNorm:83367']
labs = ['LOINC:2160-0', 'LOINC:718-7', 'LOINC:2093-3', 'LOINC:1742-6']
for i in range(n_patients):
pid = f'patient_{i:03d}'
n_events = np.random.randint(5, 12)
# Generate random clinical events
times = pd.date_range('2023-01-01', periods=n_events, freq='M')
patient_data = pd.DataFrame({
'subject_id': [pid] * n_events,
'time': times,
'code': np.random.choice(
conditions + procedures + medications + labs,
n_events
),
'table': np.random.choice(
['condition', 'procedure', 'medication', 'lab'],
n_events
),
'value': [
np.random.uniform(0.5, 2.0) if np.random.random() > 0.7 else None
for _ in range(n_events)
]
})
all_data.append(patient_data)
# Generate labels (correlated with patient index for reproducibility)
labels_binary.append(int(np.random.random() > 0.7))
labels_multiclass.append(np.random.randint(0, 4))
labels_survival.append(np.random.exponential(24) + 6) # Months
df = pd.concat(all_data, ignore_index=True)
return df, {
'binary': np.array(labels_binary),
'multiclass': np.array(labels_multiclass),
'survival': np.array(labels_survival)
}
# Create cohort
df_cohort, labels = create_dummy_cohort(n_patients=100)
patient_ids = [f'patient_{i:03d}' for i in range(100)]
print(f"Total events: {len(df_cohort)}")
print(f"Patients: {len(patient_ids)}")
print(f"Binary labels distribution: {np.bincount(labels['binary'])}")
print(f"Multiclass labels distribution: {np.bincount(labels['multiclass'])}")
Extract Embeddings for Cohort
Copy
Ask AI
def extract_cohort_embeddings(df, patient_ids, model, tokenizer, end_time):
"""Extract embeddings for all patients in cohort."""
embeddings = []
for pid in patient_ids:
input_text = process_ehr_info(
df=df,
subject_id=pid,
end_time=end_time
)
inputs = tokenizer(
input_text,
return_tensors="pt",
truncation=True,
max_length=2048
).to(model.device)
with torch.no_grad():
outputs = model(
input_ids=inputs.input_ids,
output_hidden_states=True,
return_dict=True
)
emb = outputs.hidden_states[-1][:, -1, :]
embeddings.append(emb.cpu())
return torch.cat(embeddings, dim=0)
# Extract embeddings (this may take a few minutes)
embeddings = extract_cohort_embeddings(
df=df_cohort,
patient_ids=patient_ids,
model=model,
tokenizer=tokenizer,
end_time=pd.Timestamp("2024-01-01")
)
print(f"Embeddings shape: {embeddings.shape}") # [100, hidden_dim]
Train/Test Split
Copy
Ask AI
from sklearn.model_selection import train_test_split
# Split indices
train_idx, test_idx = train_test_split(
range(len(patient_ids)),
test_size=0.2,
random_state=42
)
X_train = embeddings[train_idx]
X_test = embeddings[test_idx]
print(f"Train: {len(train_idx)}, Test: {len(test_idx)}")
Binary Classification
Predict 30-day hospital readmission:Copy
Ask AI
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, roc_auc_score, classification_report
# Prepare labels
y_train = labels['binary'][train_idx]
y_test = labels['binary'][test_idx]
# Train logistic regression
clf_binary = LogisticRegression(max_iter=1000, random_state=42)
clf_binary.fit(X_train.numpy(), y_train)
# Evaluate
y_pred = clf_binary.predict(X_test.numpy())
y_prob = clf_binary.predict_proba(X_test.numpy())[:, 1]
print("Binary Classification (Readmission Prediction)")
print(f"Accuracy: {accuracy_score(y_test, y_pred):.3f}")
print(f"AUC-ROC: {roc_auc_score(y_test, y_prob):.3f}")
print(classification_report(y_test, y_pred, target_names=['No Readmit', 'Readmit']))
PyTorch Version
For more control, use a PyTorch linear layer:Copy
Ask AI
class LinearProbe(nn.Module):
def __init__(self, input_dim, num_classes):
super().__init__()
self.linear = nn.Linear(input_dim, num_classes)
def forward(self, x):
return self.linear(x)
# Setup
hidden_dim = embeddings.shape[1]
probe_binary = LinearProbe(hidden_dim, 2)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(probe_binary.parameters(), lr=1e-3)
# Create dataloaders
train_dataset = TensorDataset(X_train, torch.tensor(y_train))
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
# Training loop
probe_binary.train()
for epoch in range(50):
total_loss = 0
for batch_x, batch_y in train_loader:
optimizer.zero_grad()
logits = probe_binary(batch_x)
loss = criterion(logits, batch_y)
loss.backward()
optimizer.step()
total_loss += loss.item()
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}")
# Evaluate
probe_binary.eval()
with torch.no_grad():
logits = probe_binary(X_test)
preds = logits.argmax(dim=1)
acc = (preds == torch.tensor(y_test)).float().mean()
print(f"Test Accuracy: {acc:.3f}")
Multi-Class Classification
Predict disease phenotype (4 classes):Copy
Ask AI
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
# Prepare labels
y_train_mc = labels['multiclass'][train_idx]
y_test_mc = labels['multiclass'][test_idx]
# Train
clf_multiclass = LogisticRegression(
multi_class='multinomial',
max_iter=1000,
random_state=42
)
clf_multiclass.fit(X_train.numpy(), y_train_mc)
# Evaluate
y_pred_mc = clf_multiclass.predict(X_test.numpy())
print("Multi-Class Classification (Disease Phenotype)")
print(f"Accuracy: {accuracy_score(y_test_mc, y_pred_mc):.3f}")
print(classification_report(
y_test_mc,
y_pred_mc,
target_names=['Lung Cancer', 'CAD', 'Diabetes', 'COPD']
))
# Confusion matrix
cm = confusion_matrix(y_test_mc, y_pred_mc)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=['Lung', 'CAD', 'Diabetes', 'COPD'],
yticklabels=['Lung', 'CAD', 'Diabetes', 'COPD'])
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Disease Phenotype Classification')
plt.savefig('confusion_matrix.png', dpi=150, bbox_inches='tight')
PyTorch Version
Copy
Ask AI
# Multi-class probe
probe_multiclass = LinearProbe(hidden_dim, 4)
optimizer = torch.optim.Adam(probe_multiclass.parameters(), lr=1e-3)
train_dataset_mc = TensorDataset(X_train, torch.tensor(y_train_mc))
train_loader_mc = DataLoader(train_dataset_mc, batch_size=16, shuffle=True)
probe_multiclass.train()
for epoch in range(50):
for batch_x, batch_y in train_loader_mc:
optimizer.zero_grad()
logits = probe_multiclass(batch_x)
loss = criterion(logits, batch_y)
loss.backward()
optimizer.step()
# Evaluate
probe_multiclass.eval()
with torch.no_grad():
logits = probe_multiclass(X_test)
preds = logits.argmax(dim=1)
acc = (preds == torch.tensor(y_test_mc)).float().mean()
print(f"Multi-class Test Accuracy: {acc:.3f}")
Overall Survival (OS) Prediction
Predict survival time using Cox proportional hazards or simple regression:Linear Regression Approach
Copy
Ask AI
from sklearn.linear_model import Ridge
from sklearn.metrics import mean_squared_error, mean_absolute_error
from scipy.stats import pearsonr
# Prepare labels (survival in months)
y_train_os = labels['survival'][train_idx]
y_test_os = labels['survival'][test_idx]
# Train ridge regression
reg_survival = Ridge(alpha=1.0)
reg_survival.fit(X_train.numpy(), y_train_os)
# Predict
y_pred_os = reg_survival.predict(X_test.numpy())
# Metrics
mse = mean_squared_error(y_test_os, y_pred_os)
mae = mean_absolute_error(y_test_os, y_pred_os)
corr, _ = pearsonr(y_test_os, y_pred_os)
print("Overall Survival Prediction")
print(f"MSE: {mse:.2f}")
print(f"MAE: {mae:.2f} months")
print(f"Pearson r: {corr:.3f}")
# Plot
plt.figure(figsize=(8, 6))
plt.scatter(y_test_os, y_pred_os, alpha=0.6)
plt.plot([0, max(y_test_os)], [0, max(y_test_os)], 'r--', label='Perfect')
plt.xlabel('Actual Survival (months)')
plt.ylabel('Predicted Survival (months)')
plt.title('Overall Survival Prediction')
plt.legend()
plt.savefig('survival_scatter.png', dpi=150, bbox_inches='tight')
Cox Proportional Hazards
For proper survival analysis with censoring:Copy
Ask AI
# pip install lifelines
from lifelines import CoxPHFitter
# Create dataframe for Cox model
# Assume 20% of patients are censored
censored = np.random.random(len(train_idx)) > 0.8
cox_train = pd.DataFrame(X_train.numpy())
cox_train['duration'] = y_train_os
cox_train['event'] = ~censored[:len(train_idx)]
# Fit Cox model
cph = CoxPHFitter(penalizer=0.1)
cph.fit(cox_train, duration_col='duration', event_col='event')
# Get risk scores for test set
cox_test = pd.DataFrame(X_test.numpy())
risk_scores = cph.predict_partial_hazard(cox_test)
print("Cox Proportional Hazards Model")
print(f"Concordance Index: {cph.concordance_index_:.3f}")
PyTorch Survival Regression
Copy
Ask AI
class SurvivalProbe(nn.Module):
def __init__(self, input_dim):
super().__init__()
self.linear = nn.Linear(input_dim, 1)
def forward(self, x):
return self.linear(x).squeeze(-1)
probe_survival = SurvivalProbe(hidden_dim)
optimizer = torch.optim.Adam(probe_survival.parameters(), lr=1e-3)
criterion_mse = nn.MSELoss()
train_dataset_os = TensorDataset(
X_train,
torch.tensor(y_train_os, dtype=torch.float32)
)
train_loader_os = DataLoader(train_dataset_os, batch_size=16, shuffle=True)
probe_survival.train()
for epoch in range(100):
for batch_x, batch_y in train_loader_os:
optimizer.zero_grad()
pred = probe_survival(batch_x)
loss = criterion_mse(pred, batch_y)
loss.backward()
optimizer.step()
# Evaluate
probe_survival.eval()
with torch.no_grad():
preds = probe_survival(X_test)
mse = criterion_mse(preds, torch.tensor(y_test_os, dtype=torch.float32))
print(f"Test MSE: {mse:.2f}")
Complete Example
Full pipeline from data to evaluation:Copy
Ask AI
import pandas as pd
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression, Ridge
from sklearn.metrics import accuracy_score, roc_auc_score, mean_absolute_error
from transformers import AutoModelForCausalLM, AutoTokenizer
from smb_biopan_utils import process_ehr_info
# 1. Load model
model_id = "standardmodelbio/SMB-v1-1.7B-Structure"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
torch_dtype=torch.float16,
device_map="auto"
)
model.eval()
# 2. Create/load your cohort data
df_cohort, labels = create_dummy_cohort(n_patients=100)
patient_ids = [f'patient_{i:03d}' for i in range(100)]
# 3. Extract embeddings
embeddings = extract_cohort_embeddings(
df=df_cohort,
patient_ids=patient_ids,
model=model,
tokenizer=tokenizer,
end_time=pd.Timestamp("2024-01-01")
)
# 4. Train/test split
train_idx, test_idx = train_test_split(range(100), test_size=0.2, random_state=42)
X_train, X_test = embeddings[train_idx], embeddings[test_idx]
# 5. Binary classification
clf = LogisticRegression(max_iter=1000)
clf.fit(X_train.numpy(), labels['binary'][train_idx])
y_prob = clf.predict_proba(X_test.numpy())[:, 1]
print(f"Binary AUC: {roc_auc_score(labels['binary'][test_idx], y_prob):.3f}")
# 6. Multi-class classification
clf_mc = LogisticRegression(multi_class='multinomial', max_iter=1000)
clf_mc.fit(X_train.numpy(), labels['multiclass'][train_idx])
y_pred_mc = clf_mc.predict(X_test.numpy())
print(f"Multiclass Acc: {accuracy_score(labels['multiclass'][test_idx], y_pred_mc):.3f}")
# 7. Survival regression
reg = Ridge(alpha=1.0)
reg.fit(X_train.numpy(), labels['survival'][train_idx])
y_pred_os = reg.predict(X_test.numpy())
print(f"Survival MAE: {mean_absolute_error(labels['survival'][test_idx], y_pred_os):.2f} months")
Tips
Feature Scaling
Feature Scaling
Embeddings from foundation models are typically well-scaled, but standardization can help:
Copy
Ask AI
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train.numpy())
X_test_scaled = scaler.transform(X_test.numpy())
Regularization
Regularization
Use regularization to prevent overfitting, especially with small cohorts:
Copy
Ask AI
# L2 regularization
clf = LogisticRegression(C=0.1, max_iter=1000) # Lower C = more regularization
# Or use Ridge for regression
reg = Ridge(alpha=10.0) # Higher alpha = more regularization
Cross-Validation
Cross-Validation
For robust evaluation, use cross-validation:
Copy
Ask AI
from sklearn.model_selection import cross_val_score
scores = cross_val_score(
LogisticRegression(max_iter=1000),
embeddings.numpy(),
labels['binary'],
cv=5,
scoring='roc_auc'
)
print(f"CV AUC: {scores.mean():.3f} ± {scores.std():.3f}")
Class Imbalance
Class Imbalance
Handle imbalanced classes with class weights:
Copy
Ask AI
clf = LogisticRegression(
class_weight='balanced',
max_iter=1000
)
