Skip to content

train_hgq

Functions:

  • load_pretrained_model

    Load a serialized Keras model from disk.

  • main

    Command-line interface for training a quantized (HGQ) profile reconstruction model.

  • parse_yaml_config

    Parse a YAML run configuration file with HGQ-specific custom tags.

  • set_weights

    Copy weights from a base model into another model, layer-by-layer.

  • train

    Train a quantized (HGQ) Keras model using the provided configuration and datasets.

load_pretrained_model

load_pretrained_model(model_path: Path) -> Model

Load a serialized Keras model from disk.

Parameters:

  • model_path

    (Path) –

    Path to a saved Keras model (e.g., a .keras directory/file).

Returns:

  • model ( Model ) –

    Deserialized Keras model instance.

Source code in src/fpga_profile_reco/core/train_hgq.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def load_pretrained_model(model_path: Path) -> keras.Model:
    """
    Load a serialized Keras model from disk.

    Parameters
    ----------
    model_path : pathlib.Path
        Path to a saved Keras model (e.g., a ``.keras`` directory/file).

    Returns
    -------
    model : keras.Model
        Deserialized Keras model instance.
    """
    model = keras.models.load_model(model_path)
    return model

main

main()

Command-line interface for training a quantized (HGQ) profile reconstruction model.

This function parses command-line arguments, loads a YAML run configuration (with HGQ-specific custom tags), builds the training/validation datasets, instantiates :class:fpga_profile_reco.core.models.QHardNN, and runs training while writing logs and Pareto checkpoints to the configured output directories.

Command Line Parameters
  • config : pathlib.Path Path to the YAML run configuration file.
Source code in src/fpga_profile_reco/core/train_hgq.py
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
def main():
    """
    Command-line interface for training a quantized (HGQ) profile reconstruction model.

    This function parses command-line arguments, loads a YAML run configuration
    (with HGQ-specific custom tags), builds the training/validation datasets,
    instantiates :class:`fpga_profile_reco.core.models.QHardNN`, and runs training
    while writing logs and Pareto checkpoints to the configured output directories.

    Command Line Parameters
    -----------------------
    - `config` : pathlib.Path
        Path to the YAML run configuration file.
    """
    import argparse
    import datetime
    import time

    from fpga_profile_reco.utils.helpers import format_time

    parser = argparse.ArgumentParser(description="Train a quantized model for equilibrium profile reconstruction.")
    parser.add_argument('config', type=Path, help="Path to the YAML run configuration file.")

    args = parser.parse_args()

    # set memory growth for GPUs
    gpus = tf.config.experimental.list_physical_devices('GPU')
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)

    # read run configuration
    config = parse_yaml_config(args.config)

    train_ds, val_ds = get_datasets(train_radial_res=config['dataset']['train_radial_res'],
                                    val_radial_res=config['dataset']['val_radial_res'],
                                    batch_size=config['dataset']['batch_size'],
                                    rfp_only=config['dataset']['rfp_only'])

    # instantiate model
    model = QHardNN(architecture=config['architecture'], quantization=config['quantization'])
    # trigger model build to print summary
    model.build(input_shape=(None, 5))

    model.summary()

    start = time.time()
    print("\n\n===============================\n\n")
    print("Starting training at " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))

    history = train(model=model, config=config, train_ds=train_ds, val_ds=val_ds)

    print("\n\n===============================\n\n")
    print("Trained terminated after ", len(history['loss']), "epochs.")
    print("Training finished at " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
    print(f"Total training time: {format_time(time.time() - start)}")

parse_yaml_config

parse_yaml_config(yaml_config_path: Path) -> dict

Parse a YAML run configuration file with HGQ-specific custom tags.

In addition to standard YAML types, this parser registers constructors on :class:yaml.SafeLoader for custom tags used in the quantized (HGQ) training pipeline, including constraints, quantizer configs, and scheduler objects.

Registered custom tags
  • !tuple: build Python tuples
  • !Min: instantiate :class:hgq.constraints.Min
  • !Max: instantiate :class:hgq.constraints.Max
  • !MinMax: instantiate :class:hgq.constraints.MinMax
  • !QuantizerConfig: instantiate :class:hgq.config.QuantizerConfig
  • !PieceWiseSchedule: instantiate :class:hgq.utils.sugar.PieceWiseSchedule
  • !BetaScheduler: instantiate :class:hgq.utils.sugar.BetaScheduler
  • !CosineAnnealingScheduler: instantiate :class:fpga_profile_reco.utils.schedulers.CosineAnnealingScheduler
  • !CosineAnnealingWithRestartsScheduler: instantiate :class:fpga_profile_reco.utils.schedulers.CosineAnnealingWithRestartsScheduler
Notes

This function registers YAML constructors globally via :func:yaml.add_constructor (for yaml.SafeLoader).

Parameters:

  • yaml_config_path

    (Path) –

    Path to the YAML configuration file.

Returns:

  • config ( dict ) –

    Parsed configuration dictionary.

Source code in src/fpga_profile_reco/core/train_hgq.py
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def parse_yaml_config(yaml_config_path: Path) -> dict:
    """
    Parse a YAML run configuration file with HGQ-specific custom tags.

    In addition to standard YAML types, this parser registers constructors on
    :class:`yaml.SafeLoader` for custom tags used in the quantized (HGQ) training
    pipeline, including constraints, quantizer configs, and scheduler objects.

    Registered custom tags
    ----------------------
    - ``!tuple``: build Python tuples
    - ``!Min``: instantiate :class:`hgq.constraints.Min`
    - ``!Max``: instantiate :class:`hgq.constraints.Max`
    - ``!MinMax``: instantiate :class:`hgq.constraints.MinMax`
    - ``!QuantizerConfig``: instantiate :class:`hgq.config.QuantizerConfig`
    - ``!PieceWiseSchedule``: instantiate :class:`hgq.utils.sugar.PieceWiseSchedule`
    - ``!BetaScheduler``: instantiate :class:`hgq.utils.sugar.BetaScheduler`
    - ``!CosineAnnealingScheduler``: instantiate
      :class:`fpga_profile_reco.utils.schedulers.CosineAnnealingScheduler`
    - ``!CosineAnnealingWithRestartsScheduler``: instantiate
      :class:`fpga_profile_reco.utils.schedulers.CosineAnnealingWithRestartsScheduler`

    Notes
    -----
    This function registers YAML constructors globally via
    :func:`yaml.add_constructor` (for ``yaml.SafeLoader``).

    Parameters
    ----------
    yaml_config_path : pathlib.Path
        Path to the YAML configuration file.

    Returns
    -------
    config : dict
        Parsed configuration dictionary.
    """
    # define constructor for custom object tags
    def construct_tuple(loader, node):
        return tuple(loader.construct_sequence(node))

    def construct_min(loader, node):
        mapping = loader.construct_mapping(node)
        return hgq.constraints.Min(min_value=mapping['min_value'])

    def construct_max(loader, node):
        mapping = loader.construct_mapping(node)
        return hgq.constraints.Max(max_value=mapping['max_value'])

    def construct_min_max(loader, node):
        mapping = loader.construct_mapping(node)
        return hgq.constraints.MinMax(min_value=mapping['min_value'], max_value=mapping['max_value'])

    def construct_quantizer_config(loader, node):
        mapping = loader.construct_mapping(node)
        return hgq.config.QuantizerConfig(**mapping)

    def construct_piecewise_schedule(loader, node):
        mapping = loader.construct_mapping(node, deep=True)
        return PieceWiseSchedule(intervals=mapping['intervals'])

    def construct_beta_scheduler(loader, node):
        mapping = loader.construct_mapping(node, deep=True)
        return BetaScheduler(beta_fn=mapping['beta_fn'])

    def construct_cosine_annealing_scheduler(loader, node):
        mapping = loader.construct_mapping(node)
        return CosineAnnealingScheduler(max_T=mapping['max_T'], min_lr=mapping['min_lr'])

    def construct_cosine_annealing_with_restarts_scheduler(loader, node):
        mapping = loader.construct_mapping(node, deep=True)
        return CosineAnnealingWithRestartsScheduler(restart_lrs=mapping['restart_lrs'], min_lrs=mapping['min_lrs'], Ts=mapping['Ts'])

    # register constructors
    yaml.add_constructor('!tuple', construct_tuple, Loader=yaml.SafeLoader)
    yaml.add_constructor('!Min', construct_min, Loader=yaml.SafeLoader)
    yaml.add_constructor('!Max', construct_max, Loader=yaml.SafeLoader)
    yaml.add_constructor('!MinMax', construct_min_max, Loader=yaml.SafeLoader)
    yaml.add_constructor('!QuantizerConfig', construct_quantizer_config, Loader=yaml.SafeLoader)
    yaml.add_constructor('!PieceWiseSchedule', construct_piecewise_schedule, Loader=yaml.SafeLoader)
    yaml.add_constructor('!BetaScheduler', construct_beta_scheduler, Loader=yaml.SafeLoader)
    yaml.add_constructor('!CosineAnnealingScheduler', construct_cosine_annealing_scheduler, Loader=yaml.SafeLoader)
    yaml.add_constructor('!CosineAnnealingWithRestartsScheduler', construct_cosine_annealing_with_restarts_scheduler, Loader=yaml.SafeLoader)

    # read yaml config file
    with open(yaml_config_path, 'r') as f:
        config = yaml.safe_load(f)

    return config

set_weights

set_weights(base_model: Model, model: Model) -> None

Copy weights from a base model into another model, layer-by-layer.

This function iterates over layers in base_model and model in lockstep (via :func:zip) and replaces the first two weight arrays of each target layer with those from the corresponding base layer.

Notes
  • This assumes that corresponding layers have compatible weight structures and that the target layer has at least two weight tensors (commonly kernel and bias).
  • Layers are matched purely by position, not by name.

Parameters:

  • base_model

    (Model) –

    Model to copy weights from.

  • model

    (Model) –

    Model to copy weights into.

Returns:

  • None
Source code in src/fpga_profile_reco/core/train_hgq.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def set_weights(base_model: keras.Model, model: keras.Model) -> None:
    """
    Copy weights from a base model into another model, layer-by-layer.

    This function iterates over layers in ``base_model`` and ``model`` in lockstep
    (via :func:`zip`) and replaces the first two weight arrays of each target
    layer with those from the corresponding base layer.

    Notes
    -----
    - This assumes that corresponding layers have compatible weight structures
      and that the target layer has at least two weight tensors (commonly kernel
      and bias).
    - Layers are matched purely by position, not by name.

    Parameters
    ----------
    base_model : keras.Model
        Model to copy weights from.
    model : keras.Model
        Model to copy weights into.

    Returns
    -------
    None
    """
    for base_layer, layer in zip(base_model.layers, model.layers):
        weight_list = layer.get_weights()
        base_model_weights = base_layer.get_weights()
        # set weights and biases
        weight_list[0] = base_model_weights[0]
        weight_list[1] = base_model_weights[1]
        layer.set_weights(weight_list)

train

train(model: Model, config: dict, train_ds: Dataset, val_ds: Dataset) -> dict

Train a quantized (HGQ) Keras model using the provided configuration and datasets.

The model is compiled with an Adam optimizer using training.initial_lr. The loss/metrics are assumed to be handled by the model's internal/custom training logic. Training behavior is controlled by callbacks specified in the config, plus HGQ utilities such as EBOP accounting and Pareto checkpointing.

Parameters:

  • model

    (Model) –

    Model to train.

  • config

    (dict) –

    Run configuration dictionary as returned by :func:parse_yaml_config. Expected keys include run_config and training.

  • train_ds

    (Dataset) –

    Training dataset.

  • val_ds

    (Dataset) –

    Validation dataset.

Returns:

  • history ( dict ) –

    History dictionary (i.e., history.history) returned by :meth:keras.Model.fit, mapping metric names to lists of epoch values.

Source code in src/fpga_profile_reco/core/train_hgq.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
def train(model: keras.Model, config: dict, train_ds: tf.data.Dataset, val_ds: tf.data.Dataset) -> dict:
    """
    Train a quantized (HGQ) Keras model using the provided configuration and datasets.

    The model is compiled with an Adam optimizer using ``training.initial_lr``.
    The loss/metrics are assumed to be handled by the model's internal/custom
    training logic. Training behavior is controlled by callbacks specified in
    the config, plus HGQ utilities such as EBOP accounting and Pareto checkpointing.

    Parameters
    ----------
    model : keras.Model
        Model to train.
    config : dict
        Run configuration dictionary as returned by :func:`parse_yaml_config`.
        Expected keys include ``run_config`` and ``training``.
    train_ds : tf.data.Dataset
        Training dataset.
    val_ds : tf.data.Dataset
        Validation dataset.

    Returns
    -------
    history : dict
        History dictionary (i.e., ``history.history``) returned by
        :meth:`keras.Model.fit`, mapping metric names to lists of epoch values.
    """
    run_config = config['run_config']
    training_config = config['training']

    # only compile with optimizer, loss and metrics are handled in the custom training loop
    model.compile(optimizer=keras.optimizers.Adam(learning_rate=training_config['initial_lr']))

    # load pretrained model weights if specified
    if run_config['load_pretrained_model']:
        print("Loading pretrained model weights from:", run_config['pretrained_model_path'])
        base_model = load_pretrained_model(run_config['pretrained_model_path'])
        set_weights(base_model, model)

    # setup various callbacks
    callbacks = []

    callbacks.append(keras.callbacks.TerminateOnNaN())
    if training_config['lr_scheduler']:
        callbacks.append(keras.callbacks.LearningRateScheduler(training_config['lr_scheduler'], verbose=1))
    if training_config['beta_scheduler']:
        callbacks.append(training_config['beta_scheduler'])
    callbacks.append(FreeEBOPs())
    tb_path = cfg.TENSORBOARD_LOGS_DIR / run_config['name']
    tb_path.mkdir(parents=True, exist_ok=True)
    callbacks.append(keras.callbacks.TensorBoard(log_dir=tb_path, histogram_freq=10, update_freq='epoch'))
    chkpt_path = cfg.PARETO_CHKPTS_DIR / run_config['name']
    chkpt_path.mkdir(parents=True, exist_ok=True)
    callbacks.append(ParetoFront(path=chkpt_path,
                                 fname_format='{epoch:04d}-val_loss-{val_loss:.4g}-ebops-{ebops:.4g}.keras',
                                 metrics=['val_loss', 'ebops'],
                                 enable_if=lambda x: x['val_loss'] < 1e-5,  # require a minimum val_loss to save
                                 sides=[-1, -1]))
    csv_path = cfg.HISTORY_DIR
    csv_path.mkdir(parents=True, exist_ok=True)
    callbacks.append(keras.callbacks.CSVLogger(filename=csv_path / (run_config['name'] + '.csv'), append=False))

    # run training
    history = model.fit(train_ds,
                        validation_data=val_ds,
                        callbacks=callbacks,
                        verbose=1,
                        epochs=training_config['epochs'])

    return history.history