import os import sys from pathlib import Path from urllib.parse import unquote, urlparse from omegaconf import OmegaConf import mlflow import logging # specify the keras backend we want to use os.environ["KERAS_BACKEND"] = "jax" import keras import torch import numpy as np import bayesflow as bf import sbibm from sbibm.metrics import c2st, mmd import pandas as pd import timeit import tqdm from mlflow_bayesflow_benchmark_plugin import log_bayesflow_params logger = logging.getLogger() if __name__ == "__main__": config = OmegaConf.load("config.yaml") with mlflow.start_run(): artifact_path = Path(unquote(urlparse(mlflow.get_artifact_uri()).path)) artifact_path.mkdir(parents=True, exist_ok=True) # required parameters logger.info(f"Task: {config.task.name}") if config.task.name in ["sir", "lotka_volterra"]: logger.info( "This task requires the Julia backend. Installing dependencies it might take a while..." ) os.environ["JULIA_SYSIMAGE_DIFFEQTORCH"] = str( Path(sys.prefix) / "julia_sysimage_diffeqtorch.so" ) logger.info( f"Setting environment variable: JULIA_SYSIMAGE_DIFFEQTORCH={os.environ['JULIA_SYSIMAGE_DIFFEQTORCH']}" ) from diffeqtorch.install import ( install_pyjulia, install_julia_deps, install_julia_sysimage, ) install_pyjulia() install_julia_deps() install_julia_sysimage() if config.task.name == "slcp_distractors": import pyro torch.serialization.add_safe_globals( [ pyro.distributions.torch.MixtureSameFamily, pyro.distributions.torch.Categorical, pyro.distributions.torch.Independent, pyro.distributions.multivariate_studentt.MultivariateStudentT, pyro.distributions.torch.Chi2, ] ) task = sbibm.get_task(config.task.name) if config.training.mode == "reference": posterior_samples = [] for i in range(1, 11): seed = 2025 np.random.seed(seed) torch.manual_seed(seed) samples = task._sample_reference_posterior( num_samples=config.task.num_posterior_samples, num_observation=i, ) num_unique = torch.unique(samples, dim=0).shape[0] assert num_unique == config.task.num_posterior_samples posterior_samples.append(samples.cpu().numpy()) posterior_samples = np.stack(posterior_samples, axis=0) else: use_summary_network = config.architecture.get("summary_network", None) is not None prior = task.get_prior() likelihood = task.get_simulator() observations = np.concatenate( [task.get_observation(num_observation=i) for i in range(1, 11)], axis=0 ) def bf_prior(batch_shape, **kwargs): return dict(parameters=prior(num_samples=batch_shape[0], **kwargs)) def bf_likelihood(batch_shape, parameters, **kwargs): observables = likelihood(parameters, **kwargs) return dict(observables=observables) simulator = bf.simulators.SequentialSimulator( [ bf.simulators.LambdaSimulator(bf_prior, is_batched=True), bf.simulators.LambdaSimulator(bf_likelihood, is_batched=True), ] ) adapter = ( bf.adapters.Adapter.create_default(inference_variables=["parameters"]) .standardize(["observables", "inference_variables"]) # rename the variables to match the required approximator inputs .rename( "observables", "summary_variables" if use_summary_network else "inference_conditions", ) ) inference_network_kwargs = ( OmegaConf.to_container(config.architecture.inference_network).get( "kwargs", {} ) or {} ) inference_network_type = config.architecture.inference_network.type if keras.saving.get_registered_object(inference_network_type) is not None: inference_network_type = keras.saving.get_registered_object( inference_network_type ) inference_network = bf.utils.find_inference_network( inference_network_type, **inference_network_kwargs, ) if use_summary_network: summary_network_kwargs = ( OmegaConf.to_container(config.architecture.summary_network).get( "kwargs", {} ) or {} ) summary_network_type = config.architecture.summary_network.type if keras.saving.get_registered_object(summary_network_type) is not None: summary_network_type = keras.saving.get_registered_object( summary_network_type ) summary_network = bf.utils.find_summary_network( summary_network_type, **summary_network_kwargs, ) adapter = adapter.expand_dims("summary_variables", axis=-1) else: summary_network = None workflow = bf.BasicWorkflow( simulator=simulator, adapter=adapter, inference_network=inference_network, summary_network=summary_network, initial_learning_rate=config.training.initial_learning_rate, optimizer=config.training.optimizer, ) num_validation_datasets = 64 if config.training.mode == "offline": logger.info("Simulating training data...") training_data = simulator.rejection_sample( config.training.num_simulations, sample_size=1000, predicate=lambda x: np.full((x["observables"].shape[0],), True), ) start_time = timeit.default_timer() history = workflow.fit_offline( training_data, epochs=config.training.epochs, batch_size=config.training.batch_size, validation_data=num_validation_datasets, verbose=2, ) elif config.training.mode == "online": start_time = timeit.default_timer() history = workflow.fit_online( epochs=config.training.epochs, num_batches_per_epoch=config.training.num_batches_per_epoch, batch_size=config.training.batch_size, validation_data=num_validation_datasets, verbose=2, ) else: raise ValueError( f"Invalid training mode config.training.mode='{config.training.mode}'" ) mlflow.log_metric("training_time", timeit.default_timer() - start_time) true_params = np.concatenate( [task.get_true_parameters(i) for i in range(1, 11)], axis=0 ) try: log_prob = workflow.log_prob( { "observables": observations, "parameters": true_params, } ) log_prob_path = artifact_path / "log_prob.csv" pd.DataFrame( {"num_observation": list(range(1, 11)), "log_prob": log_prob} ).to_csv(log_prob_path, index=False) mlflow.log_metric("log_prob", np.mean(log_prob)) except Exception as e: print("Could not calclate log_prob:", e) start_time = timeit.default_timer() observation = task.get_observation(num_observation=1) posterior_samples = workflow.sample( num_samples=config.task.num_posterior_samples, conditions={"observables": observations}, )["parameters"] mlflow.log_metric("sampling_time", timeit.default_timer() - start_time) log_bayesflow_params(workflow.approximator) # Following the advice from the sbibm repository: # https://github.com/sbi-benchmark/results/blob/741183cdece7077453efbedac0fad7e74afa10e5/benchmarking_sbi/run.py#L200 # Use c2st with z-scoring, and MMD without z-scoring metric_fns = { "sbibm.c2st": lambda reference_samples, posterior_samples: c2st( X=reference_samples, Y=torch.Tensor(posterior_samples), z_score=True, )[0] .cpu() .numpy(), "sbibm.mmd": lambda reference_samples, posterior_samples: mmd( X=reference_samples, Y=torch.Tensor(posterior_samples), z_score=False, ), } c2st_accuracies = [] mmd_values = [] metrics = {k: [] for k in metric_fns.keys()} print("Draw samples and calculate metrics for 10 observations...") for i, num_observation in tqdm.tqdm(enumerate(range(1, 11)), total=10): reference_samples = task.get_reference_posterior_samples( num_observation=num_observation )[: config.task.num_posterior_samples, :] pd.DataFrame( reference_samples.cpu().numpy(), columns=task.get_labels_parameters() ).to_csv( artifact_path / f"{num_observation:02}_reference_samples.csv.gz", index=False, ) pd.DataFrame( keras.ops.convert_to_numpy( posterior_samples[i, : config.task.num_posterior_samples] ), columns=task.get_labels_parameters(), ).to_csv( artifact_path / f"{num_observation:02}_posterior_samples.csv.gz", index=False, ) for metric_name, metric_fn in metric_fns.items(): metrics[metric_name].append( float( metric_fn( reference_samples, posterior_samples[i, : config.task.num_posterior_samples], ) ) ) metrics_path = artifact_path / "metrics.csv" pd.DataFrame({"num_observation": list(range(1, 11)), **metrics}).to_csv( metrics_path, index=False ) for metric_name, metric_values in metrics.items(): metric_path = artifact_path / f"{metric_name}.csv" pd.DataFrame( {"num_observation": list(range(1, 11)), metric_name: metric_values} ).to_csv(metric_path, index=False) mlflow.log_metric(metric_name, np.mean(metric_values)) print(metric_name, np.mean(metric_values))