valentins-bf-benchmark-runs/fae377f7b3e347d9ada96debc0a06078/artifacts/source/train.py
2025-06-08 17:47:55 +02:00

309 lines
12 KiB (Stored with Git LFS)
Python

import os
import sys
from pathlib import Path
import shutil
from urllib.parse import unquote, urlparse
from omegaconf import OmegaConf
import mlflow
import logging
# specify the keras backend we want to use
os.environ["KERAS_BACKEND"] = "jax"
import keras
import torch
import numpy as np
import bayesflow as bf
import sbibm
from sbibm.metrics import c2st, mmd
import pandas as pd
import timeit
import tqdm
from mlflow_bayesflow_benchmark_plugin import log_bayesflow_params
logger = logging.getLogger()
if __name__ == "__main__":
config = OmegaConf.load("config.yaml")
with mlflow.start_run():
artifact_path = Path(unquote(urlparse(mlflow.get_artifact_uri()).path))
artifact_path.mkdir(parents=True, exist_ok=True)
# required parameters
logger.info(f"Task: {config.task.name}")
if config.task.name in ["sir", "lotka_volterra"]:
logger.info(
"This task requires the Julia backend. Installing dependencies it might take a while..."
)
os.environ["JULIA_SYSIMAGE_DIFFEQTORCH"] = str(
Path(sys.prefix) / "julia_sysimage_diffeqtorch.so"
)
logger.info(
f"Setting environment variable: JULIA_SYSIMAGE_DIFFEQTORCH={os.environ['JULIA_SYSIMAGE_DIFFEQTORCH']}"
)
import diffeqtorch
from diffeqtorch.install import (
install_pyjulia,
install_julia_deps,
install_julia_sysimage,
)
srcpath = Path(".") / "DiffEqTorchManifest.toml"
dstpath = Path(diffeqtorch.__file__).parent / "julia" / "Manifest.toml"
logger.info(
f"Copy Manifest.toml '{srcpath}' to Julia project dir '{dstpath}'..."
)
try:
shutil.copyfile(srcpath, dstpath)
except shutil.SameFileError:
pass
install_pyjulia()
install_julia_deps()
install_julia_sysimage()
if config.task.name == "slcp_distractors":
import pyro
torch.serialization.add_safe_globals(
[
pyro.distributions.torch.MixtureSameFamily,
pyro.distributions.torch.Categorical,
pyro.distributions.torch.Independent,
pyro.distributions.multivariate_studentt.MultivariateStudentT,
pyro.distributions.torch.Chi2,
]
)
task = sbibm.get_task(config.task.name)
if config.training.mode == "reference":
posterior_samples = []
for i in range(1, 11):
seed = 2025
np.random.seed(seed)
torch.manual_seed(seed)
samples = task._sample_reference_posterior(
num_samples=config.task.num_posterior_samples,
num_observation=i,
)
num_unique = torch.unique(samples, dim=0).shape[0]
assert num_unique == config.task.num_posterior_samples
posterior_samples.append(samples.cpu().numpy())
posterior_samples = np.stack(posterior_samples, axis=0)
else:
use_summary_network = (
config.architecture.get("summary_network", None) is not None
)
prior = task.get_prior()
likelihood = task.get_simulator()
observations = np.concatenate(
[task.get_observation(num_observation=i) for i in range(1, 11)], axis=0
)
def bf_prior(batch_shape, **kwargs):
return dict(parameters=prior(num_samples=batch_shape[0], **kwargs))
def bf_likelihood(batch_shape, parameters, **kwargs):
observables = likelihood(parameters, **kwargs)
return dict(observables=observables)
simulator = bf.simulators.SequentialSimulator(
[
bf.simulators.LambdaSimulator(bf_prior, is_batched=True),
bf.simulators.LambdaSimulator(bf_likelihood, is_batched=True),
]
)
adapter = (
bf.adapters.Adapter.create_default(inference_variables=["parameters"])
.expand_dims("observables", axis=-1)
.standardize(["observables", "inference_variables"])
# rename the variables to match the required approximator inputs
.rename(
"observables",
"summary_variables"
if use_summary_network
else "inference_conditions",
)
)
inference_network_kwargs = (
OmegaConf.to_container(config.architecture.inference_network).get(
"kwargs", {}
)
or {}
)
inference_network_type = config.architecture.inference_network.type
if keras.saving.get_registered_object(inference_network_type) is not None:
inference_network_type = keras.saving.get_registered_object(
inference_network_type
)
inference_network = bf.utils.find_inference_network(
inference_network_type,
**inference_network_kwargs,
)
if use_summary_network:
summary_network_kwargs = (
OmegaConf.to_container(config.architecture.summary_network).get(
"kwargs", {}
)
or {}
)
summary_network_type = config.architecture.summary_network.type
if keras.saving.get_registered_object(summary_network_type) is not None:
summary_network_type = keras.saving.get_registered_object(
summary_network_type
)
summary_network = bf.utils.find_summary_network(
summary_network_type,
**summary_network_kwargs,
)
else:
summary_network = None
adapter = adapter.squeeze(
"summary_variables"
if use_summary_network
else "inference_conditions",
axis=-1,
)
workflow = bf.BasicWorkflow(
simulator=simulator,
adapter=adapter,
inference_network=inference_network,
summary_network=summary_network,
initial_learning_rate=config.training.initial_learning_rate,
optimizer=config.training.optimizer,
)
num_validation_datasets = 64
if config.training.mode == "offline":
logger.info("Simulating training data...")
training_data = simulator.rejection_sample(
config.training.num_simulations,
sample_size=1000,
predicate=lambda x: np.full((x["observables"].shape[0],), True),
)
start_time = timeit.default_timer()
history = workflow.fit_offline(
training_data,
epochs=config.training.epochs,
batch_size=config.training.batch_size,
validation_data=num_validation_datasets,
verbose=2,
)
elif config.training.mode == "online":
start_time = timeit.default_timer()
history = workflow.fit_online(
epochs=config.training.epochs,
num_batches_per_epoch=config.training.num_batches_per_epoch,
batch_size=config.training.batch_size,
validation_data=num_validation_datasets,
verbose=2,
)
else:
raise ValueError(
f"Invalid training mode config.training.mode='{config.training.mode}'"
)
mlflow.log_metric("training_time", timeit.default_timer() - start_time)
true_params = np.concatenate(
[task.get_true_parameters(i) for i in range(1, 11)], axis=0
)
try:
log_prob = workflow.log_prob(
{
"observables": observations,
"parameters": true_params,
}
)
log_prob_path = artifact_path / "log_prob.csv"
pd.DataFrame(
{"num_observation": list(range(1, 11)), "log_prob": log_prob}
).to_csv(log_prob_path, index=False)
mlflow.log_metric("log_prob", np.mean(log_prob))
except Exception as e:
print("Could not calclate log_prob:", e)
start_time = timeit.default_timer()
observation = task.get_observation(num_observation=1)
posterior_samples = workflow.sample(
num_samples=config.task.num_posterior_samples,
conditions={"observables": observations},
)["parameters"]
mlflow.log_metric("sampling_time", timeit.default_timer() - start_time)
log_bayesflow_params(workflow.approximator)
# Following the advice from the sbibm repository:
# https://github.com/sbi-benchmark/results/blob/741183cdece7077453efbedac0fad7e74afa10e5/benchmarking_sbi/run.py#L200
# Use c2st with z-scoring, and MMD without z-scoring
metric_fns = {
"sbibm.c2st": lambda reference_samples, posterior_samples: c2st(
X=reference_samples,
Y=torch.Tensor(posterior_samples),
z_score=True,
)[0]
.cpu()
.numpy(),
"sbibm.mmd": lambda reference_samples, posterior_samples: mmd(
X=reference_samples,
Y=torch.Tensor(posterior_samples),
z_score=False,
),
}
c2st_accuracies = []
mmd_values = []
metrics = {k: [] for k in metric_fns.keys()}
print("Draw samples and calculate metrics for 10 observations...")
for i, num_observation in tqdm.tqdm(enumerate(range(1, 11)), total=10):
reference_samples = task.get_reference_posterior_samples(
num_observation=num_observation
)[: config.task.num_posterior_samples, :]
pd.DataFrame(
reference_samples.cpu().numpy(), columns=task.get_labels_parameters()
).to_csv(
artifact_path / f"{num_observation:02}_reference_samples.csv.gz",
index=False,
)
pd.DataFrame(
keras.ops.convert_to_numpy(
posterior_samples[i, : config.task.num_posterior_samples]
),
columns=task.get_labels_parameters(),
).to_csv(
artifact_path / f"{num_observation:02}_posterior_samples.csv.gz",
index=False,
)
for metric_name, metric_fn in metric_fns.items():
metrics[metric_name].append(
float(
metric_fn(
reference_samples,
posterior_samples[i, : config.task.num_posterior_samples],
)
)
)
metrics_path = artifact_path / "metrics.csv"
pd.DataFrame({"num_observation": list(range(1, 11)), **metrics}).to_csv(
metrics_path, index=False
)
for metric_name, metric_values in metrics.items():
metric_path = artifact_path / f"{metric_name}.csv"
pd.DataFrame(
{"num_observation": list(range(1, 11)), metric_name: metric_values}
).to_csv(metric_path, index=False)
mlflow.log_metric(metric_name, np.mean(metric_values))
print(metric_name, np.mean(metric_values))