Skip to content

train

Functions:

  • get_datasets

    Build TensorFlow datasets for training and validation.

  • main

    Command-line interface for training an equilibrium profile reconstruction model.

  • parse_yaml_config

    Parse a YAML run configuration file with custom tags.

  • train

    Train a Keras model using the provided configuration and datasets.

get_datasets

get_datasets(train_radial_res: int, val_radial_res: int, batch_size: int, rfp_only: bool) -> tuple[Dataset, Dataset]

Build TensorFlow datasets for training and validation.

This function instantiates :class:fpga_profile_reco.data.dataset.EQDataset objects for the train and val splits under :data:cfg.DATA_DIR, converts them into tf.data.Dataset pipelines and applies caching, shuffling (train only), batching and prefetching.

Parameters:

  • train_radial_res

    (int) –

    Radial resolution used when loading the training split.

  • val_radial_res

    (int) –

    Radial resolution used when loading the validation split.

  • batch_size

    (int) –

    Batch size used for both training and validation pipelines.

  • rfp_only

    (bool) –

    If True, restrict the dataset to samples in the RFP regime (as defined by the dataset implementation).

Returns:

  • train_ds ( Dataset ) –

    Prepared training dataset producing batched samples.

  • val_ds ( Dataset ) –

    Prepared validation dataset producing batched samples.

Source code in src/fpga_profile_reco/core/train.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def get_datasets(train_radial_res: int, val_radial_res: int, batch_size: int, rfp_only: bool) -> tuple[tf.data.Dataset, tf.data.Dataset]:
    """
    Build TensorFlow datasets for training and validation.

    This function instantiates :class:`fpga_profile_reco.data.dataset.EQDataset`
    objects for the ``train`` and ``val`` splits under :data:`cfg.DATA_DIR`,
    converts them into `tf.data.Dataset` pipelines and applies caching,
    shuffling (train only), batching and prefetching.

    Parameters
    ----------
    train_radial_res : int
        Radial resolution used when loading the training split.
    val_radial_res : int
        Radial resolution used when loading the validation split.
    batch_size : int
        Batch size used for both training and validation pipelines.
    rfp_only : bool
        If True, restrict the dataset to samples in the RFP regime (as defined
        by the dataset implementation).

    Returns
    -------
    train_ds : tf.data.Dataset
        Prepared training dataset producing batched samples.
    val_ds : tf.data.Dataset
        Prepared validation dataset producing batched samples.
    """
    base_path = cfg.DATA_DIR
    print("Training set:")
    train_dataset = EQDataset(data_path=base_path / "train", radial_res=train_radial_res, rfp_only=rfp_only)
    print("Validation set:")
    val_dataset = EQDataset(data_path=base_path / "val", radial_res=val_radial_res, uniform_sampling=True, rfp_only=rfp_only)

    train_ds = tf.data.Dataset.from_tensor_slices(train_dataset.get_data(scale_data=True))
    train_ds = train_ds.cache()
    train_ds = train_ds.shuffle(buffer_size=10 * batch_size, reshuffle_each_iteration=True).batch(batch_size, num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)

    val_ds = tf.data.Dataset.from_tensor_slices(val_dataset.get_data(scale_data=True))
    val_ds = val_ds.cache()
    val_ds = val_ds.batch(batch_size, num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)

    return train_ds, val_ds

main

main()

Command-line interface for training an equilibrium profile reconstruction model.

This function parses command-line arguments, loads a YAML run configuration, builds the training/validation datasets, instantiates the model, and runs training while writing logs and checkpoints to the configured output directories.

Command Line Parameters
  • config : pathlib.Path Path to the YAML run configuration file.
Source code in src/fpga_profile_reco/core/train.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
def main():
    """
    Command-line interface for training an equilibrium profile reconstruction model.

    This function parses command-line arguments, loads a YAML run configuration,
    builds the training/validation datasets, instantiates the model, and runs
    training while writing logs and checkpoints to the configured output
    directories.

    Command Line Parameters
    -----------------------
    - `config` : pathlib.Path
        Path to the YAML run configuration file.
    """
    import argparse
    import datetime
    import time

    from fpga_profile_reco.utils.helpers import format_time

    parser = argparse.ArgumentParser(description="Train a model for equilibrium profile reconstruction.")
    parser.add_argument('config', type=Path, help="Path to the YAML run configuration file.")

    args = parser.parse_args()

    # set memory growth for GPUs
    gpus = tf.config.experimental.list_physical_devices('GPU')
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)

    # read run configuration
    config = parse_yaml_config(args.config)

    train_ds, val_ds = get_datasets(train_radial_res=config['dataset']['train_radial_res'],
                                    val_radial_res=config['dataset']['val_radial_res'],
                                    batch_size=config['dataset']['batch_size'],
                                    rfp_only=config['dataset']['rfp_only'])

    # instantiate model
    model = HardNN(architecture=config['architecture'])
    # trigger model build to print summary
    model.build(input_shape=(None, 5))

    model.summary()

    start = time.time()
    print("\n\n===============================\n\n")
    print("Starting training at " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))

    history = train(model=model, config=config, train_ds=train_ds, val_ds=val_ds)

    print("\n\n===============================\n\n")
    print("Trained terminated after ", len(history['loss']), "epochs.")
    print("Training finished at " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
    print(f"Total training time: {format_time(time.time() - start)}")

parse_yaml_config

parse_yaml_config(yaml_config_path: Path) -> dict

Parse a YAML run configuration file with custom tags.

In addition to standard YAML types, this parser registers constructors on :class:yaml.SafeLoader for a few custom tags used by this project:

  • !tuple: build Python tuples
  • !CosineAnnealingScheduler: instantiate :class:fpga_profile_reco.utils.schedulers.CosineAnnealingScheduler
  • !ReduceLROnPlateau: instantiate :class:keras.callbacks.ReduceLROnPlateau
  • !EarlyStopping: instantiate :class:keras.callbacks.EarlyStopping

Parameters:

  • yaml_config_path

    (Path) –

    Path to the YAML configuration file.

Returns:

  • config ( dict ) –

    Parsed configuration dictionary.

Notes

This function registers YAML constructors globally via :func:yaml.add_constructor (for yaml.SafeLoader).

Source code in src/fpga_profile_reco/core/train.py
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
def parse_yaml_config(yaml_config_path: Path) -> dict:
    """
    Parse a YAML run configuration file with custom tags.

    In addition to standard YAML types, this parser registers constructors on
    :class:`yaml.SafeLoader` for a few custom tags used by this project:

    - ``!tuple``: build Python tuples
    - ``!CosineAnnealingScheduler``: instantiate
      :class:`fpga_profile_reco.utils.schedulers.CosineAnnealingScheduler`
    - ``!ReduceLROnPlateau``: instantiate :class:`keras.callbacks.ReduceLROnPlateau`
    - ``!EarlyStopping``: instantiate :class:`keras.callbacks.EarlyStopping`

    Parameters
    ----------
    yaml_config_path : pathlib.Path
        Path to the YAML configuration file.

    Returns
    -------
    config : dict
        Parsed configuration dictionary.

    Notes
    -----
    This function registers YAML constructors globally via
    :func:`yaml.add_constructor` (for ``yaml.SafeLoader``).
    """
    # define constructor for custom object tags
    def construct_tuple(loader, node):
        return tuple(loader.construct_sequence(node))

    def construct_cosine_annealing_scheduler(loader, node):
        mapping = loader.construct_mapping(node)
        return CosineAnnealingScheduler(max_T=mapping['max_T'], min_lr=mapping['min_lr'])

    def construct_reduce_lr_on_plateau(loader, node):
        mapping = loader.construct_mapping(node)
        return keras.callbacks.ReduceLROnPlateau(**mapping)

    def construct_early_stopping(loader, node):
        mapping = loader.construct_mapping(node)
        return keras.callbacks.EarlyStopping(**mapping)

    # register constructors
    yaml.add_constructor('!tuple', construct_tuple, Loader=yaml.SafeLoader)
    yaml.add_constructor('!CosineAnnealingScheduler', construct_cosine_annealing_scheduler, Loader=yaml.SafeLoader)
    yaml.add_constructor('!ReduceLROnPlateau', construct_reduce_lr_on_plateau, Loader=yaml.SafeLoader)
    yaml.add_constructor('!EarlyStopping', construct_early_stopping, Loader=yaml.SafeLoader)

    # read yaml config file
    with open(yaml_config_path, 'r') as f:
        config = yaml.safe_load(f)

    return config

train

train(model: Model, config: dict, train_ds: Dataset, val_ds: Dataset) -> dict

Train a Keras model using the provided configuration and datasets.

The model is compiled with an Adam optimizer using training.initial_lr. Training behavior is controlled by callback objects specified in the config, including optional learning-rate scheduling and early stopping, plus TensorBoard logging, CSV history logging, and best-checkpoint saving.

Parameters:

  • model

    (Model) –

    Model to train.

  • config

    (dict) –

    Run configuration dictionary as returned by :func:parse_yaml_config. Expected keys include run_config and training.

  • train_ds

    (Dataset) –

    Training dataset.

  • val_ds

    (Dataset) –

    Validation dataset.

Returns:

  • history ( dict ) –

    History dictionary (i.e., history.history) returned by :meth:keras.Model.fit, mapping metric names to lists of epoch values.

Source code in src/fpga_profile_reco/core/train.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
def train(model: keras.Model, config: dict, train_ds: tf.data.Dataset, val_ds: tf.data.Dataset) -> dict:
    """
    Train a Keras model using the provided configuration and datasets.

    The model is compiled with an Adam optimizer using ``training.initial_lr``.
    Training behavior is controlled by callback objects specified in the config,
    including optional learning-rate scheduling and early stopping, plus
    TensorBoard logging, CSV history logging, and best-checkpoint saving.

    Parameters
    ----------
    model : keras.Model
        Model to train.
    config : dict
        Run configuration dictionary as returned by :func:`parse_yaml_config`.
        Expected keys include ``run_config`` and ``training``.
    train_ds : tf.data.Dataset
        Training dataset.
    val_ds : tf.data.Dataset
        Validation dataset.

    Returns
    -------
    history : dict
        History dictionary (i.e., ``history.history``) returned by
        :meth:`keras.Model.fit`, mapping metric names to lists of epoch values.
    """
    run_config = config['run_config']
    training_config = config['training']

    # only compile with optimizer, loss and metrics are handled in the custom training loop
    model.compile(optimizer=keras.optimizers.Adam(learning_rate=training_config['initial_lr']))

    # setup various callbacks
    callbacks = []

    callbacks.append(keras.callbacks.TerminateOnNaN())
    if training_config.get('lr_scheduler', None):
        if isinstance(training_config['lr_scheduler'], keras.callbacks.ReduceLROnPlateau):
            callbacks.append(training_config['lr_scheduler'])
        else:
            callbacks.append(keras.callbacks.LearningRateScheduler(training_config['lr_scheduler'], verbose=1))
    if training_config.get('early_stopping', None):
        callbacks.append(training_config['early_stopping'])
    tb_path = cfg.TENSORBOARD_LOGS_DIR / run_config['name']
    tb_path.mkdir(parents=True, exist_ok=True)
    callbacks.append(keras.callbacks.TensorBoard(log_dir=tb_path, histogram_freq=10, update_freq='epoch'))
    csv_path = cfg.HISTORY_DIR
    csv_path.mkdir(parents=True, exist_ok=True)
    callbacks.append(keras.callbacks.CSVLogger(filename=csv_path / (run_config['name'] + '.csv'), append=False))
    model_save_path = cfg.MODELS_DIR
    model_save_path.mkdir(parents=True, exist_ok=True)
    callbacks.append(keras.callbacks.ModelCheckpoint(filepath=model_save_path / (run_config['name'] + '.keras'), monitor='val_loss', save_best_only=True))

    # run training
    history = model.fit(train_ds,
                        validation_data=val_ds,
                        callbacks=callbacks,
                        verbose=1,
                        epochs=training_config['epochs'])

    return history.history