tripso package
Module contents
Modules for tripso method
- class tripso.gpEval(gpdb_path: str | None = None, output_dir: str = '/path/to/output/', dataset_path: str | None = None, tissue: str | None = 'test', model_type: str | None = 'Base', batch_size: int | None = 128, path_to_trained_model: str | None = None, seed: int | None = 0, hparam_save: str | None = 'all', cond_to_shift: Dict | None = None, return_classification_report: bool | None = False, gene_format: str | None = 'symbol', gp_inputs: list | None = None, gpmean_fm_encoder_pkg: str | None = 'geneformer', gpmean_fm_encoder_name: str | None = 'gf-6L-30M-i2048')
Bases:
objectMain class for running downstream evaluation tasks on trained models
- Parameters:
dataset_path (str) – Path to folder containing tokenized dataset
gpdb_path (str) – Path to gene program database
output_dir (str) – Path to directory where we will save outputs and where model checkoints are stored
batch_size (int) – Batch size for evaluation step
n_blocks (int) – Number of transformer blocks
gene_format (str) – Format in which gene names are stored in GPDB One of ‘symbol’ or ‘ensembl’
tissue (str) – Tissue name for logging experiment in wandb This is also present in model checkpoint name
model_type (str) – Base (GP blocks only) or Global (with cell token)
n_heads (int) – Number of heads for multi-head attention
gp_latent_size (int) – Size of latent space for GP tokens
gp_inputs (list) – Which GP from GPDB to include in model if None, defaults to all GP
supervised_labels (list) – Dict {label : num_classes} for supervised classification
global_attn_heads (int) – number of heads for learning cell token in global attention model
global_loss – loss used to train global attention model (for compatibility with gpGlobal init)
- evaluate_supervised_model(precision=32)
- generate_attention_matrix(gp_for_forward, gp_for_downstream, genes_to_keep=None, do_ensembl_conversion=True, split='test', precision=32)
Get attention weights from gpTransformer
- generate_embeddings(split='train', precision=32, return_mean_non_padding=False)
Save embeddings as Dataset
- generate_gene_embeddings(gp_for_forward: str | None, gp_for_downstream: str, split='train', obs_key=None, obs_value=None, data_frac=1, genes_to_keep=None, output_tag=None, do_ensembl_conversion=True, precision=32, return_gene_cosim=None)
Save gene embeddings as Dataset
- Parameters:
split (str) – Data split to use for generating embeddings
obs_key (str) – Key in adata.obs to filter on
obs_value (str) – value of obs_key to keep
data_frac (float) – Fraction of data to use for generating embeddings
gp_for_forward (str or None) – Pathway to use for model forward pass This is helpful if you only need to run forward pass on one GP rather than all of them.
gp_for_downstream (str) – Pathway to use for focus of downstream analysis
genes_to_keep (list) – Genes to generate embeddings for if None –> all genes
return_gene_cosim – if None, get gene embeddings if gene_to_gp: anndata with (gene, GP) cosine similarity if gene_to_gene: matrix of mean (gene, gene) cosine similarity
Utils.utils (Use find_genes_in_multiple_gp or get_genes_in_single_gp from)
selection (for GP)
- visualize(label_to_plot, data_to_plot='test', gp_to_plot=None, subsample=None, method: Literal['umap', 'pca'] = 'umap')
UMAP of GP embeddings
- visualize_gene_embeddings(cell_label_to_plot, genes_to_plot, gene_label_to_plot, gene_label_df, gene_embedding_dir, output_dir, pathway=None, frac=1, gene_col_name='gene')
- tripso.pp_and_tokenize(root_dir: str, adata_path: str | None = None, input_size: int = 2048, vars_to_keep: Dict | List = ['cell_type'], subsample_by: List | None = ['cell_type'], n_cells_per_class: int = 20000, chunk_size: int = 50000, name_tag: str | None = 'Reactome', cov_to_encode: List[str] | str = ['cell_type', 'condition'], batch_keys: List[str] | None = None, tissue: str | None = None, hvg_batch_key: str | None = None, save_gp_genes_object: bool | None = False, calculate_hvg: bool | None = True, do_tokenization: bool | None = True, use_gp_tokenizer: bool | None = False, do_ensembl_conversion: bool | None = True, gp_genes_union: List[str] | None = None, output_data_name: str | None = None)
Preprocess and tokenize scRNA-seq data for GPformer training.
This function performs the complete preprocessing pipeline including: - Loading and optionally subsampling AnnData objects - Optionally calculating highly variable genes (HVGs) - Splitting data into chunks (for reasonable RAM usage) - Tokenizing data using Geneformer or custom tokenizer - Encoding categorical covariates - Optionally saving GP genes subset AnnData object
- Parameters:
root_dir (str) – Output directory where processed h5ad and tokenized data will be saved. Directory structure will be created as: root_dir/data/processed/
adata_path (str, optional) – Path to input AnnData h5ad file to preprocess and tokenize.
input_size (int, default=2048) – Maximum input sequence length for tokenization. Determines which Geneformer token dictionary to use (2048 or 4096).
vars_to_keep (dict or list, default=['cell_type']) – Metadata column names from adata.obs to retain in tokenized dataset. If dict, maps obs column names to output column names.
subsample_by (list of str, optional, default=['cell_type']) – Metadata columns to use for balanced downsampling. Set to None to skip subsampling. Multiple columns will be combined.
n_cells_per_class (int, default=20000) – Minimum number of cells to keep per class during balanced subsampling. Classes with fewer cells will keep all available cells.
chunk_size (int, default=50000) – Number of cells per chunk when splitting large datasets for tokenization.
name_tag (str, default='Reactome') – Identifier tag for gene program database filename (gpdb_{name_tag}.csv).
cov_to_encode (str or list of str, default=['cell_type', 'condition']) – Metadata columns to encode as integer IDs (creates {column}_id columns).
batch_keys (list of str, optional) – Metadata columns to combine into a ‘batch_key’ column (joined with ‘_’). Used for HVG calculation, and will be used in decoder of count reconstruction step.
tissue (str, optional) – Tissue type identifier for naming output files. If None, uses last component of root_dir path.
hvg_batch_key (str, optional) – Column name to use as batch key for highly variable gene calculation. If None and batch_keys provided, uses ‘batch_key’.
save_gp_genes_object (bool, default=False) – Whether to save a separate h5ad file containing only genes from gene programs in gpdb_{name_tag}.csv.
calculate_hvg (bool, default=True) – Whether to calculate highly variable genes using Seurat v3 method (top 2000 genes) and subset to HVGs.
do_tokenization (bool, default=True) – Whether to perform tokenization step. Set to False to only preprocess.
use_gp_tokenizer (bool, default=False) – Whether to use GPTokenizer (True) or standard TranscriptomeTokenizer (False) for tokenization.
do_ensembl_conversion (bool, default=True) – Whether to convert gene names to Ensembl IDs during tokenization.
gp_genes_union (list of str, optional) – Union of all GP genes to be used in tokenizer. If None and use_gp_tokenizer=True, will be loaded from gpdb_{name_tag}.csv.
output_data_name (str, optional) – Custom name for output dataset directory. If None, uses ‘input_dataset’.
- Raises:
ValueError – If adata_path is not provided and no existing h5ad found in root_dir. If hvg_batch_key cannot be determined when calculate_hvg=True. If no GP genes found in dataset when save_gp_genes_object=True. If gpdb_{name_tag}.csv file not found in root_dir.
Notes
Expected gene program database format: CSV file where each column represents a gene program and contains gene identifiers (one per row).
- Output directory structure:
- root_dir/
- data/processed/
input_h5ad/ - Preprocessed h5ad files tokenized/ - Tokenized datasets input_dataset/ - Final encoded dataset
or {output_data_name}/
gpdb_{name_tag}.csv - Gene program database (must exist)
- tripso.train(dataset_path: str, gpdb_path: str, output_dir: str, batch_size: int = 32, mgm: float = 0.15, tissue: str | None = None, n_heads: int = 8, n_blocks: int = 1, lr_scheduler: Literal['CosineLRwithWarmUp', 'ReduceLROnPlateau'] = 'ReduceLROnPlateau', n_epochs: int = 20, gene_format: Literal['symbol', 'ensembl'] = 'symbol', model_type: str = 'Base', strategy: str = 'ddp_find_unused_parameters_true', attn_dropout: float = 0.0, lr: float = 0.001, resume_training: bool | None = False, gp_inputs: list | None = None, frac_for_training: float | None = 1.0, global_loss: str = 'supervised', classification_labels: list | None = None, global_attn_heads: int | None = 8, supervised_labels: dict | None = None, global_masking_rate: float | None = 0.15, global_attn_dropout: float | None = 0.0, global_training: str = 'simultaneous', path_to_base_model: str | None = None, learn_new_gp: bool | None = False, gp_to_learn: list = ['novel_gp'], global_n_blocks: int = 1, reconstruction_loss: str | None = 'nb', adata_path: str | None = None, use_flash: bool | None = False, weight_decay: float = 0.0, sampler: str | None = None, sample_by: str | None = None, fm_encoder_name: str = 'gf-6L-30M-i2048', fm_encoder_pkg: str = 'geneformer', peft_config_path: str | None = None, seed: int | None = 0, data_seed: int | None = None, supervised_rem_var: str | None = None, num_nodes: int = 1, prbm_path: str | None = None, use_l2_norm: bool | None = False, gp_latent_size: int | None = None, all_genes: list | None = None, init_sparsity: float | None = 0.0, limit_train_batches: float | None = 1.0, limit_val_batches: float | None = 1.0, val_check_interval: float | None = 1.0, use_pos_emb: str | None = 'sin_cos', global_pos_emb: str | None = 'sin_cos', vocab_gene_names: list | None = None, precision=32, bert_config: Dict = {}, use_gene_embeddings: bool | None = False, calc_gp_loss: bool | None = True, calc_gene_loss: bool | None = True, lora_config_args: dict | None = None, warmup: int | None = 0, accumulate_grad_batches: int | None = 1)
Wrapper function for training Tripso model
- Parameters:
dataset_path (str) – Path to input tokenized dataset
gpdb_path (str) – Path to input gp database, a pandas csv where each column is a GP, with GP names as column names
output_dir (str) – Directory where checkpoints and results will be saved
batch_size (int, default=32) – Batch size for training
mgm (float, default=0.15) – Masking ratio for masked gene modeling ie what proportion of genes to mask during training
tissue (Optional[str], default=None) – Tissue name for logging experiment in wandb
n_heads (int, default=8) – Number of heads for multi-head attention in GP encoder
n_blocks (int, default=1) – Number of transformer blocks in GP encoder
lr_scheduler (Literal['CosineLRwithWarmUp', 'ReduceLROnPlateau'],) – default=’ReduceLROnPlateau’ Learning rate scheduler for optimizer
n_epochs (int, default=20) – Number of epochs to train for
gene_format (Literal['symbol', 'ensembl'], default='symbol') – Format in which gene names are stored in GPDB
model_type (str, default='Base') – One of ‘Base’, ‘Global’, ‘Global_LoRA’, or ‘Mean’. ‘Base’ trains only the GP encoder. ‘Global’ adds a cell-level transformer. ‘Global_LoRA’ uses LoRA for parameter-efficient training. ‘Mean’ uses gene program mean embeddings.
strategy (str, default='ddp_find_unused_parameters_true') – Strategy for multi-GPU PyTorch Lightning trainer
attn_dropout (float, default=0.0) – Dropout rate for attention layers
lr (float, default=1e-3) – Learning rate for optimizer
resume_training (Optional[bool], default=False) – Set to True to resume training from checkpoint
gp_inputs (Optional[list], default=None) – List of GP names from GPDB to include in model. If None, uses all GP
frac_for_training (Optional[float], default=1.0) – Fraction of the dataset to use for training (for development/testing)
global_loss (str, default='supervised') – Loss function for global model: ‘supervised’, ‘masking’, or ‘reconstruction’
classification_labels (Optional[list], default=None) – List of labels for supervised classification (deprecated, use supervised_labels)
global_attn_heads (Optional[int], default=8) – Number of attention heads for learning cell token in global model
supervised_labels (Optional[dict], default=None) – Dict mapping label names to number of classes for supervised classification
global_masking_rate (Optional[float], default=0.15) – Masking rate for global model when using masking loss
global_attn_dropout (Optional[float], default=0.0) – Dropout rate for attention layers in global model
global_training (str, default='simultaneous') – Training mode: ‘simultaneous’, ‘sequential’, ‘finetune’, ‘finetune_global’, or ‘finetune_gene_encoder’. Controls how base and global models are trained
path_to_base_model (Optional[str], default=None) – Path to pre-trained model checkpoint for sequential/finetuning training
learn_new_gp (Optional[bool], default=False) – If True, load pretrained model, freeze most parameters, and learn new GP
gp_to_learn (list, default=['novel_gp']) – List of GP names to learn when learn_new_gp is True
global_n_blocks (int, default=1) – Number of transformer blocks in global model
reconstruction_loss (Optional[str], default='nb') – Loss function for reconstruction: ‘nb’ (negative binomial) or ‘mse’
adata_path (Optional[str], default=None) – Path to AnnData object with gene expression required for reconstruction loss
use_flash (Optional[bool], default=False) – Whether to use flash attention in transformer blocks
weight_decay (float, default=0.0) – Weight decay for optimizer
sampler (Optional[str], default=None) – Sampling strategy for data loading Options: ‘weighted’ for WeightedRandomSampler, ‘length’ for LengthGroupedSampler, or None.
sample_by (Optional[str], default=None) – Column name in AnnData to sample by (used with sampler)
fm_encoder_name (str, default='gf-6L-30M-i2048') – Name of foundation model encoder to use
fm_encoder_pkg (str, default='geneformer') – Package for foundation model encoder: ‘geneformer’ or ‘from_scratch’
peft_config_path (Optional[str], default=None) – Path to PEFT (Parameter-Efficient Fine-Tuning) configuration file
seed (Optional[int], default=0) – Random seed for reproducibility
data_seed (Optional[int], default=None) – Random seed for data loading. If None, uses same as seed
supervised_rem_var (Optional[str], default=None) – Variable to remove from supervised labels (currently unused)
num_nodes (int, default=1) – Number of nodes for distributed training
prbm_path (Optional[str], default=None) – Path to PRBM model (currently unused)
use_l2_norm (Optional[bool], default=False) – Whether to use L2 normalization in model
gp_latent_size (Optional[int], default=None) – Size of GP latent representation. If None, uses default from model
all_genes (Optional[list], default=None) – List of all genes to consider. If provided, masks GP genes in gene encoder
init_sparsity (Optional[float], default=0.0) – Initial sparsity level for sparse models
limit_train_batches (Optional[float], default=1.0) – Fraction or number of training batches to use per epoch
limit_val_batches (Optional[float], default=1.0) – Fraction or number of validation batches to use
val_check_interval (Optional[float], default=1.0) – How often to check validation set. Float for fraction of epoch, int for number of batches
use_pos_emb (Optional[str], default='sin_cos') – Type of positional embedding for gene encoder
global_pos_emb (Optional[str], default='sin_cos') – Type of positional embedding for global model
vocab_gene_names (Optional[list], default=None) – List of gene names in vocabulary for one-hot encoding
precision (int or str, default=32) – Training precision: 32, 16, or ‘bf16-mixed’
bert_config (Dict, default={}) – Configuration dict for BERT model when training from scratch
use_gene_embeddings (Optional[bool], default=False) – Model name (e.g., ‘gf-12L-95M-i4096’) or path to gene embeddings file, or False to initialize randomly
calc_gp_loss (Optional[bool], default=True) – Whether to calculate GP prediction loss
calc_gene_loss (Optional[bool], default=True) – Whether to calculate gene-level loss
lora_config_args (Optional[dict], default=None) – Configuration arguments for LoRA when using Global_LoRA model
warmup (Optional[int], default=0) – Number of warmup steps for learning rate scheduler
accumulate_grad_batches (Optional[int], default=1) – Number of batches to accumulate gradients over before updating weights