Skip to main content
Prerequisite: Embeddings Guide
Task Types: Binary classification, multi-class classification, survival prediction
Linear probing trains a simple linear classifier on top of frozen foundation model embeddings. This approach is fast, requires minimal data, and evaluates the quality of learned representations.

Environment Activation

source standard_model/bin/activate
See the Quickstart Guide for environment creation and installation.

Setup

First, create dummy patient data and extract embeddings (see Embeddings Guide for details):
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from smb_biopan_utils import process_ehr_info

# Load model
model_id = "standardmodelbio/SMB-v1-1.7B-Structure"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    trust_remote_code=True,
    device_map="auto"
)
model.eval()

Create Dummy Dataset

Generate synthetic patient cohort with labels:
def create_dummy_cohort(n_patients=100):
    """Create dummy MEDS data for multiple patients with labels."""
    
    np.random.seed(42)
    
    all_data = []
    labels_binary = []      # Readmission within 30 days
    labels_multiclass = []  # Disease phenotype (0-3)
    labels_survival = []    # Overall survival in months
    
    conditions = ['ICD10:C34.90', 'ICD10:I25.10', 'ICD10:E11.9', 'ICD10:J44.9']
    procedures = ['CPT:71260', 'CPT:93000', 'CPT:99213', 'CPT:43239']
    medications = ['RxNorm:583214', 'RxNorm:197361', 'RxNorm:311671', 'RxNorm:83367']
    labs = ['LOINC:2160-0', 'LOINC:718-7', 'LOINC:2093-3', 'LOINC:1742-6']
    
    for i in range(n_patients):
        pid = f'patient_{i:03d}'
        n_events = np.random.randint(5, 12)
        
        # Generate random clinical events
        times = pd.date_range('2023-01-01', periods=n_events, freq='M')
        
        patient_data = pd.DataFrame({
            'subject_id': [pid] * n_events,
            'time': times,
            'code': np.random.choice(
                conditions + procedures + medications + labs, 
                n_events
            ),
            'table': np.random.choice(
                ['condition', 'procedure', 'medication', 'lab'],
                n_events
            ),
            'value': [
                np.random.uniform(0.5, 2.0) if np.random.random() > 0.7 else None 
                for _ in range(n_events)
            ]
        })
        
        all_data.append(patient_data)
        
        # Generate labels (correlated with patient index for reproducibility)
        labels_binary.append(int(np.random.random() > 0.7))
        labels_multiclass.append(np.random.randint(0, 4))
        labels_survival.append(np.random.exponential(24) + 6)  # Months
    
    df = pd.concat(all_data, ignore_index=True)
    
    return df, {
        'binary': np.array(labels_binary),
        'multiclass': np.array(labels_multiclass),
        'survival': np.array(labels_survival)
    }

# Create cohort
df_cohort, labels = create_dummy_cohort(n_patients=100)
patient_ids = [f'patient_{i:03d}' for i in range(100)]

print(f"Total events: {len(df_cohort)}")
print(f"Patients: {len(patient_ids)}")
print(f"Binary labels distribution: {np.bincount(labels['binary'])}")
print(f"Multiclass labels distribution: {np.bincount(labels['multiclass'])}")

Extract Embeddings for Cohort

def extract_cohort_embeddings(df, patient_ids, model, tokenizer, end_time):
    """Extract embeddings for all patients in cohort."""
    
    embeddings = []
    
    for pid in patient_ids:
        input_text = process_ehr_info(
            df=df,
            subject_id=pid,
            end_time=end_time
        )
        
        inputs = tokenizer(
            input_text,
            return_tensors="pt",
            truncation=True,
            max_length=2048
        ).to(model.device)
        
        with torch.no_grad():
            outputs = model(
                input_ids=inputs.input_ids,
                output_hidden_states=True,
                return_dict=True
            )
            emb = outputs.hidden_states[-1][:, -1, :]
            embeddings.append(emb.cpu())
    
    return torch.cat(embeddings, dim=0)

# Extract embeddings (this may take a few minutes)
embeddings = extract_cohort_embeddings(
    df=df_cohort,
    patient_ids=patient_ids,
    model=model,
    tokenizer=tokenizer,
    end_time=pd.Timestamp("2024-01-01")
)

print(f"Embeddings shape: {embeddings.shape}")  # [100, hidden_dim]

Train/Test Split

from sklearn.model_selection import train_test_split

# Split indices
train_idx, test_idx = train_test_split(
    range(len(patient_ids)),
    test_size=0.2,
    random_state=42
)

X_train = embeddings[train_idx]
X_test = embeddings[test_idx]

print(f"Train: {len(train_idx)}, Test: {len(test_idx)}")

Binary Classification

Predict 30-day hospital readmission:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, roc_auc_score, classification_report

# Prepare labels
y_train = labels['binary'][train_idx]
y_test = labels['binary'][test_idx]

# Train logistic regression
clf_binary = LogisticRegression(max_iter=1000, random_state=42)
clf_binary.fit(X_train.numpy(), y_train)

# Evaluate
y_pred = clf_binary.predict(X_test.numpy())
y_prob = clf_binary.predict_proba(X_test.numpy())[:, 1]

print("Binary Classification (Readmission Prediction)")
print(f"Accuracy: {accuracy_score(y_test, y_pred):.3f}")
print(f"AUC-ROC: {roc_auc_score(y_test, y_prob):.3f}")
print(classification_report(y_test, y_pred, target_names=['No Readmit', 'Readmit']))

PyTorch Version

For more control, use a PyTorch linear layer:
class LinearProbe(nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.linear = nn.Linear(input_dim, num_classes)
    
    def forward(self, x):
        return self.linear(x)

# Setup
hidden_dim = embeddings.shape[1]
probe_binary = LinearProbe(hidden_dim, 2)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(probe_binary.parameters(), lr=1e-3)

# Create dataloaders
train_dataset = TensorDataset(X_train, torch.tensor(y_train))
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

# Training loop
probe_binary.train()
for epoch in range(50):
    total_loss = 0
    for batch_x, batch_y in train_loader:
        optimizer.zero_grad()
        logits = probe_binary(batch_x)
        loss = criterion(logits, batch_y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}")

# Evaluate
probe_binary.eval()
with torch.no_grad():
    logits = probe_binary(X_test)
    preds = logits.argmax(dim=1)
    acc = (preds == torch.tensor(y_test)).float().mean()
    print(f"Test Accuracy: {acc:.3f}")

Multi-Class Classification

Predict disease phenotype (4 classes):
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

# Prepare labels
y_train_mc = labels['multiclass'][train_idx]
y_test_mc = labels['multiclass'][test_idx]

# Train
clf_multiclass = LogisticRegression(
    multi_class='multinomial',
    max_iter=1000,
    random_state=42
)
clf_multiclass.fit(X_train.numpy(), y_train_mc)

# Evaluate
y_pred_mc = clf_multiclass.predict(X_test.numpy())

print("Multi-Class Classification (Disease Phenotype)")
print(f"Accuracy: {accuracy_score(y_test_mc, y_pred_mc):.3f}")
print(classification_report(
    y_test_mc, 
    y_pred_mc,
    target_names=['Lung Cancer', 'CAD', 'Diabetes', 'COPD']
))

# Confusion matrix
cm = confusion_matrix(y_test_mc, y_pred_mc)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=['Lung', 'CAD', 'Diabetes', 'COPD'],
            yticklabels=['Lung', 'CAD', 'Diabetes', 'COPD'])
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Disease Phenotype Classification')
plt.savefig('confusion_matrix.png', dpi=150, bbox_inches='tight')

PyTorch Version

# Multi-class probe
probe_multiclass = LinearProbe(hidden_dim, 4)
optimizer = torch.optim.Adam(probe_multiclass.parameters(), lr=1e-3)

train_dataset_mc = TensorDataset(X_train, torch.tensor(y_train_mc))
train_loader_mc = DataLoader(train_dataset_mc, batch_size=16, shuffle=True)

probe_multiclass.train()
for epoch in range(50):
    for batch_x, batch_y in train_loader_mc:
        optimizer.zero_grad()
        logits = probe_multiclass(batch_x)
        loss = criterion(logits, batch_y)
        loss.backward()
        optimizer.step()

# Evaluate
probe_multiclass.eval()
with torch.no_grad():
    logits = probe_multiclass(X_test)
    preds = logits.argmax(dim=1)
    acc = (preds == torch.tensor(y_test_mc)).float().mean()
    print(f"Multi-class Test Accuracy: {acc:.3f}")

Overall Survival (OS) Prediction

Predict survival time using Cox proportional hazards or simple regression:

Linear Regression Approach

from sklearn.linear_model import Ridge
from sklearn.metrics import mean_squared_error, mean_absolute_error
from scipy.stats import pearsonr

# Prepare labels (survival in months)
y_train_os = labels['survival'][train_idx]
y_test_os = labels['survival'][test_idx]

# Train ridge regression
reg_survival = Ridge(alpha=1.0)
reg_survival.fit(X_train.numpy(), y_train_os)

# Predict
y_pred_os = reg_survival.predict(X_test.numpy())

# Metrics
mse = mean_squared_error(y_test_os, y_pred_os)
mae = mean_absolute_error(y_test_os, y_pred_os)
corr, _ = pearsonr(y_test_os, y_pred_os)

print("Overall Survival Prediction")
print(f"MSE: {mse:.2f}")
print(f"MAE: {mae:.2f} months")
print(f"Pearson r: {corr:.3f}")

# Plot
plt.figure(figsize=(8, 6))
plt.scatter(y_test_os, y_pred_os, alpha=0.6)
plt.plot([0, max(y_test_os)], [0, max(y_test_os)], 'r--', label='Perfect')
plt.xlabel('Actual Survival (months)')
plt.ylabel('Predicted Survival (months)')
plt.title('Overall Survival Prediction')
plt.legend()
plt.savefig('survival_scatter.png', dpi=150, bbox_inches='tight')

Cox Proportional Hazards

For proper survival analysis with censoring:
# pip install lifelines
from lifelines import CoxPHFitter

# Create dataframe for Cox model
# Assume 20% of patients are censored
censored = np.random.random(len(train_idx)) > 0.8

cox_train = pd.DataFrame(X_train.numpy())
cox_train['duration'] = y_train_os
cox_train['event'] = ~censored[:len(train_idx)]

# Fit Cox model
cph = CoxPHFitter(penalizer=0.1)
cph.fit(cox_train, duration_col='duration', event_col='event')

# Get risk scores for test set
cox_test = pd.DataFrame(X_test.numpy())
risk_scores = cph.predict_partial_hazard(cox_test)

print("Cox Proportional Hazards Model")
print(f"Concordance Index: {cph.concordance_index_:.3f}")

PyTorch Survival Regression

class SurvivalProbe(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.linear = nn.Linear(input_dim, 1)
    
    def forward(self, x):
        return self.linear(x).squeeze(-1)

probe_survival = SurvivalProbe(hidden_dim)
optimizer = torch.optim.Adam(probe_survival.parameters(), lr=1e-3)
criterion_mse = nn.MSELoss()

train_dataset_os = TensorDataset(
    X_train, 
    torch.tensor(y_train_os, dtype=torch.float32)
)
train_loader_os = DataLoader(train_dataset_os, batch_size=16, shuffle=True)

probe_survival.train()
for epoch in range(100):
    for batch_x, batch_y in train_loader_os:
        optimizer.zero_grad()
        pred = probe_survival(batch_x)
        loss = criterion_mse(pred, batch_y)
        loss.backward()
        optimizer.step()

# Evaluate
probe_survival.eval()
with torch.no_grad():
    preds = probe_survival(X_test)
    mse = criterion_mse(preds, torch.tensor(y_test_os, dtype=torch.float32))
    print(f"Test MSE: {mse:.2f}")

Complete Example

Full pipeline from data to evaluation:
import pandas as pd
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression, Ridge
from sklearn.metrics import accuracy_score, roc_auc_score, mean_absolute_error
from transformers import AutoModelForCausalLM, AutoTokenizer
from smb_biopan_utils import process_ehr_info

# 1. Load model
model_id = "standardmodelbio/SMB-v1-1.7B-Structure"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    trust_remote_code=True,
    torch_dtype=torch.float16,
    device_map="auto"
)
model.eval()

# 2. Create/load your cohort data
df_cohort, labels = create_dummy_cohort(n_patients=100)
patient_ids = [f'patient_{i:03d}' for i in range(100)]

# 3. Extract embeddings
embeddings = extract_cohort_embeddings(
    df=df_cohort,
    patient_ids=patient_ids,
    model=model,
    tokenizer=tokenizer,
    end_time=pd.Timestamp("2024-01-01")
)

# 4. Train/test split
train_idx, test_idx = train_test_split(range(100), test_size=0.2, random_state=42)
X_train, X_test = embeddings[train_idx], embeddings[test_idx]

# 5. Binary classification
clf = LogisticRegression(max_iter=1000)
clf.fit(X_train.numpy(), labels['binary'][train_idx])
y_prob = clf.predict_proba(X_test.numpy())[:, 1]
print(f"Binary AUC: {roc_auc_score(labels['binary'][test_idx], y_prob):.3f}")

# 6. Multi-class classification
clf_mc = LogisticRegression(multi_class='multinomial', max_iter=1000)
clf_mc.fit(X_train.numpy(), labels['multiclass'][train_idx])
y_pred_mc = clf_mc.predict(X_test.numpy())
print(f"Multiclass Acc: {accuracy_score(labels['multiclass'][test_idx], y_pred_mc):.3f}")

# 7. Survival regression
reg = Ridge(alpha=1.0)
reg.fit(X_train.numpy(), labels['survival'][train_idx])
y_pred_os = reg.predict(X_test.numpy())
print(f"Survival MAE: {mean_absolute_error(labels['survival'][test_idx], y_pred_os):.2f} months")

Tips

Embeddings from foundation models are typically well-scaled, but standardization can help:
from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train.numpy())
X_test_scaled = scaler.transform(X_test.numpy())
Use regularization to prevent overfitting, especially with small cohorts:
# L2 regularization
clf = LogisticRegression(C=0.1, max_iter=1000)  # Lower C = more regularization

# Or use Ridge for regression
reg = Ridge(alpha=10.0)  # Higher alpha = more regularization
For robust evaluation, use cross-validation:
from sklearn.model_selection import cross_val_score

scores = cross_val_score(
    LogisticRegression(max_iter=1000),
    embeddings.numpy(),
    labels['binary'],
    cv=5,
    scoring='roc_auc'
)
print(f"CV AUC: {scores.mean():.3f} ± {scores.std():.3f}")
Handle imbalanced classes with class weights:
clf = LogisticRegression(
    class_weight='balanced',
    max_iter=1000
)

Next Steps