import os from pathlib import Path from urllib.parse import unquote, urlparse from omegaconf import OmegaConf import mlflow # specify the keras backend we want to use os.environ["KERAS_BACKEND"] = "jax" import keras import torch import numpy as np import bayesflow as bf import sbibm from sbibm.metrics import c2st, mmd import pandas as pd import timeit from mlflow_bayesflow_benchmark_plugin import log_bayesflow_params if __name__ == "__main__": config = OmegaConf.load("config.yaml") with mlflow.start_run(): artifact_path = Path(unquote(urlparse(mlflow.get_artifact_uri()).path)) artifact_path.mkdir(parents=True, exist_ok=True) # required parameters mlflow.log_param("backend", keras.backend.backend()) mlflow.log_metric("training_time", 10.0) task = sbibm.get_task(config.task.name) prior = task.get_prior() likelihood = task.get_simulator() observations = np.concatenate( [task.get_observation(num_observation=i) for i in range(1, 11)], axis=0 ) def bf_prior(batch_shape, **kwargs): return dict( parameters=prior(num_samples=batch_shape[0], **kwargs).squeeze(-1) ) def bf_likelihood(batch_shape, parameters, **kwargs): return dict(observables=likelihood(parameters, **kwargs)) simulator = bf.simulators.SequentialSimulator( [ bf.simulators.LambdaSimulator(bf_prior, is_batched=True), bf.simulators.LambdaSimulator(bf_likelihood, is_batched=True), ] ) adapter = ( bf.adapters.Adapter.create_default(inference_variables=["parameters"]) # standardize data variables to zero mean and unit variance .standardize(exclude="inference_variables") # rename the variables to match the required approximator inputs .rename("observables", "inference_conditions") ) inference_network_kwargs = ( OmegaConf.to_container(config.architecture.inference_network).get( "kwargs", {} ) or {} ) inference_network_type = config.architecture.inference_network.type if keras.saving.get_registered_object(inference_network_type) is not None: inference_network_type = keras.saving.get_registered_object( inference_network_type ) inference_network = bf.utils.find_inference_network( inference_network_type, **inference_network_kwargs, ) summary_network = config.architecture.summary_network if summary_network is not None: summary_network_kwargs = ( OmegaConf.to_container(config.architecture.summary_network).get( "kwargs", {} ) or {} ) summary_network_type = config.architecture.inference_network.type if keras.saving.get_registered_object(summary_network_type) is not None: summary_network_type = keras.saving.get_registered_object( inference_network_type ) summary_network = bf.utils.find_summary_network( summary_network_type, **summary_network_kwargs, ) workflow = bf.BasicWorkflow( simulator=simulator, adapter=adapter, inference_network=inference_network, summary_network=summary_network, initial_learning_rate=config.training.initial_learning_rate, optimizer=config.training.optimizer, ) validation_data = simulator.sample( 64, ) if config.training.mode == "offline": training_data = simulator.rejection_sample( config.training.num_simulations, sample_size=1000, predicate=lambda x: np.full((x["observables"].shape[0],), True), ) start_time = timeit.default_timer() history = workflow.fit_offline( training_data, epochs=config.training.epochs, batch_size=config.training.batch_size, validation_data=validation_data, verbose=2, ) elif config.training.mode == "online": start_time = timeit.default_timer() history = workflow.fit_online( epochs=config.training.epochs, num_batches_per_epoch=config.training.num_batches_per_epoch, batch_size=config.training.batch_size, validation_data=validation_data, verbose=2, ) else: raise ValueError( f"Invalid training mode config.traning.mode='{config.training.mode}'" ) mlflow.log_metric("training_time", timeit.default_timer() - start_time) true_params = np.concatenate( [task.get_true_parameters(i) for i in range(1, 11)], axis=0 ) try: log_prob = workflow.log_prob( { "observables": observations, "parameters": true_params, } ) log_prob_path = artifact_path / "log_prob.csv" pd.DataFrame( {"num_observation": list(range(1, 11)), "log_prob": log_prob} ).to_csv(log_prob_path, index=False) mlflow.log_metric("log_prob", np.mean(log_prob)) except Exception as e: print("Could not calclate log_prob:", e) start_time = timeit.default_timer() observation = task.get_observation(num_observation=1) posterior_samples = workflow.sample( num_samples=config.task.num_posterior_samples, conditions={"observables": observations}, )["parameters"] mlflow.log_metric("sampling_time", timeit.default_timer() - start_time) # Following the advice from the sbibm repository: # https://github.com/sbi-benchmark/results/blob/741183cdece7077453efbedac0fad7e74afa10e5/benchmarking_sbi/run.py#L200 # Use c2st with z-scoring, and MMD without z-scoring metric_fns = { "sbibm.c2st": lambda reference_samples, posterior_samples: c2st( X=reference_samples, Y=torch.Tensor(posterior_samples), z_score=True, )[0] .cpu() .numpy(), "sbibm.mmd": lambda reference_samples, posterior_samples: mmd( X=reference_samples, Y=torch.Tensor(posterior_samples), z_score=False, ), } c2st_accuracies = [] mmd_values = [] metrics = {k: [] for k in metric_fns.keys()} for i, num_observation in enumerate(range(1, 11)): reference_samples = task.get_reference_posterior_samples( num_observation=num_observation )[: config.task.num_posterior_samples, :] pd.DataFrame( reference_samples.cpu().numpy(), columns=task.get_labels_parameters() ).to_csv( artifact_path / f"{num_observation:02}_reference_samples.csv.bz2", index=False, ) pd.DataFrame( keras.ops.convert_to_numpy( posterior_samples[i, : config.task.num_posterior_samples] ), columns=task.get_labels_parameters(), ).to_csv( artifact_path / f"{num_observation:02}_posterior_samples.csv.bz2", index=False, ) for metric_name, metric_fn in metric_fns.items(): metrics[metric_name].append( float( metric_fn( reference_samples, posterior_samples[i, : config.task.num_posterior_samples], ) ) ) metrics_path = artifact_path / "metrics.csv" pd.DataFrame({"num_observation": list(range(1, 11)), **metrics}).to_csv( metrics_path, index=False ) for metric_name, metric_values in metrics.items(): metric_path = artifact_path / f"{metric_name}.csv" pd.DataFrame( {"num_observation": list(range(1, 11)), metric_name: metric_values} ).to_csv(metric_path, index=False) mlflow.log_metric(metric_name, np.mean(metric_values)) log_bayesflow_params(workflow.approximator)