In-Depth Configuration¤
This page gives an overview of the various settings you can use to customize a training run.
We use Draccus for configuration. Draccus is yet-another yaml-to-dataclass library that uses both dataclasses to generate yaml and argparse to parse command line arguments.
Typically, your config data class will look something like this:
@dataclass
class TrainLmConfig:
data: LMDatasetConfig = field(default_factory=LMDatasetConfig)
trainer: TrainerConfig = field(default_factory=TrainerConfig)
model: LmConfig = field(default_factory=Gpt2Config)
optimizer: OptimizerConfig = field(default_factory=AdamConfig)
Your training run will typically be associated with a single config file. For instance, you might have a file
my-run.yaml that looks like this:
data:
train_urls:
- "gs://my_bucket/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz"
validation_urls:
- "gs://my_bucket/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz"
cache_dir: "gs://my_bucket/tokenized/openwebtext_2/"
model:
type: gpt2
hidden_dim: 768
num_heads: 12
num_layers: 12
seq_len: 1024
gradient_checkpointing: true
scale_attn_by_inverse_layer_idx: true
trainer:
tracker:
type: wandb
project: "levanter"
tags: [ "openwebtext", "gpt2"]
mp: p=f32,c=bfloat16
model_axis_size: 1
per_device_parallelism: 4
train_batch_size: 512
optimizer:
learning_rate: 6E-4
weight_decay: 0.1
min_lr_ratio: 0.1
Including Other Config Files¤
Draccus supports inclusion of config files via the !include special syntax. For instance, this:
# my-run.yaml
data: !include data.yaml
trainer:
num_train_steps: 1000000
# data.yaml
train_urls:
- "gs://my_bucket/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz"
validation_urls:
- "gs://my_bucket/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz"
cache_dir: "gs://my_bucket/tokenized/openwebtext_2/"
will expand to:
data:
train_urls:
- "gs://my_bucket/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz"
validation_urls:
- "gs://my_bucket/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz"
cache_dir: "gs://my_bucket/tokenized/openwebtext_2/"
trainer:
num_train_steps: 1000000
The inclusion path is always relative to the config file. Unfortunately, we don't (can't) support inclusion at the top level.
Trainer and TrainerConfig¤
The levanter.trainer.Trainer class is governed by the levanter.trainer.TrainerConfig dataclass.
Trainer has a lot of stuff in it. We highlight some of them in the following sections.
The following table lists some of the parameters that you might want to change.
Core Training Loop Configuration¤
| Parameter | Description | Default |
|---|---|---|
seed |
The random seed | 0 |
num_train_steps |
The number of training steps to run | 400,000 |
train_batch_size |
The batch size | 32 |
per_device_train_parallelism |
Number of examples to process on each device during training | train_batch_size / (num_accelerators * model_axis_size) |
per_device_eval_parallelism |
Number of examples to process on each device during eval | per_device_train_parallelism |
steps_per_eval |
How often to evaluate the model during training | 1,000 |
max_eval_batches |
How many batches to evaluate during each evaluation | None (meaning all) |
mp |
Mixed Precision policy using jmp | f32 (full precision) |
Logging and Reporting¤
| Parameter | Description | Default |
|---|---|---|
log_dir |
Where to save logs (python logger). $run_id will be appended |
logs/ |
Partitioning / FSDP¤
Sharding in Levanter is done with axis mappings, which specify how to map logical axes (e.g. "batch") to physical axes in the JAX device mesh.
(See the Haliax Scaling Tutorial
for a more detailed explanation of axis mappings.) Levanter's Trainer uses two axis mappings: parameter_axis_resources and compute_axis_resources.
parameter_axis_resources specifies how to shard the model parameters and optimizer state: basically how the model is sharded "at rest",
while the compute_axis_resources specifies how to shard the model during computation.
TrainerConfig allows you to specify these axis mappings in two ways, with a "basic" mode that has
reasonable defaults and an "advanced" mode that gives you more control.
Basic Mode¤
| Parameter | Description | Default |
|---|---|---|
batch_axis |
The axis to shard the batch over, for distributed data parallelism | "batch" |
fsdp_axis |
The axis or axes to shard the model over, for Fully Sharded Data Parallelism | "embed" |
tensor_parallel_axes |
The axis or axes to shard the model over, for Tensor Parallelism | None |
model_axis_size |
How many devices for tensor parallelism | 1 |
Advanced Mode¤
| Parameter | Description | Default |
|---|---|---|
axis_resources |
Mapping from logical axis to physical axis shared by both mappings | -- |
parameter_axis_resources |
Mapping from logical axis to physical axis for the parameter mapping | -- |
compute_axis_resources |
Mapping from logical axis to physical axis for the compute mapping | -- |
model_axis_size |
How many devices for tensor parallelism | 1 |
Checkpointing and Initialization¤
See also Checkpointer.
| Parameter | Description | Default |
|---|---|---|
load_checkpoint |
Whether to load checkpoint from base_path |
None: load if possible, but don't error. |
initialize_from |
Initialize training state from this path. May be a parent dir. Useful for continued training. | None |
checkpointer.base_path |
Base path to save checkpoints to | checkpoints/${run_id} |
checkpointer.save_interval |
How often to save checkpoints (time) | 15 minutes |
checkpointer.keep |
How often to keep checkpoints (steps). See below. | 10000 steps |
Checkpointer Save Policy¤
The checkpointer logic has two kinds of checkpoints:
- time-based checkpoints: temporary checkpoints that are saved every
save_intervalminutes. The previous time-based checkpoint is deleted when a new one is saved. - step-based checkpoints: permanent checkpoints that are saved according to a policy. These checkpoints are never deleted.
Step-based checkpoint configuration looks like this:
checkpointer:
keep:
- every: 1000 # steps
until: 10000 # step
- every: 5000 # steps
until: 40000 # step
- every: 10000
This policy will save permanent checkpoints every 1,000 steps until 10,000 steps, then every 5,000 steps until 40,000 steps, then every 10,000 steps. The default step-based checkpoint policy is to save a checkpoint every 10,000 steps.
JAX Compilation Cache Configuration¤
Levanter allows you to configure JAX's persistent compilation cache. This can significantly speed up startup times by caching compiled JAX functions.
The primary way to specify the cache directory is via the jax_compilation_cache_dir field in the TrainerConfig.
| Parameter | Description | Type | Default |
|---|---|---|---|
jax_compilation_cache_dir |
Path to a directory to store the persistent compilation cache. Can be a local path or a GCS path. | Optional[str] |
None (JAX default, usually ~/.cache/jax or platform specific) |
Other JAX compilation cache settings (like jax_persistent_cache_min_compile_time_secs, jax_persistent_cache_min_entry_size_bytes, jax_persistent_cache_enable_xla_caches, etc.)
can be configured by including them in the trainer.jax_config dictionary. This dictionary allows you to pass arbitrary JAX configuration options.
For more details on all available JAX compilation cache options and how JAX's compilation cache works, please refer to the official JAX documentation.
Here's an example of how to configure these options in your YAML file:
trainer:
# ... other trainer configs
jax_compilation_cache_dir: "/path/to/your/jax_cache" # Or "gs://your-bucket/jax_cache"
# To set other JAX compilation cache options or any other JAX global flag:
jax_config:
jax_persistent_cache_min_compile_time_secs: 5.0
jax_persistent_cache_min_entry_size_bytes: 1024
jax_persistent_cache_enable_xla_caches: "all" # or "xla_gpu_kernel_cache_file", etc.
# ... other jax settings like jax_threefry_partitionable
Alternatively, JAX's compilation cache directory can be set using the JAX_COMPILATION_CACHE_DIR environment variable.
This method is particularly useful for workflows involving launch.py on TPUs, as environment variables can be specified in the .levanter.yaml configuration file used by launch.py.
For more details on using launch.py, see the Using launch.py section in the TPU VM guide.
Example .levanter.yaml snippet:
env:
JAX_COMPILATION_CACHE_DIR: "gs://your-compile-cache-bucket/path"
# ... other environment variables
Trackers and Logging¤
We mostly use W&B for tracking values and other metadata about a run. However, we also support Tensorboard and a few other trackers. You can also use multiple trackers at once, or even write your own. See Trackers for more information.
W&B¤
Wandb is the default tracker and is installed by default. To use it, you can configure it in your config file:
trainer:
tracker:
type: wandb
project: my-project
entity: my-entity
Because wandb is the default, you can also just do:
trainer:
tracker:
project: my-project
entity: my-entity
| Parameter | Description | Default |
|---|---|---|
| entity | The wandb entity to use. | your default entity |
| project | The wandb project to use. | wandb's default |
| tags | Tags to add to the run. | [] |
| id | Unique run id | wandb's autogenerated id |
| name | The name of the run. | wandb's autogenerated name |
| save_code | Whether to save the code to wandb. | True |
| save_xla_dumps | Whether to save XLA compiler outputs to wandb. | False |
Notes:
- WandB's code saving logic isn't very good for our use case, so we have our own. We automatically sniff out the git repo of your main script.
save_xla_dumpsis useful for debugging XLA compilation issues. It tends to dump a lot of stuff, so we don't save it by default. To use it, you must also set the right environment variables. Something likeXLA_FLAGS="--xla_dump_to=/tmp/output_folder/xla_dumps --xla_dump_hlo_pass_re=.*. We will automatically parse out the env variable.
Tensorboard¤
Tensorboard is also supported. To use it, you can configure it in your config file:
trainer:
tracker:
type: tensorboard
logdir: logs
Install the optional dependencies for TensorBoard support with one of:
pip install "levanter[profiling]"uv sync --extra profiling
Viewing profiles: when profiling is enabled, JAX writes traces under <logdir>/plugins/profile/<timestamp>.
Launch the UI with tensorboard --logdir <logdir> and open http://localhost:6006/#profile.
If running remotely, forward the port: ssh -L 6006:localhost:6006 <host>.
Multiple Trackers¤
In some cases, you may want to use multiple trackers at once. For example, you may want to use both W&B and Tensorboard.
To do this, you can use the levanter.tracker.tracker.CompositeTracker class, or, if using a config file, you can specify multiple trackers:
trainer:
tracker:
- type: wandb
project: my-project
entity: my-entity
- type: tensorboard
logdir: logs
Ray Config¤
Levanter will by default automatically start a Ray cluster with all
the machines being used for training. This is useful for distributed
preprocessing. You can disable this behavior using auto_start_cluster: false.
| Parameter | Description | Default |
|---|---|---|
address |
The address of the Ray cluster to connect to. | None |
start_workers |
Whether to start Ray workers. If False, you must start them yourself. |
True |
auto_start_cluster |
Whether to start a Ray cluster automatically. | True |
Distributed Config¤
JAX can automatically sniff out clusters in SLURM and TPU environments. If you're not using SLURM or TPUs, you can specify the cluster manually using this config.
Don't use this on TPU, and possibly not on SLURM either.
| Parameter | Description | Default |
|---|---|---|
coordinator_address |
The address of the coordinator. If None, we'll use the default address. |
None |
num_processes |
The number of processes in the cluster. | None |
process_id |
The process id of this process. | None |
local_device_ids |
The local device ids of this process. | ${CUDA_VISIBLE_DEVICES} |
Optimizer¤
Standard Options¤
All optimizers in Levanter are based on the levanter.optim.OptimizerConfig dataclass. This class has the following fields, which are common to all optimizers (and most have to do with learning rate scheduling):
| Parameter | Description | Default |
|---|---|---|
weight_decay |
The weight decay. | 0.0 |
learning_rate |
The learning rate. | 1e-4 |
lr_schedule |
The type of learning rate schedule for decay. See below. | cosine |
min_lr_ratio |
The minimum learning rate ratio. | 0.1 |
warmup |
Warmup fraction or number of steps | 0.01 |
decay |
Decay fraction or number of steps | None |
rewarmup |
The learning rate re-warmup, if using cycles. | 0.0 |
cycles |
The number of cycles for the learning rate, or steps where cycles end | None |
cycle_length |
How long the cycles should be (as an int, fraction), or list of cycle lengths | None |
By default, Levanter uses a cosine learning rate decay with warmup. The learning rate is decayed to
min_lr_ratio * learning_rate over the course of the training run. This is a fairly standard default for LLM training.
Learning Rate Schedules¤
The lr_schedule parameter specifies the learning rate schedule. The following schedules are supported:
constant: Constant learning rate.linear: Linear decay.cosine: Cosine decay.inv_sqrt: Inverse square root decay.inv: Inverse decay.
Cycles¤
By default, there is only one cycle, and Levanter's LR schedule looks like this:
[warmup] -> [stable] -> [decay]
But you can specify more with either the cycles or cycle_length parameters.
If you want to use a learning rate schedule with cycles, you can specify the number of cycles with the cycles
or cycle_length parameters. The LR will be decayed to min_lr_ratio * learning_rate at the end of each cycle.
With cycles, Levanter's LR schedule looks like this:
[warmup] -> [stable] -> [decay] -> {[rewarmup] -> [stable] -> [decay]} x (cycles - 1)
or more compactly:
{[(re)?warmup] -> [stable] -> [decay]} x cycle
Here's what the phases mean:
warmup: The first warmup in training, which is part of the first cycle. The LR will start at 0 and linearly increase to the learning rate over this period.stable: The stable period. The LR will stay at the learning rate for this period.decay: The decay period. The LR will decay tomin_lr_ratio * learning_rateover this period.rewarmup: The re-warmup period. If using cycles, the LR will be re-warmed from the final value of the previous cycle back to the peak value of the next cycle.
Also note that if rewarmup is 0, there will be no rewarmup period, meaning the LR will jump back to the max LR. This is the default, and works surprisingly well. In addition, the stable and decay phase of the first cycle will generally be different from the stable and decay phase of the other cycles, since rewarmup and warmup are typically different.
stable cannot be specified directly. It is the period between warmup and decay in the first cycle, and the period
between rewarmup and decay in subsequent cycles. By default, there is no stable period.
All of these parameters can be specified in terms of a fraction of the total number of steps of a cycle or as an absolute number of steps.
Here are what the cycles and cycle_length parameters mean:
cycle_length: If you specify an int or float forcycle_length, the learning rate will cycle through the schedule with the specified length. This is equivalent to specifyingcyclesasnum_train_steps / cycle_length. Ifcycle_lengthis a float < 1.0, it is interpreted as a fraction of the total number of steps. If you specify a list of ints, the learning rate will cycle through the schedule with the specified cycle lengths.cycles: If you specify an int forcycles, the learning rate will cycle through the schedulecyclestimes. If you specify a list of ints, the learning rate will cycle through the schedule with the specified steps as the minima of the cycles.
It is an error to specify both cycles and cycle_length.
You can also specify cycles as a list, e.g. [10000, 25000, 50000]. In this case,
cycles is interpreted as the minima for the cycles, with the first and final steps being cycle minima as well.
cycles as an int is equivalent to list cycles with the low points evenly spaced at
[num_train_steps / (c + 1)].
See our paper on WSD-S for more information on cyclic LR schedules for training LLMs with short or no rewarmup.
AdamConfig¤
Additionally, levanter.optim.AdamConfig has the following fields:
| Parameter | Description | Default |
|---|---|---|
beta1 |
The beta1 parameter for Adam. | 0.9 |
beta2 |
The beta2 parameter for Adam. | 0.95 |
epsilon |
The epsilon parameter for Adam. | 1e-8 |
max_grad_norm |
The maximum gradient norm (for clipping). | 1.0 |
LM Model Config¤
levanter.models.lm_model.LmConfig is a Draccus "choice class" that acts as a base class for all autoregressive
language models in Levanter. You typically will specify a kind of model by using the type field, which is a string
that specifies the kind of model. For instance, type: gpt2 will use the levanter.models.gpt2.Gpt2Config class,
while type: llama will use the levanter.models.llama.LlamaConfig class.
We won't go into detail here. You can see the auto-generated docs below.
Auto-generated Documentation¤
Trainer¤
TrainerConfig(seed: int = 0, mp: jmp.Policy = jmp.get_policy('f32'), quantization: Optional[QuantizationConfig] = None, model_averaging: ModelAveragingConfig | None = None, wandb: Optional[tracker.wandb.WandbConfig] = None, log_dir: Path = Path('logs/'), id: Optional[str] = None, tracker: TrackerConfig | Tuple[TrackerConfig, ...] = field(default_factory=(tracker.wandb.WandbConfig)), watch: WatchConfig = WatchConfig(), profiler: bool = False, profiler_start_step: int = 5, profiler_num_steps: int = 100, profiler_perfetto_link: bool = False, log_jaxprs: bool = True, log_xla_hlo: bool = True, crash_on_nan: bool = True, crash_on_inf: bool = True, batch_axis: str = 'batch', fsdp_axis: Optional[Union[str, List[str]]] = 'embed', tensor_parallel_axes: Optional[List[str]] = None, axis_resources: Mapping[str, Union[Tuple[str], str]] = field(default_factory=dict), parameter_axis_resources: Mapping[str, Union[Tuple[str], str]] = field(default_factory=dict), replica_ici_axis_size: int = 1, model_axis_size: int = 1, replica_dcn_axis_size: int = 1, train_batch_size: int | IntSchedule = 512, per_device_parallelism: int = -1, per_device_eval_parallelism: int = -1, allow_nondivisible_batch_size: bool = False, num_train_steps: int = 400000, steps_per_eval: int = 1000, max_eval_batches: Optional[int] = None, checkpointer: CheckpointerConfig = field(default_factory=CheckpointerConfig), load_checkpoint: Optional[bool] = None, load_checkpoint_path: Optional[str] = None, initialize_from: Optional[str] = None, allow_partial_checkpoint: bool = False, jax_config: Mapping[str, JsonAtom] = field(default_factory=(lambda: copy.deepcopy(DEFAULT_JAX_CONFIG))), jax_compilation_cache_dir: Optional[str] = None, distributed: DistributedConfig = DistributedConfig(), ray: RayConfig = field(default_factory=RayConfig), require_accelerator: Optional[bool] = None, shutdown_at_exit: Union[bool, float] = False)
dataclass
¤
Methods:
-
batch_axis_at_step– -
initialize–Initializes jax, logging, setting the run name/id in the process
-
use_device_mesh–Context manager that sets the device mesh for jax, using Haliax's wrapper.
Attributes:
-
seed(int) – -
mp(Policy) – -
quantization(Optional[QuantizationConfig]) – -
model_averaging(ModelAveragingConfig | None) – -
wandb(Optional[WandbConfig]) – -
log_dir(Path) – -
id(Optional[str]) – -
tracker(TrackerConfig | Tuple[TrackerConfig, ...]) – -
watch(WatchConfig) – -
profiler(bool) – -
profiler_start_step(int) – -
profiler_num_steps(int) – -
profiler_perfetto_link(bool) – -
log_jaxprs(bool) –Whether to log the jaxpr of the training step. This is useful for debugging and understanding the model.
-
log_xla_hlo(bool) –Whether to log the XLA HLO of the training step. This is useful for debugging and understanding the model.
-
crash_on_nan(bool) – -
crash_on_inf(bool) – -
batch_axis(str) – -
fsdp_axis(Optional[Union[str, List[str]]]) – -
tensor_parallel_axes(Optional[List[str]]) – -
axis_resources(Mapping[str, Union[Tuple[str], str]]) –mapping from logical axis to physical axis. batch_axis, fsdp_axis, and tensor_parallel_axes are preferred
-
parameter_axis_resources(Mapping[str, Union[Tuple[str], str]]) –logical->physical mapping for parameter/optimizer sharding. fsdp_axis and tensor_parallel_axes are preferred
-
replica_ici_axis_size(int) – -
model_axis_size(int) –how many devices within each slice for sharding with DP. Fix TP=1, the rest of the devices is for FSDP.
-
replica_dcn_axis_size(int) –how many slices in the multislice scheme for sharding with DP and TP. The rest of the devices is for FSDP.
-
train_batch_size(int | IntSchedule) – -
per_device_parallelism(int) –how many examples to process in parallel on each device. -1 (default) means train_batch_size/num_devices
-
per_device_eval_parallelism(int) –how many examples to process in parallel on each device. -1 (default) means same as per_device_parallelism
-
allow_nondivisible_batch_size(bool) –Allow batch sizes to be non-divisible by the number of devices (or data axis size).
-
num_train_steps(int) – -
steps_per_eval(int) – -
max_eval_batches(Optional[int]) – -
checkpointer(CheckpointerConfig) – -
load_checkpoint(Optional[bool]) –if None (default), we'll load a checkpoint if it exists. If true, we must load a checkpoint
-
load_checkpoint_path(Optional[str]) –can be a parent (to find latest) or a specific checkpoint. if None, will set to checkpointer.base_path.
-
initialize_from(Optional[str]) –Load and continue training from a checkpoint. If None, will initialize from model_init.
-
allow_partial_checkpoint(bool) –If True, we allow loading a checkpoint that doesn't have all the parameters in the model.
-
jax_config(Mapping[str, JsonAtom]) – -
jax_compilation_cache_dir(Optional[str]) – -
distributed(DistributedConfig) – -
ray(RayConfig) – -
require_accelerator(Optional[bool]) – -
shutdown_at_exit(Union[bool, float]) – -
TrainBatch– -
batch_schedule– -
EvalBatch– -
microbatch_size(int | None) – -
device_mesh(Mesh) – -
eval_batch_size– -
num_slices–number of nodes
-
num_devices_per_slice–number of devices within a slice
-
data_ici_axis_size–size of the FSDP axis within slices
-
data_dcn_axis_size–size of the FSDP axis across slices
-
data_axis_size–size of the data parallel/batch parallel axis.
-
replica_axis_size–size of the data parallel/batch parallel axis.
-
compute_axis_mapping(ResourceMapping) –Mapping from logical axis to physical axis for compute.
-
parameter_axis_mapping(ResourceMapping) –
seed: int = 0
class-attribute
instance-attribute
¤
mp: jmp.Policy = jmp.get_policy('f32')
class-attribute
instance-attribute
¤
quantization: Optional[QuantizationConfig] = None
class-attribute
instance-attribute
¤
model_averaging: ModelAveragingConfig | None = None
class-attribute
instance-attribute
¤
wandb: Optional[tracker.wandb.WandbConfig] = None
class-attribute
instance-attribute
¤
log_dir: Path = Path('logs/')
class-attribute
instance-attribute
¤
id: Optional[str] = None
class-attribute
instance-attribute
¤
tracker: TrackerConfig | Tuple[TrackerConfig, ...] = field(default_factory=(tracker.wandb.WandbConfig))
class-attribute
instance-attribute
¤
watch: WatchConfig = WatchConfig()
class-attribute
instance-attribute
¤
profiler: bool = False
class-attribute
instance-attribute
¤
profiler_start_step: int = 5
class-attribute
instance-attribute
¤
profiler_num_steps: int = 100
class-attribute
instance-attribute
¤
profiler_perfetto_link: bool = False
class-attribute
instance-attribute
¤
log_jaxprs: bool = True
class-attribute
instance-attribute
¤
Whether to log the jaxpr of the training step. This is useful for debugging and understanding the model.
log_xla_hlo: bool = True
class-attribute
instance-attribute
¤
Whether to log the XLA HLO of the training step. This is useful for debugging and understanding the model.
crash_on_nan: bool = True
class-attribute
instance-attribute
¤
crash_on_inf: bool = True
class-attribute
instance-attribute
¤
batch_axis: str = 'batch'
class-attribute
instance-attribute
¤
fsdp_axis: Optional[Union[str, List[str]]] = 'embed'
class-attribute
instance-attribute
¤
tensor_parallel_axes: Optional[List[str]] = None
class-attribute
instance-attribute
¤
axis_resources: Mapping[str, Union[Tuple[str], str]] = field(default_factory=dict)
class-attribute
instance-attribute
¤
mapping from logical axis to physical axis. batch_axis, fsdp_axis, and tensor_parallel_axes are preferred
parameter_axis_resources: Mapping[str, Union[Tuple[str], str]] = field(default_factory=dict)
class-attribute
instance-attribute
¤
logical->physical mapping for parameter/optimizer sharding. fsdp_axis and tensor_parallel_axes are preferred
replica_ici_axis_size: int = 1
class-attribute
instance-attribute
¤
model_axis_size: int = 1
class-attribute
instance-attribute
¤
how many devices within each slice for sharding with DP. Fix TP=1, the rest of the devices is for FSDP.
replica_dcn_axis_size: int = 1
class-attribute
instance-attribute
¤
how many slices in the multislice scheme for sharding with DP and TP. The rest of the devices is for FSDP.
train_batch_size: int | IntSchedule = 512
class-attribute
instance-attribute
¤
per_device_parallelism: int = -1
class-attribute
instance-attribute
¤
how many examples to process in parallel on each device. -1 (default) means train_batch_size/num_devices
per_device_eval_parallelism: int = -1
class-attribute
instance-attribute
¤
how many examples to process in parallel on each device. -1 (default) means same as per_device_parallelism
allow_nondivisible_batch_size: bool = False
class-attribute
instance-attribute
¤
Allow batch sizes to be non-divisible by the number of devices (or data axis size).
This is typically used when you want a specific batch size but have a weird number of devices.
num_train_steps: int = 400000
class-attribute
instance-attribute
¤
steps_per_eval: int = 1000
class-attribute
instance-attribute
¤
max_eval_batches: Optional[int] = None
class-attribute
instance-attribute
¤
checkpointer: CheckpointerConfig = field(default_factory=CheckpointerConfig)
class-attribute
instance-attribute
¤
load_checkpoint: Optional[bool] = None
class-attribute
instance-attribute
¤
if None (default), we'll load a checkpoint if it exists. If true, we must load a checkpoint
load_checkpoint_path: Optional[str] = None
class-attribute
instance-attribute
¤
can be a parent (to find latest) or a specific checkpoint. if None, will set to checkpointer.base_path.
initialize_from: Optional[str] = None
class-attribute
instance-attribute
¤
Load and continue training from a checkpoint. If None, will initialize from model_init.
allow_partial_checkpoint: bool = False
class-attribute
instance-attribute
¤
If True, we allow loading a checkpoint that doesn't have all the parameters in the model. Missing parameters are initialized from the model_init function.
jax_config: Mapping[str, JsonAtom] = field(default_factory=(lambda: copy.deepcopy(DEFAULT_JAX_CONFIG)))
class-attribute
instance-attribute
¤
jax_compilation_cache_dir: Optional[str] = None
class-attribute
instance-attribute
¤
distributed: DistributedConfig = DistributedConfig()
class-attribute
instance-attribute
¤
ray: RayConfig = field(default_factory=RayConfig)
class-attribute
instance-attribute
¤
require_accelerator: Optional[bool] = None
class-attribute
instance-attribute
¤
shutdown_at_exit: Union[bool, float] = False
class-attribute
instance-attribute
¤
TrainBatch
property
¤
batch_schedule
cached
property
¤
EvalBatch
property
¤
microbatch_size: int | None
property
¤
device_mesh: Mesh
property
¤
eval_batch_size
property
¤
num_slices
cached
property
¤
number of nodes
num_devices_per_slice
property
¤
number of devices within a slice
data_ici_axis_size
property
¤
size of the FSDP axis within slices
data_dcn_axis_size
property
¤
size of the FSDP axis across slices
data_axis_size
property
¤
size of the data parallel/batch parallel axis.
replica_axis_size
property
¤
size of the data parallel/batch parallel axis.
compute_axis_mapping: ResourceMapping
cached
property
¤
Mapping from logical axis to physical axis for compute.
parameter_axis_mapping: ResourceMapping
cached
property
¤
batch_axis_at_step(step: int) -> Axis
¤
initialize()
¤
Initializes jax, logging, setting the run name/id in the process
use_device_mesh() -> ContextManager[None]
¤
Context manager that sets the device mesh for jax, using Haliax's wrapper.
In recent jax, this is the same as jax.set_mesh(self.device_mesh), but we use Haliax's wrapper for
compatibility with older jax versions.
Trainer(config: TrainerConfig, optimizer: GradientTransformation, loss_fn: ComputeLossFunction, *, add_default_hooks: bool = True)
¤
optimizer: the optimizer, e.g. `optax.adam(1e-3)` or produced by [levanter.optim.OptimizerConfig][]
loss_fn (Callable): the loss function. This should be a function that takes a model and some inputs and returns
either a scalar loss, or a tuple of (scalar loss, metrics_dict). The metrics dict will be automatically
logged to the tracker with appropriate prefixes (e.g., "train/accuracy", "eval/perplexity").
The function should be jit-able and should not have any side effects.
Methods:
-
add_hook– -
run_hooks– -
initial_state–Either loads a checkpoint or initializes a fresh trainer state. This is the recommended way to initialize
-
train_step–Performs a single training step.
-
training_steps–Generator that yields training steps and runs hooks.
-
train–Performs training until the number of steps is reached.
-
add_eval_hook– -
data_loader–Creates a data loader for the given dataset and batch axis.
-
write_artifact–Saves an artifact to disk (in the run dir) and logs it to the tracker.
Attributes:
-
tracker(Tracker) – -
is_trainable_param(PyTree[FilterSpec]) – -
hooks(TrainerHooks) – -
config(TrainerConfig) – -
optimizer(GradientTransformation) – -
loss_fn(WrappedLossFunction) –Wrapped loss function that always returns (loss, metrics_dict).
-
run_id(str) –Returns the run id
-
mp(Policy) –Returns the mixed precision policy
-
num_train_steps(int) – -
parameter_axis_mapping(ResourceMapping) – -
compute_axis_mapping(ResourceMapping) – -
device_mesh(Mesh) – -
TrainBatch– -
EvalBatch– -
checkpoint_path(str) –
tracker: levanter.tracker.Tracker
instance-attribute
¤
is_trainable_param: PyTree[FilterSpec]
instance-attribute
¤
hooks: TrainerHooks = TrainerHooks()
instance-attribute
¤
config: TrainerConfig = config
instance-attribute
¤
optimizer: GradientTransformation = optimizer
instance-attribute
¤
loss_fn: WrappedLossFunction
cached
property
¤
Wrapped loss function that always returns (loss, metrics_dict). Casts the model to compute precision and sets the context axis mapping to compute.
run_id: str
property
¤
Returns the run id
mp: jmp.Policy
property
¤
Returns the mixed precision policy
num_train_steps: int
property
¤
parameter_axis_mapping: ResourceMapping
property
¤
compute_axis_mapping: ResourceMapping
property
¤
device_mesh: Mesh
property
¤
TrainBatch
property
¤
EvalBatch
property
¤
checkpoint_path: str
property
¤
add_hook(fn: Optional[Callable[[StepInfo], Any] | Callback | JitCallback] = None, *, every: int = 1)
¤
add_hook(fn: Callable[[StepInfo], Any], *, every: int = 1)
add_hook(fn: JitCallback, *, every: int = 1)
add_hook(fn: Callback, *, every: int = 1)
add_hook(*, every: int = 1)
run_hooks(info: StepInfo, force: bool = False)
¤
initial_state(training_key: PRNGKeyArray, model: Optional[M] = None, model_init: Optional[Callable[[], M]] = None, *, is_trainable: PyTree[FilterSpec] = True) -> TrainerState[M]
¤
Either loads a checkpoint or initializes a fresh trainer state. This is the recommended way to initialize a trainer state.
This method is smart enough to handle subclasses of TrainerState. If you want to extend TrainerState, you can override _initialize_state_from_scratch
Args is_trainable: optional filter spec for the trainable parameters. This is used to filter out non-trainable parameters for the optimizer state and for computing gradients. Non-trainable parameters are also not checkpointed. If you don't specify this, all parameters are assumed to be trainable.
Returns:
-
TrainerState(TrainerState[M]) –the initial state,
train_step(state: S, *batch: X, **batch_kwargs) -> StepInfo[S]
¤
Performs a single training step.
training_steps(state: S, train_loader) -> typing.Iterator[StepInfo[S]]
¤
Generator that yields training steps and runs hooks.
train(state: S, train_loader: Iterable[X]) -> StepInfo[S]
¤
Performs training until the number of steps is reached.
add_eval_hook(eval_dataset, name: Optional[str] = None)
¤
data_loader(dataset: AsyncDataset[X], batch: Optional[hax.Axis] = None) -> DataLoader[X]
¤
Creates a data loader for the given dataset and batch axis.
Parameters:
-
(dataset¤AsyncDataset) –the dataset to load
-
(batch¤Optional[Axis], default:None) –the batch axis. If None, uses the trainer batch axis (and schedule, if applicable)
Returns:
-
DataLoader(DataLoader[X]) –the data loader
write_artifact(name: str, artifact: Any, type: Optional[str] = None)
¤
Saves an artifact to disk (in the run dir) and logs it to the tracker.
Checkpointer¤
CheckpointerConfig(base_path: str = 'checkpoints/', save_interval: timedelta = timedelta(minutes=15), keep: List[dict] = field(default_factory=(lambda: [dict(every=10000)])), append_run_id_to_base_path: bool = True, delete_old_temp_checkpoints: bool = True)
dataclass
¤
Methods:
-
expanded_path– -
create–
Attributes:
-
base_path(str) – -
save_interval(timedelta) – -
keep(List[dict]) – -
append_run_id_to_base_path(bool) – -
delete_old_temp_checkpoints(bool) –If True, delete old checkpoints from prior attempts at this run. If False, keep them.
base_path: str = 'checkpoints/'
class-attribute
instance-attribute
¤
save_interval: timedelta = timedelta(minutes=15)
class-attribute
instance-attribute
¤
keep: List[dict] = field(default_factory=(lambda: [dict(every=10000)]))
class-attribute
instance-attribute
¤
append_run_id_to_base_path: bool = True
class-attribute
instance-attribute
¤
delete_old_temp_checkpoints: bool = True
class-attribute
instance-attribute
¤
If True, delete old checkpoints from prior attempts at this run. If False, keep them.
This is useful if the run is being preempted and restarted, and you want to keep the old checkpoints.
expanded_path(run_id) -> str
¤
create(run_id) -> Checkpointer
¤
Checkpointer(base_path: PathLike, save_interval: Optional[datetime.timedelta], step_policies: Sequence[CheckpointInterval], *, keep_params: PyTree[FilterSpec] = True, dt_now_injection: Optional[Callable[[], datetime.datetime]] = None, delete_old_temp_checkpoints: bool = True)
¤
A checkpointer class that saves checkpoints with two different, but overlapping policies: time and step.
Note that this class is stateful: it keeps track of the last time a checkpoint was saved, and the last step a checkpoint was saved at.
Time policy: we save a checkpoint at least every save_interval seconds.
Step policy: we save a checkpoint every every steps, until until steps have been reached.
Time checkpoints are deleted after the next checkpoint is saved. Step checkpoints are never deleted.
Parameters:
-
(base_path¤PathLike) –the base path to save checkpoints to. may be gcs, local, or anything that tensorstore supports
-
(save_interval¤Optional[timedelta]) –the minimum amount of time between checkpoints (for time)
-
(step_policies¤Sequence[CheckpointInterval]) –the step policies to use
-
(keep_params¤PyTree[FilterSpec], default:True) –a PyTree of FilterSpecs that specifies which parameters to keep in the checkpoint
-
(dt_now_injection¤Optional[Callable[[], datetime]], default:None) –a function that returns the current time. useful for testing
-
(delete_old_temp_checkpoints¤bool, default:True) –if True, delete old checkpoints when saving a new one
Methods:
-
load_checkpoint– -
load_model–Convenience method/holdover from previous API for loading checkpoints.
-
on_step– -
wait_until_finished– -
save_checkpoint–
Attributes:
-
base_path(str) – -
save_interval(Optional[timedelta]) – -
step_policies(Sequence[CheckpointInterval]) – -
keep_params–
base_path: str = str(base_path)
instance-attribute
¤
save_interval: Optional[datetime.timedelta] = save_interval
instance-attribute
¤
step_policies: Sequence[CheckpointInterval] = list(step_policies)
class-attribute
instance-attribute
¤
keep_params = keep_params
instance-attribute
¤
load_checkpoint(state: M, path: Optional[PathLike] = None, *, discover_latest: bool = True, axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, mesh: Optional[haliax.partitioning.Mesh] = None) -> Optional[M]
¤
load_model(model: M, path: Optional[str] = None, *, discover_latest: bool = True, axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, mesh: Optional[haliax.partitioning.Mesh] = None) -> Optional[M]
¤
Convenience method/holdover from previous API for loading checkpoints.
Loads just the model assuming the model is in the model subdir of the discovered checkpoint.
on_step(info, force: bool = False)
¤
wait_until_finished()
¤
save_checkpoint(info, destination: str, commit_callback: Optional[Callable[[], None]] = None, *, is_temporary: bool = False)
¤
Trackers and Metrics¤
See also Trackers for more information. Basic configuration is shown below.
Single Tracker¤
trainer:
tracker:
type: wandb
project: my-project
entity: my-entity
Distributed and Ray¤
DistributedConfig(coordinator_address: Optional[str] = None, num_processes: Optional[int] = None, process_id: Optional[int] = None, local_device_ids: Optional[Union[int, List[int]]] = None, initialize_jax_distributed: bool = True)
dataclass
¤
Methods:
Attributes:
-
coordinator_address(Optional[str]) – -
num_processes(Optional[int]) – -
process_id(Optional[int]) – -
local_device_ids(Optional[Union[int, List[int]]]) – -
initialize_jax_distributed(bool) –
coordinator_address: Optional[str] = None
class-attribute
instance-attribute
¤
num_processes: Optional[int] = None
class-attribute
instance-attribute
¤
process_id: Optional[int] = None
class-attribute
instance-attribute
¤
local_device_ids: Optional[Union[int, List[int]]] = None
class-attribute
instance-attribute
¤
initialize_jax_distributed: bool = True
class-attribute
instance-attribute
¤
initialize()
¤
RayConfig(address: Optional[str] = None, start_workers: bool = True, auto_start_cluster: bool = True)
dataclass
¤
Model Averaging¤
Levanter can average model weights during training. Specify one of the
registered strategies in trainer.model_averaging:
trainer:
model_averaging:
type: ema # or 'ema_decay_sqrt'
ema– classic exponential moving average with parameterbeta.ema_decay_sqrt– EMA untilswitch_step, then decays with :math:1 - \sqrt{x}overdecay_steps.
EmaModelAveragingConfig(beta: float = 0.999)
dataclass
¤
EmaDecaySqrtConfig(beta: float = 0.999, switch_step: int = 100000, decay_steps: int = 100000)
dataclass
¤
Bases: ModelAveragingConfig[M]
EMA followed by :math:1 - \sqrt{x} decay.
Methods:
-
create–
Attributes:
-
beta(float) – -
switch_step(int) – -
decay_steps(int) –
Optimizer¤
OptimizerConfig(learning_rate: float = 0.0006, weight_decay: float = 0.1, min_lr_ratio: float = 0.1, warmup: int | float = 0.01, decay: int | float | None = None, rewarmup: int | float = 0.0, cooldown: Optional[float] = None, cycle_length: int | float | None | list[int] = None, cycles: int | list[int] | None = None, lr_schedule: LrSchedule | str = 'cosine', haps: Optional[list[int]] = None, weight_decay_modules: Optional[list[str] | str] = None, default_weight_decay_mask: Optional[bool] = None)
dataclass
¤
Bases: ChoiceRegistry, ABC
Methods:
Attributes:
-
learning_rate(float) – -
weight_decay(float) – -
min_lr_ratio(float) –The lr scheduler operates on 4 stages: [warmup] - {[stable] - [decay]} x haps - [cooldown]
-
warmup(int | float) –fraction of training steps to use as warmup, or steps to use. 0.0 means no warmup
-
decay(int | float | None) –fraction of training steps to use as decay, or steps to use. None means full decay
-
rewarmup(int | float) –If using a cycle, how much of the cycle to use as re-warmup. 0.0 means no re-warmup.
-
cooldown(Optional[float]) –Deprecated, as its semantics are confusing.
-
cycle_length(int | float | None | list[int]) –Length of cycle. If <= 1, it is treated as a fraction of the total number of steps. None is equivalent to 1.0.
-
cycles(int | list[int] | None) –Number of cycles or a list of cycle endpoints. Can use at most one of cycle_length, cycles, or haps.
-
lr_schedule(LrSchedule | str) – -
haps(Optional[list[int]]) –Deprecated.
-
weight_decay_modules(Optional[list[str] | str]) –A regex or a list of strings to identify where to mask weight.
-
default_weight_decay_mask(Optional[bool]) –Whether to apply a default reasonable weight decay to modules not explicitly masked. None means it will if
learning_rate: float = 0.0006
class-attribute
instance-attribute
¤
weight_decay: float = 0.1
class-attribute
instance-attribute
¤
min_lr_ratio: float = 0.1
class-attribute
instance-attribute
¤
The lr scheduler operates on 4 stages: [warmup] - {[stable] - [decay]} x haps - [cooldown]
warmup: int | float = 0.01
class-attribute
instance-attribute
¤
fraction of training steps to use as warmup, or steps to use. 0.0 means no warmup
decay: int | float | None = None
class-attribute
instance-attribute
¤
fraction of training steps to use as decay, or steps to use. None means full decay
rewarmup: int | float = 0.0
class-attribute
instance-attribute
¤
If using a cycle, how much of the cycle to use as re-warmup. 0.0 means no re-warmup.
cooldown: Optional[float] = None
class-attribute
instance-attribute
¤
Deprecated, as its semantics are confusing.
cycle_length: int | float | None | list[int] = None
class-attribute
instance-attribute
¤
Length of cycle. If <= 1, it is treated as a fraction of the total number of steps. None is equivalent to 1.0.
cycles: int | list[int] | None = None
class-attribute
instance-attribute
¤
Number of cycles or a list of cycle endpoints. Can use at most one of cycle_length, cycles, or haps.
lr_schedule: LrSchedule | str = 'cosine'
class-attribute
instance-attribute
¤
haps: Optional[list[int]] = None
class-attribute
instance-attribute
¤
Deprecated.
weight_decay_modules: Optional[list[str] | str] = None
class-attribute
instance-attribute
¤
A regex or a list of strings to identify where to mask weight.
For nano-GPT, this field can be set as r".*attn.*weight|.*mlp.*weight|.*token_embeddings|.*position_embeddings"
default_weight_decay_mask: Optional[bool] = None
class-attribute
instance-attribute
¤
Whether to apply a default reasonable weight decay to modules not explicitly masked. None means it will if no weight_decay_modules are set. False means it will not. True means it will regardless of weight_decay_modules.
default_choice_name() -> Optional[str]
classmethod
¤
build(num_train_steps: int)
abstractmethod
¤
build_weight_decay_mask()
¤
lr_scheduler(num_train_steps, override_lr=None)
¤
AdamConfig(learning_rate: float = 0.0006, weight_decay: float = 0.1, min_lr_ratio: float = 0.1, warmup: int | float = 0.01, decay: int | float | None = None, rewarmup: int | float = 0.0, cooldown: Optional[float] = None, cycle_length: int | float | None | list[int] = None, cycles: int | list[int] | None = None, lr_schedule: LrSchedule | str = 'cosine', haps: Optional[list[int]] = None, weight_decay_modules: Optional[list[str] | str] = None, default_weight_decay_mask: Optional[bool] = None, beta1: float = 0.9, beta2: float = 0.95, epsilon: float = 1e-08, max_grad_norm: Optional[float] = 1.0, nesterov: bool = False, update_rms_clipping: Optional[float] = None, clip_update_norm: Optional[ClipUpdateNormConfig] = None, skip_bad_steps: SkipStepConfig | int | bool = False, adamc_weight_decay: bool = False)
dataclass
¤
Bases: OptimizerConfig
Methods:
-
default_choice_name– -
build_weight_decay_mask– -
lr_scheduler– -
build–Creates the optimizer
Attributes:
-
learning_rate(float) – -
weight_decay(float) – -
min_lr_ratio(float) –The lr scheduler operates on 4 stages: [warmup] - {[stable] - [decay]} x haps - [cooldown]
-
warmup(int | float) –fraction of training steps to use as warmup, or steps to use. 0.0 means no warmup
-
decay(int | float | None) –fraction of training steps to use as decay, or steps to use. None means full decay
-
rewarmup(int | float) –If using a cycle, how much of the cycle to use as re-warmup. 0.0 means no re-warmup.
-
cooldown(Optional[float]) –Deprecated, as its semantics are confusing.
-
cycle_length(int | float | None | list[int]) –Length of cycle. If <= 1, it is treated as a fraction of the total number of steps. None is equivalent to 1.0.
-
cycles(int | list[int] | None) –Number of cycles or a list of cycle endpoints. Can use at most one of cycle_length, cycles, or haps.
-
lr_schedule(LrSchedule | str) – -
haps(Optional[list[int]]) –Deprecated.
-
weight_decay_modules(Optional[list[str] | str]) –A regex or a list of strings to identify where to mask weight.
-
default_weight_decay_mask(Optional[bool]) –Whether to apply a default reasonable weight decay to modules not explicitly masked. None means it will if
-
beta1(float) – -
beta2(float) – -
epsilon(float) – -
max_grad_norm(Optional[float]) – -
nesterov(bool) – -
update_rms_clipping(Optional[float]) –If set, this will use RMS clipping on the update, a la Adafactor or StableAdamW (https://arxiv.org/pdf/2304.13013)
-
clip_update_norm(Optional[ClipUpdateNormConfig]) –If set, this will clip the update norm based on the historical mean and standard deviation of update norms. A less extreme version of skip_bad_steps.
-
skip_bad_steps(SkipStepConfig | int | bool) –If set, defines the configuration for skipping steps when gradients are too large.
-
adamc_weight_decay(bool) –If set, use the AdamC corrected weight decay, which keeps
learning_rate: float = 0.0006
class-attribute
instance-attribute
¤
weight_decay: float = 0.1
class-attribute
instance-attribute
¤
min_lr_ratio: float = 0.1
class-attribute
instance-attribute
¤
The lr scheduler operates on 4 stages: [warmup] - {[stable] - [decay]} x haps - [cooldown]
warmup: int | float = 0.01
class-attribute
instance-attribute
¤
fraction of training steps to use as warmup, or steps to use. 0.0 means no warmup
decay: int | float | None = None
class-attribute
instance-attribute
¤
fraction of training steps to use as decay, or steps to use. None means full decay
rewarmup: int | float = 0.0
class-attribute
instance-attribute
¤
If using a cycle, how much of the cycle to use as re-warmup. 0.0 means no re-warmup.
cooldown: Optional[float] = None
class-attribute
instance-attribute
¤
Deprecated, as its semantics are confusing.
cycle_length: int | float | None | list[int] = None
class-attribute
instance-attribute
¤
Length of cycle. If <= 1, it is treated as a fraction of the total number of steps. None is equivalent to 1.0.
cycles: int | list[int] | None = None
class-attribute
instance-attribute
¤
Number of cycles or a list of cycle endpoints. Can use at most one of cycle_length, cycles, or haps.
lr_schedule: LrSchedule | str = 'cosine'
class-attribute
instance-attribute
¤
haps: Optional[list[int]] = None
class-attribute
instance-attribute
¤
Deprecated.
weight_decay_modules: Optional[list[str] | str] = None
class-attribute
instance-attribute
¤
A regex or a list of strings to identify where to mask weight.
For nano-GPT, this field can be set as r".*attn.*weight|.*mlp.*weight|.*token_embeddings|.*position_embeddings"
default_weight_decay_mask: Optional[bool] = None
class-attribute
instance-attribute
¤
Whether to apply a default reasonable weight decay to modules not explicitly masked. None means it will if no weight_decay_modules are set. False means it will not. True means it will regardless of weight_decay_modules.
beta1: float = 0.9
class-attribute
instance-attribute
¤
beta2: float = 0.95
class-attribute
instance-attribute
¤
epsilon: float = 1e-08
class-attribute
instance-attribute
¤
max_grad_norm: Optional[float] = 1.0
class-attribute
instance-attribute
¤
nesterov: bool = False
class-attribute
instance-attribute
¤
update_rms_clipping: Optional[float] = None
class-attribute
instance-attribute
¤
If set, this will use RMS clipping on the update, a la Adafactor or StableAdamW (https://arxiv.org/pdf/2304.13013)
(Note that this is distinct from StableAdamW b/c we clip on RMS(m/sqrt(v)) rather than RMS(g/sqrt(v)).)
A value of 1.0 is recommended for most models, but you can set it to None to disable RMS clipping.
clip_update_norm: Optional[ClipUpdateNormConfig] = None
class-attribute
instance-attribute
¤
If set, this will clip the update norm based on the historical mean and standard deviation of update norms. A less extreme version of skip_bad_steps.
skip_bad_steps: SkipStepConfig | int | bool = False
class-attribute
instance-attribute
¤
If set, defines the configuration for skipping steps when gradients are too large.
int means history length, bool means True for default config, False for no skipping.
"Bad" here means either the loss or grad norm is much much larger than the average of the last
rolling_interval_length steps. (Default is 128 steps, with a sigma factor of 6.0)
See https://github.com/allenai/OLMo-core/blob/main/src/olmo_core/optim/skip_step_optimizer.py
adamc_weight_decay: bool = False
class-attribute
instance-attribute
¤
If set, use the AdamC corrected weight decay, which keeps
weight_decay / lr constant across training.
This follows Defazio, On the Correct Treatment of Weight Decay in Adam (2025, https://arxiv.org/abs/2506.02285v2).
default_choice_name() -> Optional[str]
classmethod
¤
build_weight_decay_mask()
¤
lr_scheduler(num_train_steps, override_lr=None)
¤
build(num_train_steps)
¤
Creates the optimizer
SkipStepConfig(rolling_interval_length: int = 128, sigma_factor: float = 6.0)
dataclass
¤
Configuration for "skip step" logic in an optimizer.
This optimizer skips steps based on the history of loss and gradient norms.
If the current loss or gradient norm is significantly higher than the historical mean plus a multiple of the
standard deviation, the step is skipped. The decision is made based on a rolling window of the last
rolling_interval_length steps.
IF there are fewer than rolling_interval_length // 2 steps in the history, the step is always taken.
See https://github.com/allenai/OLMo-core/blob/main/src/olmo_core/optim/skip_step_optimizer.py
Methods:
-
from_bool_int_or_config–Converts a boolean, integer, or SkipStepConfig to a SkipStepConfig.
-
wrap–
Attributes:
-
rolling_interval_length(int) – -
sigma_factor(float) –
rolling_interval_length: int = 128
class-attribute
instance-attribute
¤
sigma_factor: float = 6.0
class-attribute
instance-attribute
¤
from_bool_int_or_config(config: Union[bool, int, SkipStepConfig]) -> Optional[SkipStepConfig]
staticmethod
¤
Converts a boolean, integer, or SkipStepConfig to a SkipStepConfig. If the input is True, it returns a default SkipStepConfig. If the input is False, it returns None. If the input is an integer, it sets rolling_interval_length to that value.
wrap(inner_optimizer: optax.GradientTransformation) -> optax.GradientTransformation
¤
LM Model¤
LmConfig(cross_entropy_block_size: Optional[int] = None)
dataclass
¤
Bases: PluginRegistry, ABC, Generic[LmT]
Methods:
Attributes:
-
model_type(Type[LmT]) – -
KeyPos(Axis) – -
Pos(Axis) – -
Embed(Axis) – -
cross_entropy_block_size(Optional[int]) –The block size for computing cross-entropy loss. This is the number of tokens that are processed together
model_type: Type[LmT]
abstractmethod
property
¤
KeyPos: Axis
abstractmethod
property
¤
Pos: Axis
abstractmethod
property
¤
Embed: Axis
abstractmethod
property
¤
cross_entropy_block_size: Optional[int] = None
class-attribute
instance-attribute
¤
The block size for computing cross-entropy loss. This is the number of tokens that are processed together in a single block. This can be adjusted to fit within memory constraints. It's deliberately set to a large value because it usually faster to compute the loss in larger blocks.
flops_per_token(vocab_size: int) -> Optional[float]
¤
total_trainable_params() -> Optional[float]
¤
build(Vocab: Axis, *, key: PRNGKeyArray) -> LmT
¤
Gpt2Config(cross_entropy_block_size: Optional[int] = None, seq_len: int = 1024, hidden_dim: int = 768, num_layers: int = 12, num_heads: int = 12, mlp_scale: int = 4, initializer_range: float = 0.02, embed_pdrop: float = 0.0, resid_pdrop: float = 0.0, attn_pdrop: float = 0.0, layer_norm_epsilon: float = 1e-05, activation_function: ActivationFunctionEnum = ActivationFunctionEnum.gelu_new, scale_attn_by_inverse_layer_idx: bool = False, upcast_attn: bool = False, gradient_checkpointing: bool = True, use_bias: bool = True, use_flash_attention: Optional[bool] = None, attn_backend: Optional[AttentionBackend] = None, flash_attention_block_size: Optional[int] = None)
dataclass
¤
Bases: HFCompatConfig
Methods:
-
hf_checkpoint_converter– -
to_hf_config– -
from_hf_config– -
flops_per_token– -
total_trainable_params– -
build–
Attributes:
-
seq_len(int) – -
hidden_dim(int) – -
num_layers(int) – -
num_heads(int) – -
mlp_scale(int) – -
initializer_range(float) – -
embed_pdrop(float) – -
resid_pdrop(float) – -
attn_pdrop(float) – -
layer_norm_epsilon(float) – -
activation_function(ActivationFunctionEnum) – -
scale_attn_by_inverse_layer_idx(bool) – -
upcast_attn(bool) – -
gradient_checkpointing(bool) – -
use_bias(bool) – -
use_flash_attention(Optional[bool]) – -
attn_backend(Optional[AttentionBackend]) – -
flash_attention_block_size(Optional[int]) – -
Pos– -
KeyPos– -
Embed– -
Heads– -
Layers– -
Mlp– -
HeadSize– -
model_type(Type[Gpt2LMHeadModel]) – -
cross_entropy_block_size(Optional[int]) –The block size for computing cross-entropy loss. This is the number of tokens that are processed together
seq_len: int = 1024
class-attribute
instance-attribute
¤
hidden_dim: int = 768
class-attribute
instance-attribute
¤
num_layers: int = 12
class-attribute
instance-attribute
¤
num_heads: int = 12
class-attribute
instance-attribute
¤
mlp_scale: int = 4
class-attribute
instance-attribute
¤
initializer_range: float = 0.02
class-attribute
instance-attribute
¤
embed_pdrop: float = 0.0
class-attribute
instance-attribute
¤
resid_pdrop: float = 0.0
class-attribute
instance-attribute
¤
attn_pdrop: float = 0.0
class-attribute
instance-attribute
¤
layer_norm_epsilon: float = 1e-05
class-attribute
instance-attribute
¤
activation_function: ActivationFunctionEnum = ActivationFunctionEnum.gelu_new
class-attribute
instance-attribute
¤
scale_attn_by_inverse_layer_idx: bool = False
class-attribute
instance-attribute
¤
upcast_attn: bool = False
class-attribute
instance-attribute
¤
gradient_checkpointing: bool = True
class-attribute
instance-attribute
¤
use_bias: bool = True
class-attribute
instance-attribute
¤
use_flash_attention: Optional[bool] = None
class-attribute
instance-attribute
¤
attn_backend: Optional[AttentionBackend] = None
class-attribute
instance-attribute
¤
flash_attention_block_size: Optional[int] = None
class-attribute
instance-attribute
¤
Pos = property(lambda self: Axis(name='position', size=(self.seq_len)))
class-attribute
instance-attribute
¤
KeyPos = property(lambda self: self.Pos.alias('key_position'))
class-attribute
instance-attribute
¤
Embed = property(lambda self: Axis(name='embed', size=(self.hidden_dim)))
class-attribute
instance-attribute
¤
Heads = property(lambda self: Axis(name='heads', size=(self.num_heads)))
class-attribute
instance-attribute
¤
Layers = property(lambda self: Axis(name='layers', size=(self.num_layers)))
class-attribute
instance-attribute
¤
Mlp = property(lambda self: Axis(name='mlp', size=(self.hidden_dim * self.mlp_scale)))
class-attribute
instance-attribute
¤
HeadSize = property(lambda self: Axis(name='head_size', size=(self.hidden_dim // self.num_heads)))
class-attribute
instance-attribute
¤
model_type: Type[Gpt2LMHeadModel]
property
¤
cross_entropy_block_size: Optional[int] = None
class-attribute
instance-attribute
¤
The block size for computing cross-entropy loss. This is the number of tokens that are processed together in a single block. This can be adjusted to fit within memory constraints. It's deliberately set to a large value because it usually faster to compute the loss in larger blocks.
hf_checkpoint_converter(ref_checkpoint: Optional[str] = None) -> HFCheckpointConverter[Gpt2Config]
¤
to_hf_config(vocab_size, config_overrides=None) -> HfGpt2Config
¤
from_hf_config(hf_config: HfConfig)
classmethod
¤
flops_per_token(vocab_size: int) -> Optional[float]
¤
total_trainable_params() -> Optional[float]
¤
build(Vocab: Axis, *, key: PRNGKeyArray) -> LmT
¤
LlamaConfig(cross_entropy_block_size: Optional[int] = None, seq_len: int = 2048, hidden_dim: int = 4096, intermediate_dim: int = 11008, num_layers: int = 32, num_heads: int = 32, head_dim: int | None = None, num_kv_heads: int = 32, activation_function: ActivationFunctionEnum = ActivationFunctionEnum.silu, initializer_range: float = 0.02, layer_norm_epsilon: float = 1e-05, tie_word_embeddings: bool = False, hybrid_norm: bool = False, use_qk_norm: bool = False, input_embedding_norm: bool = False, upcast_attn: bool = False, attn_backend: Optional[AttentionBackend] = None, flash_attention_block_size: Optional[int] = None, gradient_checkpointing: bool | ScanCheckpointPolicy | str = True, scan_layers: bool = True, use_bias: bool = False, use_layer_norm_weight: bool = True, rope: RotaryEmbeddingsConfig = DefaultRotaryEmbeddingsConfig(), reference_checkpoint: str = 'NousResearch/Llama-2-7b-hf', tokenizer: Optional[str] = None)
dataclass
¤
Bases: HFCompatConfig
Config for LlamaModel
Parameters:
-
(seq_len¤int, default:2048) –maximum length of the input sequence. Defaults to 2048.
-
(hidden_dim¤int, default:4096) –dimension of the hidden state. Defaults to 4096.
-
(intermediate_dim¤int, default:11008) –dimension of the intermediate state. Defaults to 11008.
-
(num_layers¤int, default:32) –number of hidden layers in the Transformer encoder. Defaults to 32.
-
(num_heads¤int, default:32) –number of attention heads for each attention layer. Defaults to 32.
-
(num_kv_heads¤int, default:32) –number of attention heads for keys and values in each attention layer. Setting to 1 means MQA. Setting to num_heads means MHA. Otherwise GQA. Note that num_heads must be divisible by this number. Defaults to 32.
-
(activation_function¤str, default:silu) –activation function for the hidden layer. Defaults to "silu".
-
(hybrid_norm¤bool, default:False) –whether to use hybrid normalization with additional layer norms after attention and MLP. Defaults to False.
-
(input_embedding_norm¤bool, default:False) –whether to use layer normalization after input embeddings. Defaults to False.
Methods:
-
hf_checkpoint_converter– -
from_hf_config– -
to_hf_config–Convert to HuggingFace's LlamaConfig
-
mk_LayerNorm– -
flops_per_token– -
total_trainable_params– -
attention_config–Convert this LlamaConfig to an AttentionConfig for use with Attention.
-
build–
Attributes:
-
seq_len(int) – -
hidden_dim(int) – -
intermediate_dim(int) – -
num_layers(int) – -
num_heads(int) – -
head_dim(int | None) – -
num_kv_heads(int) – -
activation_function(ActivationFunctionEnum) – -
initializer_range(float) – -
layer_norm_epsilon(float) – -
tie_word_embeddings(bool) – -
hybrid_norm(bool) – -
use_qk_norm(bool) – -
input_embedding_norm(bool) – -
upcast_attn(bool) – -
attn_backend(Optional[AttentionBackend]) – -
flash_attention_block_size(Optional[int]) – -
gradient_checkpointing(bool | ScanCheckpointPolicy | str) – -
scan_layers(bool) – -
use_bias(bool) – -
use_layer_norm_weight(bool) – -
rope(RotaryEmbeddingsConfig) – -
reference_checkpoint(str) – -
tokenizer(Optional[str]) – -
Pos– -
KeyPos– -
Embed– -
Layers– -
Mlp– -
model_type(Type[LlamaLMHeadModel]) – -
norm_config(LayerNormConfigBase) – -
actual_head_size–Returns the actual head size based on the head_dim or calculated from hidden_dim and num_heads.
-
cross_entropy_block_size(Optional[int]) –The block size for computing cross-entropy loss. This is the number of tokens that are processed together
seq_len: int = 2048
class-attribute
instance-attribute
¤
hidden_dim: int = 4096
class-attribute
instance-attribute
¤
intermediate_dim: int = 11008
class-attribute
instance-attribute
¤
num_layers: int = 32
class-attribute
instance-attribute
¤
num_heads: int = 32
class-attribute
instance-attribute
¤
head_dim: int | None = None
class-attribute
instance-attribute
¤
num_kv_heads: int = 32
class-attribute
instance-attribute
¤
activation_function: ActivationFunctionEnum = ActivationFunctionEnum.silu
class-attribute
instance-attribute
¤
initializer_range: float = 0.02
class-attribute
instance-attribute
¤
layer_norm_epsilon: float = 1e-05
class-attribute
instance-attribute
¤
tie_word_embeddings: bool = False
class-attribute
instance-attribute
¤
hybrid_norm: bool = False
class-attribute
instance-attribute
¤
use_qk_norm: bool = False
class-attribute
instance-attribute
¤
input_embedding_norm: bool = False
class-attribute
instance-attribute
¤
upcast_attn: bool = False
class-attribute
instance-attribute
¤
attn_backend: Optional[AttentionBackend] = None
class-attribute
instance-attribute
¤
flash_attention_block_size: Optional[int] = None
class-attribute
instance-attribute
¤
gradient_checkpointing: bool | ScanCheckpointPolicy | str = True
class-attribute
instance-attribute
¤
scan_layers: bool = True
class-attribute
instance-attribute
¤
use_bias: bool = False
class-attribute
instance-attribute
¤
use_layer_norm_weight: bool = True
class-attribute
instance-attribute
¤
rope: RotaryEmbeddingsConfig = dataclasses.field(default_factory=DefaultRotaryEmbeddingsConfig)
class-attribute
instance-attribute
¤
reference_checkpoint: str = 'NousResearch/Llama-2-7b-hf'
class-attribute
instance-attribute
¤
tokenizer: Optional[str] = None
class-attribute
instance-attribute
¤
Pos = property(lambda self: Axis(name='position', size=(self.seq_len)))
class-attribute
instance-attribute
¤
KeyPos = property(lambda self: self.Pos.alias('key_position'))
class-attribute
instance-attribute
¤
Embed = property(lambda self: Axis(name='embed', size=(self.hidden_dim)))
class-attribute
instance-attribute
¤
Layers = property(lambda self: Axis(name='layer', size=(self.num_layers)))
class-attribute
instance-attribute
¤
Mlp = property(lambda self: Axis(name='mlp', size=(self.intermediate_dim)))
class-attribute
instance-attribute
¤
model_type: Type[LlamaLMHeadModel]
property
¤
norm_config: LayerNormConfigBase
property
¤
actual_head_size
property
¤
Returns the actual head size based on the head_dim or calculated from hidden_dim and num_heads.
cross_entropy_block_size: Optional[int] = None
class-attribute
instance-attribute
¤
The block size for computing cross-entropy loss. This is the number of tokens that are processed together in a single block. This can be adjusted to fit within memory constraints. It's deliberately set to a large value because it usually faster to compute the loss in larger blocks.
hf_checkpoint_converter(ref_checkpoint: Optional[str] = None) -> HFCheckpointConverter[LlamaConfig]
¤
from_hf_config(hf_config: HfConfig)
classmethod
¤
to_hf_config(vocab_size: int, config_overrides: Optional[Dict] = None) -> HfLlamaConfig
¤
Convert to HuggingFace's LlamaConfig
Parameters:
-
(vocab_size¤int) –Vocabulary size of the tokenizer. Defaults to 32000.
-
(config_overrides¤dict, default:None) –Overrides for the config. Defaults to None.
Returns:
-
HfLlamaConfig(LlamaConfig) –HuggingFace's LlamaConfig
Raises:
-
ValueError–If hybrid_norm or input_embedding_norm are enabled, as these features are not supported in the HuggingFace config format.
mk_LayerNorm(axis: AxisSpec)
¤
flops_per_token(vocab_size: int)
¤
total_trainable_params(vocab_size)
¤
attention_config() -> AttentionConfig
¤
Convert this LlamaConfig to an AttentionConfig for use with Attention.