Skip to content

Dataset Format Reference¤

This document outlines the supported dataset formats in Levanter and how each format transforms raw data into model-ready tokens. These formats determine how Levanter tokenizes, structures, and masks training data. For a more directed, tutorial-like guide, see the Training Data Guide.

Overview¤

Levanter supports three canonical formats:

Format Intended Use Required Fields YAML Spec Example
text Language modeling pretraining "text" → string type: text
chat Conversational fine-tuning (SFT) "messages" → list of turns in OpenAI format type: chat
supervised Instruction tuning / seq2seq tasks two string fields (e.g. prompt, answer) type: supervised

Tip

Extra fields in the JSON are ignored. All input must be valid JSONL (i.e., one JSON object per line).


text Format¤

This is the default format used for pretraining.

Expected Input:

{"text": "The quick brown fox jumps over the lazy dog."}

Configuration¤

Tip

For text, format is optional.

format:
  type: text
  text_key: text  # optional, default is "text"

Processing:¤

  • Tokenizes the value in text_key
  • Appends EOS token and prepends BOS token if not already present

chat Format¤

Used for multi-turn conversation datasets (e.g. ShareGPT, OpenChat, Tulu).

Expected Input:

{"messages": [
  {"role": "user", "content": "Hello!"},
  {"role": "assistant", "content": "Hi there!"}
]}

Configuration:¤

format:
  type: chat
  messages_key: messages  # optional (default)
  pack: true  # optional (default)
  mask_user_turns: true  # optional (default). See below for important details!
  chat_template: |
    {{ bos_token }}
    {%- for message in messages -%}
    {%- if message['role'] == 'assistant' -%}
        <|start_header_id|>{{ message['role'] }}<|end_header_id|>
    {% generation %}{{- message['content'] | trim }}<|eot_id|>{% endgeneration %}\n
    {% else %}
    <|start_header_id|>{{ message['role'] }}<|end_header_id|>
    {{ message['content'] | trim }}<|eot_id|>
    {% endif %}
    {%- endfor -%}
    {%- if add_generation_prompt -%}
    <|start_header_id|>assistant<|end_header_id|>\n{% endif -%}
  • pack: true will pack multiple conversations into a single example if they fit within the context length.
  • pack: false will produce a single example per conversation. This is very inefficient.

Processing:¤

  • Requires a chat_template:
  • If not supplied in config, will use tokenizer.chat_template
  • If neither is available, raises an error
  • Uses template to flatten messages into a single token sequence
  • Builds loss_mask so that only assistant spans are predicted

Chat Templates¤

Chat templates are Jinja2 templates that format a list of messages into a single string. Hugging Face provides mostly sufficient documentation here but misses one important detail: the template must contain {%generation%} to indicate where the assistant message should be inserted. (See here.) We need this tag to construct the loss_mask for training, unless mask_user_turns is set to false.

Unfortunately, almost no tokenizers use this format, so you will need to write your own.

Here is an example we use in the stanford-crfm/marin-tokenizer tokenizer:

{{ bos_token }}
{%- for message in messages -%}
{%- if message['role'] == 'assistant' -%}
    <|start_header_id|>{{ message['role'] }}<|end_header_id|>
{% generation %}{{- message['content'] | trim }}<|eot_id|>{% endgeneration %}\n
{% else %}
<|start_header_id|>{{ message['role'] }}<|end_header_id|>
{{ message['content'] | trim }}<|eot_id|>
{% endif %}
{%- endfor -%}
{%- if add_generation_prompt -%}
<|start_header_id|>assistant<|end_header_id|>\n{% endif -%}

The key points are: * Wrap the assistant message in {% generation %} and {% endgeneration %} to indicate what the model is responsible for predicting. Jinja's handling of white space is confusing to me, so you'll want to be careful there. * Use {{ bos_token }} to prepend the BOS token. * Ensure that the generation prompt resembles the format of the training data (e.g. the final \n).


supervised Format¤

Used for single-turn instruction following or sequence-to-sequence tasks.

Expected Input:

{"prompt": "Translate to French: Hello", "answer": "Bonjour"}

Configuration:¤

format:
  type: supervised
  input_field: prompt
  output_field: answer
  separate_with: "\n"  # optional separator between input and output
  pack: true  # optional, default is true
  mask_inputs: true  # optional, default is true
  • pack: true will pack multiple examples into a single example if they fit within the context length.
  • pack: false will produce a single example per conversation. This is very inefficient.

Processing:¤

  • Tokenizes prompt, then tokenizes answer (with optional separator)
  • Produces a single input_ids sequence
  • Computes sources_len so that loss is masked on prompt tokens (assuming mask_inputs: true)

API¤

Overall Configs¤

LMMixtureDatasetConfig(tokenizer: str = 'gpt2', vocab_size: Optional[int] = None, cache_dir: Optional[str] = 'cache/', cache_options: CacheOptions = field(default_factory=CacheOptions), enforce_eos: bool = True, chat_template: str | None = None, ignore_token_id: Optional[int] = DEFAULT_IGNORE_INDEX, shuffle: bool | int = False, permutation_type: Literal['feistel', 'linear'] | None = None, configs: Dict[str, LMDatasetSourceConfig] = field(default_factory=dict), train_weights: Union[Dict[str, float], List[Tuple[int, Dict[str, float]]]] = field(default_factory=dict), stop_strategy: str = field(default=(StopStrategy.RESTART_STRATEGY)), target_budget: Optional[int] = None, experiment_budget: Optional[int] = None, mixture_block_size: int = 2048, max_train_batches: Optional[Dict[str, int]] = None, num_validation_sequences: Optional[Dict[str, int]] = None) dataclass ¤

Bases: LMTaskConfig

A mixture of language model datasets that supports dynamic weight changes during training.

Weights can be specified either as a single dictionary for constant mixing ratios, or as a list of (step, weights) tuples to change mixing ratios during training.

Methods:

Attributes:

cache_dir: Optional[str] = 'cache/' class-attribute instance-attribute ¤
configs: Dict[str, LMDatasetSourceConfig] = field(default_factory=dict) class-attribute instance-attribute ¤

Configuration of each dataset source (urls, hf dataset id, etc.)

train_weights: Union[Dict[str, float], List[Tuple[int, Dict[str, float]]]] = field(default_factory=dict) class-attribute instance-attribute ¤

Dataset mixing weights. Either a constant dict[name->weight] or list of (step, weights) tuples

stop_strategy: str = field(default=(StopStrategy.RESTART_STRATEGY)) class-attribute instance-attribute ¤
target_budget: Optional[int] = None class-attribute instance-attribute ¤
experiment_budget: Optional[int] = None class-attribute instance-attribute ¤
mixture_block_size: int = 2048 class-attribute instance-attribute ¤

Block size for deterministic mixing. In each block, a given dataset will have exactly the same number of samples, equal to the expected number of samples in the mixture, rounding in the expected way.

max_train_batches: Optional[Dict[str, int]] = None class-attribute instance-attribute ¤

Maximum number of batches to use from each dataset for training (using the initial batch size)

num_validation_sequences: Optional[Dict[str, int]] = None class-attribute instance-attribute ¤

Number of validation sequences to sample from the training set for each dataset

sources: Mapping[str, LmDatasetSourceConfigBase] property ¤
tokenizer: str = 'gpt2' class-attribute instance-attribute ¤
vocab_size: Optional[int] = None class-attribute instance-attribute ¤
cache_options: CacheOptions = field(default_factory=CacheOptions) class-attribute instance-attribute ¤
enforce_eos: bool = True class-attribute instance-attribute ¤
chat_template: str | None = None class-attribute instance-attribute ¤
ignore_token_id: Optional[int] = DEFAULT_IGNORE_INDEX class-attribute instance-attribute ¤
shuffle: bool | int = False class-attribute instance-attribute ¤

whether to shuffle the dataset. True means shuffle the whole dataset, False means don't shuffle. If you want to shuffle in eras, set this to the era length

permutation_type: Literal['feistel', 'linear'] | None = None class-attribute instance-attribute ¤

Type of permutation to use for shuffle.

If None, defaults to linear, but this will change in the future since Feistel is better.

the_tokenizer: HfTokenizer cached property ¤
build_token_datasets(caches: Mapping[str, TreeCache[dict]], Pos: Axis) ¤
train_set(Pos: Axis, batch_schedule: BatchSchedule, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: PRNGKeyArray, epochs: Optional[int] = None) -> AsyncDataset[LmExample] ¤
train_sets(Pos: Axis, monitors: Union[bool, List[MetricsMonitor]] = True, *, initial_batch_size: Optional[int] = None, epochs: Optional[int] = None, key: PRNGKeyArray) -> Mapping[str, AsyncDataset[LmExample]] ¤
validation_sets(Pos: Axis, monitors: Union[bool, List[MetricsMonitor]] = True) -> Mapping[str, AsyncDataset[LmExample]] ¤
build_caches(split: str, monitors: Union[bool, List[MetricsMonitor]] = True) -> Dict[str, TreeCache[dict]] ¤
tagged_eval_sets(Pos: Axis, monitors: Union[bool, List[MetricsMonitor]] = True) -> list[Tuple[AsyncDataset[LmExample], List[str]]] ¤

SingleDatasetLMConfigBase(tokenizer: str = 'gpt2', vocab_size: Optional[int] = None, cache_dir: Optional[str] = 'cache/', cache_options: CacheOptions = field(default_factory=CacheOptions), enforce_eos: bool = True, chat_template: str | None = None, ignore_token_id: Optional[int] = DEFAULT_IGNORE_INDEX, shuffle: bool | int = False, permutation_type: Literal['feistel', 'linear'] | None = None, tags: Optional[List[str]] = None, format: LmDatasetFormatBase = field(default_factory=TextLmDatasetFormat)) dataclass ¤

Bases: LmDatasetSourceConfigBase, LMTaskConfig

This class supports loading data both from HF Datasets and from a raw dataset of jsonl urls

Methods:

Attributes:

cache_dir: Optional[str] = 'cache/' class-attribute instance-attribute ¤
sources: Mapping[str, LmDatasetSourceConfigBase] property ¤
tokenizer: str = 'gpt2' class-attribute instance-attribute ¤
vocab_size: Optional[int] = None class-attribute instance-attribute ¤
cache_options: CacheOptions = field(default_factory=CacheOptions) class-attribute instance-attribute ¤
enforce_eos: bool = True class-attribute instance-attribute ¤
chat_template: str | None = None class-attribute instance-attribute ¤
ignore_token_id: Optional[int] = DEFAULT_IGNORE_INDEX class-attribute instance-attribute ¤
shuffle: bool | int = False class-attribute instance-attribute ¤

whether to shuffle the dataset. True means shuffle the whole dataset, False means don't shuffle. If you want to shuffle in eras, set this to the era length

permutation_type: Literal['feistel', 'linear'] | None = None class-attribute instance-attribute ¤

Type of permutation to use for shuffle.

If None, defaults to linear, but this will change in the future since Feistel is better.

the_tokenizer: HfTokenizer cached property ¤
tags: Optional[List[str]] = None class-attribute instance-attribute ¤

tags for the dataset. Typically the name of the dataset in the config will be added as a tag as well

format: LmDatasetFormatBase = field(default_factory=TextLmDatasetFormat) class-attribute instance-attribute ¤

format of the dataset.

train_set(Pos: Axis, batch_schedule: BatchSchedule, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: PRNGKeyArray, epochs: Optional[int] = None) -> AsyncDataset[LmExample] ¤
train_sets(Pos: Axis, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: PRNGKeyArray, epochs: Optional[int] = None) -> Mapping[str, AsyncDataset[LmExample]] ¤
validation_set(Pos: Axis, monitors: Union[bool, List[MetricsMonitor]] = True) -> AsyncDataset[LmExample] | None ¤
validation_sets(Pos: Axis, monitors: Union[bool, List[MetricsMonitor]] = True) -> Mapping[str, AsyncDataset[LmExample]] ¤
build_caches(split: str, monitors: Union[bool, List[MetricsMonitor]] = True) -> Mapping[str, TreeCache[dict]] ¤
build_or_load_cache(split: str, monitors: Union[bool, List[MetricsMonitor]] = True) -> Optional[TreeCache[dict]] ¤
tagged_eval_sets(Pos: Axis, monitors: Union[bool, List[MetricsMonitor]] = True) -> list[Tuple[AsyncDataset[LmExample], List[str]]] ¤
get_shard_source(split) -> Optional[ShardedDataSource[dict]] abstractmethod ¤
load_cache(split, tokenizer: HfTokenizer, override_cache_dir: str | None = None, enforce_eos=True) -> TreeCache[dict] ¤

HfSingleDatasetLMConfig(tokenizer: str = 'gpt2', vocab_size: Optional[int] = None, cache_dir: Optional[str] = 'cache/', cache_options: CacheOptions = field(default_factory=CacheOptions), enforce_eos: bool = True, chat_template: str | None = None, ignore_token_id: Optional[int] = DEFAULT_IGNORE_INDEX, shuffle: bool | int = False, permutation_type: Literal['feistel', 'linear'] | None = None, tags: Optional[List[str]] = None, format: LmDatasetFormatBase = field(default_factory=TextLmDatasetFormat), name: Optional[str] = None, stream: bool = True, *, id: str) dataclass ¤

Bases: SingleDatasetLMConfigBase, HfDatasetSourceConfig

Methods:

Attributes:

tokenizer: str = 'gpt2' class-attribute instance-attribute ¤
vocab_size: Optional[int] = None class-attribute instance-attribute ¤
cache_dir: Optional[str] = 'cache/' class-attribute instance-attribute ¤
cache_options: CacheOptions = field(default_factory=CacheOptions) class-attribute instance-attribute ¤
enforce_eos: bool = True class-attribute instance-attribute ¤
chat_template: str | None = None class-attribute instance-attribute ¤
ignore_token_id: Optional[int] = DEFAULT_IGNORE_INDEX class-attribute instance-attribute ¤
shuffle: bool | int = False class-attribute instance-attribute ¤

whether to shuffle the dataset. True means shuffle the whole dataset, False means don't shuffle. If you want to shuffle in eras, set this to the era length

permutation_type: Literal['feistel', 'linear'] | None = None class-attribute instance-attribute ¤

Type of permutation to use for shuffle.

If None, defaults to linear, but this will change in the future since Feistel is better.

the_tokenizer: HfTokenizer cached property ¤
sources: Mapping[str, LmDatasetSourceConfigBase] property ¤
tags: Optional[List[str]] = None class-attribute instance-attribute ¤

tags for the dataset. Typically the name of the dataset in the config will be added as a tag as well

format: LmDatasetFormatBase = field(default_factory=TextLmDatasetFormat) class-attribute instance-attribute ¤

format of the dataset.

id: str = dataclasses.field(kw_only=True) class-attribute instance-attribute ¤
name: Optional[str] = None class-attribute instance-attribute ¤
stream: bool = True class-attribute instance-attribute ¤
train_set(Pos: Axis, batch_schedule: BatchSchedule, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: PRNGKeyArray, epochs: Optional[int] = None) -> AsyncDataset[LmExample] ¤
train_sets(Pos: Axis, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: PRNGKeyArray, epochs: Optional[int] = None) -> Mapping[str, AsyncDataset[LmExample]] ¤
validation_sets(Pos: Axis, monitors: Union[bool, List[MetricsMonitor]] = True) -> Mapping[str, AsyncDataset[LmExample]] ¤
build_caches(split: str, monitors: Union[bool, List[MetricsMonitor]] = True) -> Mapping[str, TreeCache[dict]] ¤
tagged_eval_sets(Pos: Axis, monitors: Union[bool, List[MetricsMonitor]] = True) -> list[Tuple[AsyncDataset[LmExample], List[str]]] ¤
get_shard_source(split) -> Optional[ShardedDataSource[dict]] ¤
load_cache(split, tokenizer: HfTokenizer, override_cache_dir: str | None = None, enforce_eos=True) -> TreeCache[dict] ¤
validation_set(Pos: Axis, monitors: Union[bool, List[MetricsMonitor]] = True) -> AsyncDataset[LmExample] | None ¤
build_or_load_cache(split: str, monitors: Union[bool, List[MetricsMonitor]] = True) -> Optional[TreeCache[dict]] ¤

UrlSingleDatasetLMConfig(tokenizer: str = 'gpt2', vocab_size: Optional[int] = None, cache_dir: Optional[str] = 'cache/', cache_options: CacheOptions = field(default_factory=CacheOptions), enforce_eos: bool = True, chat_template: str | None = None, ignore_token_id: Optional[int] = DEFAULT_IGNORE_INDEX, shuffle: bool | int = False, permutation_type: Literal['feistel', 'linear'] | None = None, tags: Optional[List[str]] = None, format: LmDatasetFormatBase = field(default_factory=TextLmDatasetFormat), train_urls: list[str] = (), validation_urls: list[str] = ()) dataclass ¤

Bases: SingleDatasetLMConfigBase, UrlDatasetSourceConfig

Methods:

Attributes:

tokenizer: str = 'gpt2' class-attribute instance-attribute ¤
vocab_size: Optional[int] = None class-attribute instance-attribute ¤
cache_dir: Optional[str] = 'cache/' class-attribute instance-attribute ¤
cache_options: CacheOptions = field(default_factory=CacheOptions) class-attribute instance-attribute ¤
enforce_eos: bool = True class-attribute instance-attribute ¤
chat_template: str | None = None class-attribute instance-attribute ¤
ignore_token_id: Optional[int] = DEFAULT_IGNORE_INDEX class-attribute instance-attribute ¤
shuffle: bool | int = False class-attribute instance-attribute ¤

whether to shuffle the dataset. True means shuffle the whole dataset, False means don't shuffle. If you want to shuffle in eras, set this to the era length

permutation_type: Literal['feistel', 'linear'] | None = None class-attribute instance-attribute ¤

Type of permutation to use for shuffle.

If None, defaults to linear, but this will change in the future since Feistel is better.

the_tokenizer: HfTokenizer cached property ¤
sources: Mapping[str, LmDatasetSourceConfigBase] property ¤
tags: Optional[List[str]] = None class-attribute instance-attribute ¤

tags for the dataset. Typically the name of the dataset in the config will be added as a tag as well

format: LmDatasetFormatBase = field(default_factory=TextLmDatasetFormat) class-attribute instance-attribute ¤

format of the dataset.

train_urls: list[str] = () class-attribute instance-attribute ¤
validation_urls: list[str] = () class-attribute instance-attribute ¤
train_set(Pos: Axis, batch_schedule: BatchSchedule, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: PRNGKeyArray, epochs: Optional[int] = None) -> AsyncDataset[LmExample] ¤
train_sets(Pos: Axis, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: PRNGKeyArray, epochs: Optional[int] = None) -> Mapping[str, AsyncDataset[LmExample]] ¤
validation_sets(Pos: Axis, monitors: Union[bool, List[MetricsMonitor]] = True) -> Mapping[str, AsyncDataset[LmExample]] ¤
build_caches(split: str, monitors: Union[bool, List[MetricsMonitor]] = True) -> Mapping[str, TreeCache[dict]] ¤
tagged_eval_sets(Pos: Axis, monitors: Union[bool, List[MetricsMonitor]] = True) -> list[Tuple[AsyncDataset[LmExample], List[str]]] ¤
get_shard_source(split) -> Optional[ShardedDataSource[dict]] ¤
load_cache(split, tokenizer: HfTokenizer, override_cache_dir: str | None = None, enforce_eos=True) -> TreeCache[dict] ¤
urls_for_split(split) ¤
validation_set(Pos: Axis, monitors: Union[bool, List[MetricsMonitor]] = True) -> AsyncDataset[LmExample] | None ¤
build_or_load_cache(split: str, monitors: Union[bool, List[MetricsMonitor]] = True) -> Optional[TreeCache[dict]] ¤

Formats¤

LmDatasetFormatBase ¤

Bases: ABC, ChoiceRegistry

Methods:

default_choice_name() -> Optional[str] classmethod ¤

ChatLmDatasetFormat(messages_field: str = 'messages', single_turn: bool = False, chat_template: str | None = None, system_prompt: str | None = None, chat_template_kwargs: str | None = 'chat_template_kwargs', pack: bool = True, mask_user_turns: bool = True) dataclass ¤

Bases: LmDatasetFormatBase

Dataset configuration for multi-turn chat transcripts.

Attributes:

  • messages_field (str) –

    Field name containing the ordered list of chat messages.

  • single_turn (bool) –

    Treat examples as a single user/assistant exchange.

  • chat_template (str | None) –

    Overrides the tokenizer's chat template when provided.

  • system_prompt (str | None) –

    Field name carrying an optional system instruction to prepend.

  • chat_template_kwargs (str | None) –

    Field name containing optional keyword arguments passed to the chat template.

  • pack (bool) –

    Whether to allow example packing for efficient batching.

  • mask_user_turns (bool) –

    Mask user tokens from the training loss when True.

Methods:

messages_field: str = 'messages' class-attribute instance-attribute ¤
single_turn: bool = False class-attribute instance-attribute ¤
chat_template: str | None = None class-attribute instance-attribute ¤
system_prompt: str | None = None class-attribute instance-attribute ¤
chat_template_kwargs: str | None = 'chat_template_kwargs' class-attribute instance-attribute ¤
pack: bool = True class-attribute instance-attribute ¤
mask_user_turns: bool = True class-attribute instance-attribute ¤
default_choice_name() -> Optional[str] classmethod ¤

SupervisedLmDatasetFormat(input_field: str = CANONICAL_INPUT_FIELD, output_field: str = CANONICAL_OUTPUT_FIELD, separate_with: str | int | None = None, pack: bool = True, mask_inputs: bool = True) dataclass ¤

Bases: LmDatasetFormatBase

Dataset configuration for supervised input/output pairs.

Attributes:

  • input_field (str) –

    Field name with the model input text.

  • output_field (str) –

    Field name with the target response text.

  • separate_with (str | int | None) –

    Optional separator inserted between input and output.

  • pack (bool) –

    Whether to enable packing of multiple samples.

  • mask_inputs (bool) –

    Mask tokens from the input_field during loss computation.

Methods:

input_field: str = CANONICAL_INPUT_FIELD class-attribute instance-attribute ¤
output_field: str = CANONICAL_OUTPUT_FIELD class-attribute instance-attribute ¤
separate_with: str | int | None = None class-attribute instance-attribute ¤
pack: bool = True class-attribute instance-attribute ¤
mask_inputs: bool = True class-attribute instance-attribute ¤
default_choice_name() -> Optional[str] classmethod ¤

TextLmDatasetFormat(text_key: str = 'text') dataclass ¤

Bases: LmDatasetFormatBase

Dataset configuration for raw text examples.

Attributes:

  • text_key (str) –

    Field name containing the raw text or tokens.

Methods:

text_key: str = 'text' class-attribute instance-attribute ¤
default_choice_name() -> Optional[str] classmethod ¤

Datasets¤

CausalLmDataset(dataset: AsyncDataset[np.ndarray], Pos: Axis, *, ignore_index: Optional[int] = None, eos_id: Optional[int] = None) ¤

Bases: MappedAsyncDataset[ndarray, LmExample]

Methods:

Attributes:

dataset = dataset instance-attribute ¤
Pos = Pos instance-attribute ¤
ignore_id = ignore_index instance-attribute ¤
eos_id = eos_id instance-attribute ¤
fn = fn instance-attribute ¤
async_len() -> int async ¤
as_async_dataset() -> AsyncDataset[T_co] ¤
as_sync_dataset() ¤
final_length_is_known() -> bool async ¤
is_finite() -> bool ¤
current_len() -> Optional[int] async ¤
getitem_async(index: int) -> U async ¤
get_batch(indices: Sequence[int]) -> Sequence[U] async ¤
wait_until_len_at_least(length: int) -> int async ¤
map(fn: MapFunction[U], *extra_args, **extra_kwargs) -> MappedAsyncDataset[T_co, U] ¤
map_batches(fn: MapFunction[Sequence[U]], *extra_args, **extra_kwargs) -> BatchMappedAsyncDataset[U] ¤
slice_dataset(start_index: Optional[int] = None, end_index: Optional[int] = None) ¤

Slices the dataset from start_index to end_index.

take(n: int) ¤

Alias for slice_dataset(end_index=n).

shuffle(key: PRNGKeyArray, *, perm_type: PermType = 'feistel') ¤
era_shuffle(era_length: int, key: PRNGKeyArray, *, perm_type: PermType = 'feistel') ¤

MultiturnChatDataset(cache: TreeCache[ProcessedChatDict], Pos: Axis, max_segments_per_example: int = 64, slice_strategy: Literal['left', 'right', 'raise'] = 'left', mask_user_turns: bool = True) ¤

Bases: MappedAsyncDataset[tuple[ProcessedChatDict, ProcessedChatDict], LmExample]

A dataset that yields multiturn chat examples from a cache of processed chat data.

Parameters:

  • cache ¤
    (TreeCache[ProcessedChatDict]) –

    The cache of processed chat data.

  • Pos ¤
    (Axis) –

    The position axis.

  • max_segments_per_example ¤
    (int, default: 64 ) –

    The maximum number of segments to pack into a single example. Set to 1 to disable packing.

  • slice_strategy ¤
    (Literal['left', 'right', 'raise'], default: 'left' ) –

    The strategy to use when an example is too long.

Methods:

Attributes:

packed: GreedyPrepackedDataset[ProcessedChatDict] = GreedyPrepackedDataset(cache.store.tree, Pos.size, max_segments_per_example=max_segments_per_example, slice_strategy=slice_strategy) instance-attribute ¤
Pos = Pos instance-attribute ¤
mask_user_turns = mask_user_turns instance-attribute ¤
dataset = dataset instance-attribute ¤
fn = fn instance-attribute ¤
as_async_dataset() -> AsyncDataset[T_co] ¤
as_sync_dataset() ¤
async_len() -> int async ¤
final_length_is_known() -> bool async ¤
is_finite() -> bool ¤
current_len() -> Optional[int] async ¤
getitem_async(index: int) -> U async ¤
get_batch(indices: Sequence[int]) -> Sequence[U] async ¤
wait_until_len_at_least(length: int) -> int async ¤
map(fn: MapFunction[U], *extra_args, **extra_kwargs) -> MappedAsyncDataset[T_co, U] ¤
map_batches(fn: MapFunction[Sequence[U]], *extra_args, **extra_kwargs) -> BatchMappedAsyncDataset[U] ¤
slice_dataset(start_index: Optional[int] = None, end_index: Optional[int] = None) ¤

Slices the dataset from start_index to end_index.

take(n: int) ¤

Alias for slice_dataset(end_index=n).

shuffle(key: PRNGKeyArray, *, perm_type: PermType = 'feistel') ¤
era_shuffle(era_length: int, key: PRNGKeyArray, *, perm_type: PermType = 'feistel') ¤

SupervisedDataset(cache: TreeCache[ProcessedSupervisedDict], Pos: Axis, max_segments_per_example: int | None = 64, mask_inputs: bool = True, slice_strategy: Literal['left', 'right', 'raise'] = 'right') ¤

Bases: MappedAsyncDataset[tuple[ProcessedSupervisedDict, ProcessedSupervisedDict], LmExample]

A dataset that yields packed supervised examples from a cache of processed supervised data.

Methods:

Attributes:

mask_inputs = mask_inputs instance-attribute ¤
packed = GreedyPrepackedDataset(cache.store.tree, Pos.size, max_segments_per_example=max_segments_per_example, slice_strategy=slice_strategy) instance-attribute ¤
dataset = dataset instance-attribute ¤
fn = fn instance-attribute ¤
as_async_dataset() -> AsyncDataset[T_co] ¤
as_sync_dataset() ¤
async_len() -> int async ¤
final_length_is_known() -> bool async ¤
is_finite() -> bool ¤
current_len() -> Optional[int] async ¤
getitem_async(index: int) -> U async ¤
get_batch(indices: Sequence[int]) -> Sequence[U] async ¤
wait_until_len_at_least(length: int) -> int async ¤
map(fn: MapFunction[U], *extra_args, **extra_kwargs) -> MappedAsyncDataset[T_co, U] ¤
map_batches(fn: MapFunction[Sequence[U]], *extra_args, **extra_kwargs) -> BatchMappedAsyncDataset[U] ¤
slice_dataset(start_index: Optional[int] = None, end_index: Optional[int] = None) ¤

Slices the dataset from start_index to end_index.

take(n: int) ¤

Alias for slice_dataset(end_index=n).

shuffle(key: PRNGKeyArray, *, perm_type: PermType = 'feistel') ¤
era_shuffle(era_length: int, key: PRNGKeyArray, *, perm_type: PermType = 'feistel') ¤