Skip to main content
UThe workflow consists of three stages:
  1. Format data and tokenize data: Organize your data into the required MEDS format to record medical events and a labels table which includes the ground truth for your task targets. Convert these into serial tokens using smb_utils.
  2. Represent data as embeddings: Ingest MEDS events tokens into the Standard Model to produce high-dimensional patient-level embeddings.
  3. Train clinical predictors: Use embeddings plus labels to train task heads (e.g., readmission, phenotyping, survival).
Make sure you’ve completed the quickstart setup. Commands below assume quickstart/ as the working directory and use uv run.

Format and tokenize data

To apply the Standard Model, you must first have an events table in MEDS (Medical Event Data Standard) format (one row per clinical event). A collection of ETLs from common data formats, including OMOP, MIMIC-IV, and MEDS Unsorted can be found here. The end-to-end example further describes the input data format.
The Standard Model was built to operate on multiple modalities. For simplicity, this tutorial focuses on EHR text data. A more advanced multi-modal tutorial is forthcoming.
Events data example rows:
subject_idtimecodetablevalue
100000322022-01-15 08:00:00ICD10:I10condition
100000322022-01-15 09:30:00LOINC:2093-3lab145.2
100012172022-02-01 14:00:00RxNorm:861004medication
You should also have a labels table with one row per subject, in the same order as your events table. Columns set the ground truth for prediction tasks, and are used for performance evaluation. Labels data example rows:
subject_idprediction_timereadmission_riskphenotype_classoverall_survival_monthsevent_observed
100000322022-04-12 12:00:000068.11
100012172022-06-15 08:00:000245.81
100024282022-02-14 13:30:001135.51
Your labels should be derived from linked patient data and should reflect the data types used by your clinical outcome classifier or model type. The model tokenizer uses prediction_time in the labels table as the cutoff when considering events to build embeddings. Think of this date as the “as-of” time for the prediction, or the inclusive endpoint of data the model can use to generate embeddings. We next provide smb_utils.process_ehr_info to convert the MEDS events table into a tokenizable XML-like stream (e.g., [timestamp]: <conditions></conditions>, <measurements></measurements>). Time in our setting is measured in days. All events recorded at the same timestamp will be grouped by the event type/category with the XML-style tag, in chronological order. The timestamp is an explicit token in the string as well (see input_text in the example below).
Note that smb-v1-1.7B only supports a max token length of 4096, yet many patient histories exceed this constraint.Our smb-utils package offers functionality to solve this via multiple strategies as a temporary solution to efficiently manage context:
  1. filter events data by modality (i.e., code or table columns),
  2. organize events into time bins with the most recent events from an anchor date, or
  3. flexibly define custom event categories.
Supporting longer context length is always a to-do on our list and we are almost ready to roll out something better pretty soon!

ETL example

Assume your events are saved as internal_cohort_meds.parquet with the schema above. Load and verify, then serialize one patient:
import pandas as pd
from smb_utils import process_ehr_info

# Load events (MEDS format: subject_id, time, code, table, value)
df = pd.read_parquet("internal_cohort_meds.parquet")
assert {"subject_id", "time", "code", "table", "value"}.issubset(df.columns)

# Example: Serialize a single patient history
# 'end_time' enforces causal masking (the model cannot see future data)
input_text = process_ehr_info(
    df,
    subject_id="patient_5521",
    end_time=pd.Timestamp("2024-01-01")
)

print(input_text)
# Output:
# [2023-11-15]
# <conditions>
# ICD10:C34.90
# </conditions>
# <medications>
# RxNorm:583214
# </medications>
Labels must be aligned to the same patient order as your embedding matrix (i.e., same subject_id order as the events you passed to the model). This is important to resolve the correct end_time for each patient when creating embeddings. If you’re using a coding agent (like Claude Code, Codex, Gemini, etc.), you can derive a new labels table from your events data with the prompt below. In this example we’ve specified the kinds of outcomes we’re looking for and the data type we’ll use when making predictions downstream. We recommend chatting with the agent to explore your events data if you’re unsure of what clinical outcomes you may want to predict. The prompt below runs in ~1m.

Create a new labels table by referencing your events data. This prompt assumes that you've already downloaded the model and that the MIMIC-IV demo events data file is available for use.

CursorOpen in Cursor

Get embeddings and train classifiers with AI

From here, a coding agent is generally able to run inference, train classifiers, and give you evaluation results with a single prompt. The following example resolves in 3-4m.

Represent patients as embeddings, train 4 types of classifiers, and get prediction results. This prompt assumes you have completed all previous steps

Represent data as embeddings

We now provide example code for get_embeddings, a function to load the model, serialize MEDS data into text, tokenize the text stream, and perform inference through the model for a full dataframe of patients.
We use Last Token Pooling to extract the final hidden state, which represents the patient’s entire causal trajectory up to the end_time(inclusive).
import pandas as pd
import torch
from smb_utils import process_ehr_info
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm

MODEL_ID = "standardmodelbio/smb-v1-1.7b"

# 1. Load Model
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, trust_remote_code=True, device_map="auto"
)
model.eval()

# 2. Batch Extraction (df from Format data step above)
def get_embeddings(df, pids, end_time):
    embeddings = []
    for pid in tqdm(pids):
        
		# Resolve per-patient end_time if a mapping is provided
          if isinstance(end_time, (pd.Series, dict)):
              patient_end_time = pd.Timestamp(end_time[pid])
          else:
              patient_end_time = end_time

		# A. Serialize (MEDS -> Text)
        text = process_ehr_info(df, subject_id=pid, end_time=end_time)
        
        # B. Tokenize
        inputs = tokenizer(
            text, return_tensors="pt", truncation=True, max_length=4096
        ).to(model.device)
        
        # C. Inference (Hidden States)
        with torch.no_grad():
            outputs = model(inputs.input_ids, output_hidden_states=True)
            # Extract last token vector
            vec = outputs.hidden_states[-1][:, -1, :].cpu()
            embeddings.append(vec)
            
    return torch.cat(embeddings, dim=0).numpy()

# Execute (use prediction_time from your labels table if it has that column; otherwise set a single cutoff)
pids = df["subject_id"].unique()
# If your labels table has prediction_time (e.g. from the quickstart demo)
end_times = labels_df.set_index("subject_id")["prediction_time"]
X = get_embeddings(df, pids, end_time)
The output X is a list of embeddings (fixed-length, 1-dimensional PyTorch tensor) that represent the patients’ medical history from df. This list is in the same patient order as your MEDS data and labels table.

Train clinical predictors

We can now use the embeddings (X) and the labels table to train multiple types of task heads to predict clinical outcomes. After importing the necessary packages and splitting our data into training and test sets, we complete Task A: a binary classifier to predict subject readmission to the hospital, as recorded in the readmission_risk column. ROC-AUC is given as the evaluation metric for Task A.
# --- Setup ---
import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression, Ridge
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from sklearn.metrics import roc_auc_score, accuracy_score, mean_absolute_error
from lifelines import CoxPHFitter

# X, pids from embedding step above. Load your labels table and align to X row order:
labels_df = pd.read_parquet("your_labels.parquet")  # replace with your path
labels = labels_df.set_index("subject_id").loc[pids].reset_index()

X_train, X_test, labels_train, labels_test = train_test_split(
    X, labels, test_size=0.2, random_state=42
)
print(f"Training on {len(X_train)} samples, testing on {len(X_test)}")


# --- Task A: Binary (Readmission Risk) ---
print("\\n--- Task A: Binary Classification ---")
clf_bin = LogisticRegression(max_iter=1000)
clf_bin.fit(X_train, labels_train["readmission_risk"])
y_prob = clf_bin.predict_proba(X_test)[:, 1]
auc = roc_auc_score(labels_test["readmission_risk"], y_prob)
print(f"-\> ROC-AUC: {auc:.3f}")
Task B is a multi-class classifier to predict cancer stage 1-4, as recorded in the phenotype_class column. Accuracy score, or fraction of correctly classified samples, is given as evaluation metric.
# --- Task B: Multiclass (Phenotype) ---
print("\\n--- Task B: Multiclass Phenotyping ---")
clf_multi = LogisticRegression(max_iter=1000)
clf_multi.fit(X_train, labels_train["phenotype_class"])
y_pred_class = clf_multi.predict(X_test)
acc = accuracy_score(labels_test["phenotype_class"], y_pred_class)
print(f"-\> Accuracy: {acc:.3f}"
Task C uses the continuous variable in overall_survival_months to train a regression model and predict months survival for subjects who died. Mean absolute error is reported for Task C.
# --- Task C: Regression (Survival months) ---
print("\\n--- Task C: Regression ---")
reg = Ridge(alpha=1.0)
reg.fit(X_train, labels_train["overall_survival_months"])
y_pred_reg = reg.predict(X_test)
mae = mean_absolute_error(labels_test["overall_survival_months"], y_pred_reg)
print(f"-\> MAE: {mae:.2f}")
Finally, Task D predicts the Cox proportional hazard between arbitrary groups of patients for risk of death, using the overall_survival_months and event_observed columns. The concordance index is reported as the evaluation metric.
# --- Task D: Survival (Cox PH) ---\
print("\\n--- Task D: Survival Analysis ---")\
pca = PCA(n_components=10)\
X_train_pca = pca.fit_transform(X_train)\
X_test_pca = pca.transform(X_test)\
cox_df = pd.DataFrame(X_train_pca, columns=[f"PC{i}" for i in range(10)])\
cox_df["T"] = labels_train["overall_survival_months"].values\
cox_df["E"] = labels_train["event_observed"].values\
cph = CoxPHFitter()\
cph.fit(cox_df, duration_col="T", event_col="E")\
test_cox_df = pd.DataFrame(X_test_pca, columns=[f"PC{i}" for i in range(10)])\
test_cox_df["T"] = labels_test["overall_survival_months"].values\
test_cox_df["E"] = labels_test["event_observed"].values\
c_index = cph.score(test_cox_df, scoring_method="concordance_index")\
print(f"-\> C-Index: {c_index:.3f}")
You may similarly abstract these concepts to use the Standard Model embeddings as input to other classifier types, including more complex downstream models!

Contact Us

Having trouble, or just want to talk about your project? Send an email to erik@standardmodel.bio