valentins-bf-benchmark-runs/079890cd838d49e5be32febd5bf07a62/artifacts/source/train.py
2025-05-28 14:56:14 +02:00

246 lines
9.6 KiB (Stored with Git LFS)
Python

import os
from pathlib import Path
from urllib.parse import unquote, urlparse
from omegaconf import OmegaConf
import mlflow
# specify the keras backend we want to use
os.environ["KERAS_BACKEND"] = "jax"
import keras
import torch
import numpy as np
import bayesflow as bf
import sbibm
from sbibm.metrics import c2st, mmd
import pandas as pd
import timeit
import tqdm
from mlflow_bayesflow_benchmark_plugin import log_bayesflow_params
if __name__ == "__main__":
config = OmegaConf.load("config.yaml")
with mlflow.start_run():
artifact_path = Path(unquote(urlparse(mlflow.get_artifact_uri()).path))
artifact_path.mkdir(parents=True, exist_ok=True)
# required parameters
task = sbibm.get_task(config.task.name)
if config.training.mode == "reference":
posterior_samples = []
for i in range(1, 11):
seed = 2025
np.random.seed(seed)
torch.manual_seed(seed)
samples = task._sample_reference_posterior(
num_samples=config.task.num_posterior_samples,
num_observation=i,
)
num_unique = torch.unique(samples, dim=0).shape[0]
assert num_unique == config.task.num_posterior_samples
posterior_samples.append(samples.cpu().numpy())
posterior_samples = np.stack(posterior_samples, axis=0)
else:
prior = task.get_prior()
likelihood = task.get_simulator()
observations = np.concatenate(
[task.get_observation(num_observation=i) for i in range(1, 11)], axis=0
)
def bf_prior(batch_shape, **kwargs):
return dict(
parameters=prior(num_samples=batch_shape[0], **kwargs).squeeze(-1)
)
def bf_likelihood(batch_shape, parameters, **kwargs):
return dict(observables=likelihood(parameters, **kwargs))
simulator = bf.simulators.SequentialSimulator(
[
bf.simulators.LambdaSimulator(bf_prior, is_batched=True),
bf.simulators.LambdaSimulator(bf_likelihood, is_batched=True),
]
)
adapter = (
bf.adapters.Adapter.create_default(inference_variables=["parameters"])
# standardize data variables to zero mean and unit variance
.standardize(exclude="inference_variables")
# rename the variables to match the required approximator inputs
.rename("observables", "inference_conditions")
)
inference_network_kwargs = (
OmegaConf.to_container(config.architecture.inference_network).get(
"kwargs", {}
)
or {}
)
inference_network_type = config.architecture.inference_network.type
if keras.saving.get_registered_object(inference_network_type) is not None:
inference_network_type = keras.saving.get_registered_object(
inference_network_type
)
inference_network = bf.utils.find_inference_network(
inference_network_type,
**inference_network_kwargs,
)
summary_network = config.architecture.summary_network
if summary_network is not None:
summary_network_kwargs = (
OmegaConf.to_container(config.architecture.summary_network).get(
"kwargs", {}
)
or {}
)
summary_network_type = config.architecture.inference_network.type
if keras.saving.get_registered_object(summary_network_type) is not None:
summary_network_type = keras.saving.get_registered_object(
inference_network_type
)
summary_network = bf.utils.find_summary_network(
summary_network_type,
**summary_network_kwargs,
)
workflow = bf.BasicWorkflow(
simulator=simulator,
adapter=adapter,
inference_network=inference_network,
summary_network=summary_network,
initial_learning_rate=config.training.initial_learning_rate,
optimizer=config.training.optimizer,
)
validation_data = simulator.sample(
64,
)
if config.training.mode == "offline":
training_data = simulator.rejection_sample(
config.training.num_simulations,
sample_size=1000,
predicate=lambda x: np.full((x["observables"].shape[0],), True),
)
start_time = timeit.default_timer()
history = workflow.fit_offline(
training_data,
epochs=config.training.epochs,
batch_size=config.training.batch_size,
validation_data=validation_data,
verbose=2,
)
elif config.training.mode == "online":
start_time = timeit.default_timer()
history = workflow.fit_online(
epochs=config.training.epochs,
num_batches_per_epoch=config.training.num_batches_per_epoch,
batch_size=config.training.batch_size,
validation_data=validation_data,
verbose=2,
)
else:
raise ValueError(
f"Invalid training mode config.traning.mode='{config.training.mode}'"
)
mlflow.log_metric("training_time", timeit.default_timer() - start_time)
true_params = np.concatenate(
[task.get_true_parameters(i) for i in range(1, 11)], axis=0
)
try:
log_prob = workflow.log_prob(
{
"observables": observations,
"parameters": true_params,
}
)
log_prob_path = artifact_path / "log_prob.csv"
pd.DataFrame(
{"num_observation": list(range(1, 11)), "log_prob": log_prob}
).to_csv(log_prob_path, index=False)
mlflow.log_metric("log_prob", np.mean(log_prob))
except Exception as e:
print("Could not calclate log_prob:", e)
start_time = timeit.default_timer()
observation = task.get_observation(num_observation=1)
posterior_samples = workflow.sample(
num_samples=config.task.num_posterior_samples,
conditions={"observables": observations},
)["parameters"]
mlflow.log_metric("sampling_time", timeit.default_timer() - start_time)
log_bayesflow_params(workflow.approximator)
# Following the advice from the sbibm repository:
# https://github.com/sbi-benchmark/results/blob/741183cdece7077453efbedac0fad7e74afa10e5/benchmarking_sbi/run.py#L200
# Use c2st with z-scoring, and MMD without z-scoring
metric_fns = {
"sbibm.c2st": lambda reference_samples, posterior_samples: c2st(
X=reference_samples,
Y=torch.Tensor(posterior_samples),
z_score=True,
)[0]
.cpu()
.numpy(),
"sbibm.mmd": lambda reference_samples, posterior_samples: mmd(
X=reference_samples,
Y=torch.Tensor(posterior_samples),
z_score=False,
),
}
c2st_accuracies = []
mmd_values = []
metrics = {k: [] for k in metric_fns.keys()}
print("Draw samples and calculate metrics for 10 observations...")
for i, num_observation in tqdm.tqdm(enumerate(range(1, 11)), total=10):
reference_samples = task.get_reference_posterior_samples(
num_observation=num_observation
)[: config.task.num_posterior_samples, :]
pd.DataFrame(
reference_samples.cpu().numpy(), columns=task.get_labels_parameters()
).to_csv(
artifact_path / f"{num_observation:02}_reference_samples.csv.gz",
index=False,
)
pd.DataFrame(
keras.ops.convert_to_numpy(
posterior_samples[i, : config.task.num_posterior_samples]
),
columns=task.get_labels_parameters(),
).to_csv(
artifact_path / f"{num_observation:02}_posterior_samples.csv.gz",
index=False,
)
for metric_name, metric_fn in metric_fns.items():
metrics[metric_name].append(
float(
metric_fn(
reference_samples,
posterior_samples[i, : config.task.num_posterior_samples],
)
)
)
metrics_path = artifact_path / "metrics.csv"
pd.DataFrame({"num_observation": list(range(1, 11)), **metrics}).to_csv(
metrics_path, index=False
)
for metric_name, metric_values in metrics.items():
metric_path = artifact_path / f"{metric_name}.csv"
pd.DataFrame(
{"num_observation": list(range(1, 11)), metric_name: metric_values}
).to_csv(metric_path, index=False)
mlflow.log_metric(metric_name, np.mean(metric_values))