|
def from_dictionary(self, dictionary): |
|
mvgkwargs = {} |
|
for key in list(dictionary.keys()): |
|
val = dictionary[key] |
|
if isinstance(val, Prior): |
|
continue |
|
elif isinstance(val, (int, float)): |
|
dictionary[key] = DeltaFunction(peak=val) |
|
elif isinstance(val, str): |
|
cls = val.split("(")[0] |
|
args = "(".join(val.split("(")[1:])[:-1] |
|
try: |
|
dictionary[key] = DeltaFunction(peak=float(cls)) |
|
logger.debug("{} converted to DeltaFunction prior".format(key)) |
|
continue |
|
except ValueError: |
|
pass |
|
if "." in cls: |
|
module = ".".join(cls.split(".")[:-1]) |
|
cls = cls.split(".")[-1] |
|
else: |
|
module = __name__.replace( |
|
"." + os.path.basename(__file__).replace(".py", ""), "" |
|
) |
|
try: |
|
cls = getattr(import_module(module), cls, cls) |
|
except ModuleNotFoundError: |
|
logger.error( |
|
"Cannot import prior class {} for entry: {}={}".format( |
|
cls, key, val |
|
) |
|
) |
|
raise |
|
if key.lower() in ["conversion_function", "condition_func"]: |
|
setattr(self, key, cls) |
|
elif isinstance(cls, str): |
|
if "(" in val: |
|
raise TypeError("Unable to parse prior class {}".format(cls)) |
|
else: |
|
continue |
|
elif cls.__name__ in [ |
|
"MultivariateGaussianDist", |
|
"MultivariateNormalDist", |
|
]: |
|
dictionary.pop(key) |
|
if key not in mvgkwargs: |
|
mvgkwargs[key] = cls.from_repr(args) |
|
elif cls.__name__ in ["MultivariateGaussian", "MultivariateNormal"]: |
|
mgkwargs = { |
|
item[0].strip(): cls._parse_argument_string(item[1]) |
|
for item in cls._split_repr( |
|
", ".join( |
|
[arg for arg in args.split(",") if "dist=" not in arg] |
|
) |
|
).items() |
|
} |
|
keymatch = re.match(r"dist=(?P<distkey>\S+),", args) |
|
if keymatch is None: |
|
raise ValueError( |
|
"'dist' argument for MultivariateGaussian is not specified" |
|
) |
|
|
|
if keymatch["distkey"] not in mvgkwargs: |
|
raise ValueError( |
|
f"MultivariateGaussianDist {keymatch['distkey']} must be defined before {cls.__name__}" |
|
) |
|
|
|
mgkwargs["dist"] = mvgkwargs[keymatch["distkey"]] |
|
dictionary[key] = cls(**mgkwargs) |
|
else: |
|
try: |
|
dictionary[key] = cls.from_repr(args) |
|
except TypeError as e: |
|
raise TypeError( |
|
"Unable to parse prior, bad entry: {} " |
|
"= {}. Error message {}".format(key, val, e) |
|
) |
|
elif isinstance(val, dict): |
|
try: |
|
_class = getattr( |
|
import_module(val.get("__module__", "none")), |
|
val.get("__name__", "none"), |
|
) |
|
dictionary[key] = _class(**val.get("kwargs", dict())) |
|
except ImportError: |
|
logger.debug( |
|
"Cannot import prior module {}.{}".format( |
|
val.get("__module__", "none"), val.get("__name__", "none") |
|
) |
|
) |
|
logger.warning( |
|
"Cannot convert {} into a prior object. " |
|
"Leaving as dictionary.".format(key) |
|
) |
|
continue |
|
else: |
|
raise TypeError( |
|
"Unable to parse prior, bad entry: {} " |
|
"= {} of type {}".format(key, val, type(val)) |
|
) |
|
self.update(dictionary) |