246 lines
9.6 KiB (Stored with Git LFS)
Python
246 lines
9.6 KiB (Stored with Git LFS)
Python
import os
|
|
from pathlib import Path
|
|
from urllib.parse import unquote, urlparse
|
|
from omegaconf import OmegaConf
|
|
import mlflow
|
|
|
|
|
|
# specify the keras backend we want to use
|
|
os.environ["KERAS_BACKEND"] = "jax"
|
|
|
|
import keras
|
|
|
|
import torch
|
|
import numpy as np
|
|
import bayesflow as bf
|
|
import sbibm
|
|
from sbibm.metrics import c2st, mmd
|
|
import pandas as pd
|
|
import timeit
|
|
import tqdm
|
|
from mlflow_bayesflow_benchmark_plugin import log_bayesflow_params
|
|
|
|
if __name__ == "__main__":
|
|
config = OmegaConf.load("config.yaml")
|
|
with mlflow.start_run():
|
|
artifact_path = Path(unquote(urlparse(mlflow.get_artifact_uri()).path))
|
|
artifact_path.mkdir(parents=True, exist_ok=True)
|
|
# required parameters
|
|
|
|
task = sbibm.get_task(config.task.name)
|
|
if config.training.mode == "reference":
|
|
posterior_samples = []
|
|
for i in range(1, 11):
|
|
seed = 2025
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
|
|
samples = task._sample_reference_posterior(
|
|
num_samples=config.task.num_posterior_samples,
|
|
num_observation=i,
|
|
)
|
|
num_unique = torch.unique(samples, dim=0).shape[0]
|
|
assert num_unique == config.task.num_posterior_samples
|
|
posterior_samples.append(samples.cpu().numpy())
|
|
posterior_samples = np.stack(posterior_samples, axis=0)
|
|
else:
|
|
prior = task.get_prior()
|
|
likelihood = task.get_simulator()
|
|
observations = np.concatenate(
|
|
[task.get_observation(num_observation=i) for i in range(1, 11)], axis=0
|
|
)
|
|
|
|
def bf_prior(batch_shape, **kwargs):
|
|
return dict(
|
|
parameters=prior(num_samples=batch_shape[0], **kwargs).squeeze(-1)
|
|
)
|
|
|
|
def bf_likelihood(batch_shape, parameters, **kwargs):
|
|
return dict(observables=likelihood(parameters, **kwargs))
|
|
|
|
simulator = bf.simulators.SequentialSimulator(
|
|
[
|
|
bf.simulators.LambdaSimulator(bf_prior, is_batched=True),
|
|
bf.simulators.LambdaSimulator(bf_likelihood, is_batched=True),
|
|
]
|
|
)
|
|
|
|
adapter = (
|
|
bf.adapters.Adapter.create_default(inference_variables=["parameters"])
|
|
# standardize data variables to zero mean and unit variance
|
|
.standardize(exclude="inference_variables")
|
|
# rename the variables to match the required approximator inputs
|
|
.rename("observables", "inference_conditions")
|
|
)
|
|
|
|
inference_network_kwargs = (
|
|
OmegaConf.to_container(config.architecture.inference_network).get(
|
|
"kwargs", {}
|
|
)
|
|
or {}
|
|
)
|
|
inference_network_type = config.architecture.inference_network.type
|
|
if keras.saving.get_registered_object(inference_network_type) is not None:
|
|
inference_network_type = keras.saving.get_registered_object(
|
|
inference_network_type
|
|
)
|
|
inference_network = bf.utils.find_inference_network(
|
|
inference_network_type,
|
|
**inference_network_kwargs,
|
|
)
|
|
|
|
summary_network = config.architecture.summary_network
|
|
if summary_network is not None:
|
|
summary_network_kwargs = (
|
|
OmegaConf.to_container(config.architecture.summary_network).get(
|
|
"kwargs", {}
|
|
)
|
|
or {}
|
|
)
|
|
summary_network_type = config.architecture.inference_network.type
|
|
if keras.saving.get_registered_object(summary_network_type) is not None:
|
|
summary_network_type = keras.saving.get_registered_object(
|
|
inference_network_type
|
|
)
|
|
summary_network = bf.utils.find_summary_network(
|
|
summary_network_type,
|
|
**summary_network_kwargs,
|
|
)
|
|
|
|
workflow = bf.BasicWorkflow(
|
|
simulator=simulator,
|
|
adapter=adapter,
|
|
inference_network=inference_network,
|
|
summary_network=summary_network,
|
|
initial_learning_rate=config.training.initial_learning_rate,
|
|
optimizer=config.training.optimizer,
|
|
)
|
|
|
|
validation_data = simulator.sample(
|
|
64,
|
|
)
|
|
if config.training.mode == "offline":
|
|
training_data = simulator.rejection_sample(
|
|
config.training.num_simulations,
|
|
sample_size=1000,
|
|
predicate=lambda x: np.full((x["observables"].shape[0],), True),
|
|
)
|
|
|
|
start_time = timeit.default_timer()
|
|
history = workflow.fit_offline(
|
|
training_data,
|
|
epochs=config.training.epochs,
|
|
batch_size=config.training.batch_size,
|
|
validation_data=validation_data,
|
|
verbose=2,
|
|
)
|
|
elif config.training.mode == "online":
|
|
start_time = timeit.default_timer()
|
|
history = workflow.fit_online(
|
|
epochs=config.training.epochs,
|
|
num_batches_per_epoch=config.training.num_batches_per_epoch,
|
|
batch_size=config.training.batch_size,
|
|
validation_data=validation_data,
|
|
verbose=2,
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"Invalid training mode config.traning.mode='{config.training.mode}'"
|
|
)
|
|
|
|
mlflow.log_metric("training_time", timeit.default_timer() - start_time)
|
|
|
|
true_params = np.concatenate(
|
|
[task.get_true_parameters(i) for i in range(1, 11)], axis=0
|
|
)
|
|
try:
|
|
log_prob = workflow.log_prob(
|
|
{
|
|
"observables": observations,
|
|
"parameters": true_params,
|
|
}
|
|
)
|
|
|
|
log_prob_path = artifact_path / "log_prob.csv"
|
|
pd.DataFrame(
|
|
{"num_observation": list(range(1, 11)), "log_prob": log_prob}
|
|
).to_csv(log_prob_path, index=False)
|
|
mlflow.log_metric("log_prob", np.mean(log_prob))
|
|
except Exception as e:
|
|
print("Could not calclate log_prob:", e)
|
|
|
|
start_time = timeit.default_timer()
|
|
observation = task.get_observation(num_observation=1)
|
|
posterior_samples = workflow.sample(
|
|
num_samples=config.task.num_posterior_samples,
|
|
conditions={"observables": observations},
|
|
)["parameters"]
|
|
mlflow.log_metric("sampling_time", timeit.default_timer() - start_time)
|
|
|
|
log_bayesflow_params(workflow.approximator)
|
|
# Following the advice from the sbibm repository:
|
|
# https://github.com/sbi-benchmark/results/blob/741183cdece7077453efbedac0fad7e74afa10e5/benchmarking_sbi/run.py#L200
|
|
# Use c2st with z-scoring, and MMD without z-scoring
|
|
metric_fns = {
|
|
"sbibm.c2st": lambda reference_samples, posterior_samples: c2st(
|
|
X=reference_samples,
|
|
Y=torch.Tensor(posterior_samples),
|
|
z_score=True,
|
|
)[0]
|
|
.cpu()
|
|
.numpy(),
|
|
"sbibm.mmd": lambda reference_samples, posterior_samples: mmd(
|
|
X=reference_samples,
|
|
Y=torch.Tensor(posterior_samples),
|
|
z_score=False,
|
|
),
|
|
}
|
|
c2st_accuracies = []
|
|
mmd_values = []
|
|
|
|
metrics = {k: [] for k in metric_fns.keys()}
|
|
|
|
print("Draw samples and calculate metrics for 10 observations...")
|
|
for i, num_observation in tqdm.tqdm(enumerate(range(1, 11)), total=10):
|
|
reference_samples = task.get_reference_posterior_samples(
|
|
num_observation=num_observation
|
|
)[: config.task.num_posterior_samples, :]
|
|
|
|
pd.DataFrame(
|
|
reference_samples.cpu().numpy(), columns=task.get_labels_parameters()
|
|
).to_csv(
|
|
artifact_path / f"{num_observation:02}_reference_samples.csv.gz",
|
|
index=False,
|
|
)
|
|
|
|
pd.DataFrame(
|
|
keras.ops.convert_to_numpy(
|
|
posterior_samples[i, : config.task.num_posterior_samples]
|
|
),
|
|
columns=task.get_labels_parameters(),
|
|
).to_csv(
|
|
artifact_path / f"{num_observation:02}_posterior_samples.csv.gz",
|
|
index=False,
|
|
)
|
|
|
|
for metric_name, metric_fn in metric_fns.items():
|
|
metrics[metric_name].append(
|
|
float(
|
|
metric_fn(
|
|
reference_samples,
|
|
posterior_samples[i, : config.task.num_posterior_samples],
|
|
)
|
|
)
|
|
)
|
|
|
|
metrics_path = artifact_path / "metrics.csv"
|
|
pd.DataFrame({"num_observation": list(range(1, 11)), **metrics}).to_csv(
|
|
metrics_path, index=False
|
|
)
|
|
|
|
for metric_name, metric_values in metrics.items():
|
|
metric_path = artifact_path / f"{metric_name}.csv"
|
|
pd.DataFrame(
|
|
{"num_observation": list(range(1, 11)), metric_name: metric_values}
|
|
).to_csv(metric_path, index=False)
|
|
mlflow.log_metric(metric_name, np.mean(metric_values))
|