Comments (6)
Thank for the bug report! We'll take a closer look at it shortly.
from aemcmc.
I have a fix in #93 for the error you were seeing, but we still need to look into the samplers AeMCMC produces (or lack thereof) for your model.
from aemcmc.
Thank you, that was quick! I've tried with the patch, and construct_sampler
no longer fails. However there is a new error, which could be very well be my mistake?
Here is the new code:
import numpy as np
import aesara
import aesara.tensor as at
from aemcmc.basic import construct_sampler
from aesara.tensor.random.utils import RandomStream
def logistic_fit(X_val, y_val):
N, M, T = X_val.shape
srng = RandomStream(0)
X = at.tensor3("X")
sigma_rv = srng.exponential(1, size=X.shape[1])
beta_t_rv = at.cumsum(srng.normal(0, 1/sigma_rv, size=(X.shape[1],X.shape[2])), axis=1)
eta = at.tensordot(X, beta_t_rv, 2)
p = at.sigmoid(-eta)
Y_rv = srng.bernoulli(p, name="Y")
y_vv = Y_rv.clone()
y_vv.name = "y"
sample_vars = [sigma_rv, beta_t_rv]
sampler, initial_values = construct_sampler({Y_rv: y_vv}, srng)
inputs = [X, y_vv] + [initial_values[rv] for rv in sample_vars]
outputs = [sampler.sample_steps[rv] for rv in sample_vars]
sample_step = aesara.function(
inputs,
outputs,
updates=sampler.updates,
on_unused_input="ignore",
)
sigma_val = np.ones(M)
beta_pst_vals = []
sigma_pst_val, beta_pst_val = (
sigma_val,
np.zeros(M,T)
)
for i in range(100):
sigma_pst_val, beta_pst_val = sample_step(
X_val,
y_val,
sigma_pst_val,
beta_pst_val
)
beta_pst_vals += [beta_pst_val]
beta_pst_mean = np.mean(beta_pst_vals, axis=0)
return beta_pst_mean
# X_val = np.load("X_val.npy")
# y_val = np.load("y_val.npy")
X_val = np.zeros((1000, 50, 10))
y_val = np.zeros(1000)
beta = logistic_fit(X_val, y_val)
Here is the error (also notice the warning)
/Users/acristia/anaconda3/lib/python3.8/site-packages/aehmc/utils.py:43: UserWarning: The following parameters need to be computed in order to determine the shapes in this parameter map: [<TensorType(float64, (None, None))>]
warnings.warn(
Traceback (most recent call last):
File "examples/gibbs_sample.py", line 61, in <module>
beta = logistic_fit(X_val, y_val)
File "examples/gibbs_sample.py", line 28, in logistic_fit
inputs = [X, y_vv] + [initial_values[rv] for rv in sample_vars]
File "examples/gibbs_sample.py", line 28, in <listcomp>
inputs = [X, y_vv] + [initial_values[rv] for rv in sample_vars]
KeyError: CumOp{1, add}.0
Let me know if I can provide more useful information!
from aemcmc.
That CumOp
warning is something we still need to address. We can open another issue for it, though.
from aemcmc.
Looked like an error to me?
from aemcmc.
Looked like an error to me?
Yes, it is. I'll reopen this and take a look soon.
from aemcmc.
Related Issues (20)
- Add `sample_prior`, `sample` and `sample_posterior_predictive` functions
- Add automatic Laplace approximation
- Use miniKanren to walk through mathematically equivalent model representations
- Gcc_flag issue with Max os HOT 1
- Dynamically generate lists in documentation
- Make sure `kanren` rewrites account for `SharedVariable.default_update`s
- Add AeMCMC logo to RTD?
- Demo automatic MAP estimation based on proximal operators
- Extend exact posteriors to condition on multiple observations HOT 7
- Assign FFBS sampler to variables in HMM models
- Refactor NUTS builder to use the new `logprob` interface
- Set up scheduled nightly builds
- Add a utility function to sample using `scan`
- Add examples to the README HOT 3
- Add a function to change the scan order of the sampling steps returned by `construct_sampler`
- Update the README with the new sampler interface
- Add standard HMC/NUTS defaults and options HOT 1
- `construct_sampler` does not support transformed observables HOT 1
- Replacements need to apply to `SharedVariable.default_update` and `OpFromGraph`s
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from aemcmc.