This guide explains how to leverage the Standard Model on your own EHR data. The workflow consists of three stages:
- Format data: Convert MEDS (Medical Event Data Standard) formatted EHR data into serial tokens using
smb-biopan-utils.
- Represent data as embeddings: Ingest MEDS data into the Standard Model to produce high-dimensional patient-level embeddings that capture the full causal history.
- Train clinical predictors: Use embeddings to power downstream tasks including Readmission Risk, Disease Phenotyping, and Survival Analysis.
Don’t forget to activate your virtual environment!
source standard_model/bin/activate
Format data
To apply the Standard Model, you must bridge the gap between your storage schema (MEDS) and the model’s input modality (Text).
The Standard Model was built to operate on multiple modalities, as we believe patient biology is far too complex to represent with text alone. However, for simplicty, we focus the scope of this tutorial to EHR text data. A more advanced tutorial on how to leverage multiple modalities is forthcoming.
- Schema (MEDS): You must first organize your data into MEDS format. This provides a strict logical structure (patient, time, code) for your ETL pipeline. The MEDS repo has tools to help you do so.
- The Modality (Text): We provide
smb-biopan-utils to deterministically convert your MEDS tables into a tokenizable XML-like stream to input to the model. This function handles date grouping, modality tagging (e.g., <conditions>), and causal masking.
ETL Example
Use Pandas to restructure your raw data into this schema. We assume your data is already saved in MEDS format on disk as internal_cohort_meds.parquet.
import pandas as pd
from smb_biopan_utils import process_ehr_info
# Load your cohort in MEDS format
df = pd.read_parquet("internal_cohort_meds.parquet")
# Verify schema
assert {'subject_id', 'time', 'code'}.issubset(df.columns)
# Example: Serialize a single patient history
# 'end_time' enforces causal masking (the model cannot see future data)
input_text = process_ehr_info(
df,
subject_id="patient_5521",
end_time=pd.Timestamp("2024-01-01")
)
print(input_text)
# Output:
# [2023-11-15]
# <conditions>
# ICD10:C34.90
# </conditions>
# <medications>
# RxNorm:583214
# </medications>
Represent data as embeddings
Pass the serialized text through the model. We use Last Token Pooling to extract the final hidden state, which represents the patient’s entire causal trajectory up to the end_time.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
MODEL_ID = "standardmodelbio/SMB-v1-1.7B"
# 1. Load Model
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, trust_remote_code=True, device_map="auto"
)
model.eval()
# 2. Batch Extraction
def get_embeddings(df, pids, end_time):
embeddings = []
for pid in tqdm(pids):
# A. Serialize (MEDS -> Text)
text = process_ehr_info(df, subject_id=pid, end_time=end_time)
# B. Tokenize
inputs = tokenizer(
text, return_tensors="pt", truncation=True, max_length=4096
).to(model.device)
# C. Inference (Hidden States)
with torch.no_grad():
outputs = model(inputs.input_ids, output_hidden_states=True)
# Extract last token vector
vec = outputs.hidden_states[-1][:, -1, :].cpu()
embeddings.append(vec)
return torch.cat(embeddings, dim=0).numpy()
# Execute
pids = df["subject_id"].unique()
X = get_embeddings(df, pids, pd.Timestamp("2024-01-01"))
Train clinical predictors
The embeddings (X) serve as a universal substrate for various clinical tasks. Below is an example of how to train and assess binary classifiers, multiclass phenotypers, regressors, and survival models on the embeddings.
import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression, Ridge
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from sklearn.metrics import roc_auc_score, accuracy_score, mean_absolute_error
from lifelines import CoxPHFitter
# --- Setup ---
# Assume 'X' is your (N_samples, 1536) embedding matrix
# Assume 'labels' is a DataFrame aligned to X with columns:
# ['binary_outcome', 'phenotype_class', 'survival_months', 'event_observed']
# Split 80/20
X_train, X_test, y_train, y_test = train_test_split(
X, labels, test_size=0.2, random_state=42
)
print(f"Training on {len(X_train)} samples, Testing on {len(X_test)}")
# --- Task A: Binary Classification (e.g., Readmission Risk) ---
print("\n--- Task A: Binary Classification ---")
clf_bin = LogisticRegression(max_iter=1000)
clf_bin.fit(X_train, y_train["binary_outcome"])
y_prob = clf_bin.predict_proba(X_test)[:, 1]
auc = roc_auc_score(y_test["binary_outcome"], y_prob)
print(f"-> ROC-AUC: {auc:.3f}")
# --- Task B: Multiclass Classification (e.g., Disease Staging) ---
print("\n--- Task B: Multiclass Phenotyping ---")
clf_multi = LogisticRegression(multi_class="multinomial", max_iter=1000)
clf_multi.fit(X_train, y_train["phenotype_class"])
y_pred_class = clf_multi.predict(X_test)
acc = accuracy_score(y_test["phenotype_class"], y_pred_class)
print(f"-> Accuracy: {acc:.3f}")
# --- Task C: Regression (e.g., Length of Stay / Survival Time) ---
print("\n--- Task C: Regression ---")
reg = Ridge(alpha=1.0)
reg.fit(X_train, y_train["survival_months"])
y_pred_reg = reg.predict(X_test)
mae = mean_absolute_error(y_test["survival_months"], y_pred_reg)
print(f"-> MAE: {mae:.2f}")
# --- Task D: Survival Analysis (Cox Proportional Hazards) ---
print("\n--- Task D: Survival Analysis ---")
# Note: Project to lower dim (PCA) to ensure stability of CoxPH
pca = PCA(n_components=10)
X_train_pca = pca.fit_transform(X_train)
X_test_pca = pca.transform(X_test)
# Prepare dataframe for lifelines (requires covariates + duration + event in one df)
cox_df = pd.DataFrame(X_train_pca, columns=[f"PC{i}" for i in range(10)])
cox_df["T"] = y_train["survival_months"].values
cox_df["E"] = y_train["event_observed"].values
cph = CoxPHFitter()
cph.fit(cox_df, duration_col="T", event_col="E")
# Evaluate on test set
test_cox_df = pd.DataFrame(X_test_pca, columns=[f"PC{i}" for i in range(10)])
test_cox_df["T"] = y_test["survival_months"].values
test_cox_df["E"] = y_test["event_observed"].values
c_index = cph.score(test_cox_df, scoring_method="concordance_index")
print(f"-> C-Index: {c_index:.3f}")