from fugw.solvers import FUGWSolver
torch.manual_seed(1)
torch.backends.cudnn.benchmark = True
nits_bcd = 100
eval_bcd = 2
Ds_normalized = Ds / Ds.max()
Dt_normalized = Dt / Dt.max()
F_normalized = c_z / c_z.max()
fugw = FUGWSolver(
nits_bcd=nits_bcd,
nits_uot=1000,
tol_bcd=1e-7,
tol_uot=1e-7,
tol_loss=1e-5,
eval_bcd=eval_bcd,
eval_uot=10,
ibpp_eps_base=1e5,
)
divergence="kl"
reg_mode="independent"
solver="sinkhorn"
res = fugw.solve(
alpha=alpha,
rho_s=rho_s,
rho_t=rho_t,
eps=eps,
reg_mode=reg_mode,
divergence=divergence,
F=F_normalized,
Ds=Ds_normalized,
Dt=Dt_normalized,
init_plan=None,
solver=solver,
callback_bcd=None,
verbose=False,
)
Why would I get coupling having nan? In what cases would this happen easily?