karm216's Introduction
karm216's People
karm216's Issues
Practicing Comment
Hi, My name is karm. I am working with Prof. @nipunbatra and lab colleague @patel-zeel.
I made desired changes in the code and tested it on the following examples.
MLE Model 1
def model_mle_1(data):
mu = pyro.param('mu', torch.tensor(0.),constraint=constraints.unit_interval)
sd = pyro.param('sd', torch.tensor(1.),constraint=constraints.greater_than_eq(0))
with pyro.plate('plate_data', len(data)):
pyro.sample('obs', dist.Normal(mu, sd), obs=data)
data = torch.tensor([1.,2.,3.])
get_model_relations(model_mle_1,model_args=(data,))
{'sample_sample': {'obs': []},
'sample_param': {'obs': ['sd', 'mu']},
'sample_dist': {'obs': 'Normal'},
'param_constraint': {'mu': Interval(lower_bound=0.0, upper_bound=1.0),
'sd': GreaterThanEq(lower_bound=0)},
'plate_sample': {'plate_data': ['obs']},
'observed': ['obs']}
render_model(model_mle_1,model_args=(data,),render_distributions=True)
render_model(model_mle_1,model_args=(data,),render_distributions=True,render_params=True)
MAP Model 1
def model_map_1(data):
k1 = pyro.param('k1',torch.tensor(1.))
mu = pyro.sample('mu', dist.Normal(0, k1))
sd = pyro.sample('sd', dist.LogNormal(mu, k1))
with pyro.plate('plate_data', len(data)):
pyro.sample('obs', dist.Normal(mu, sd), obs=data)
data = torch.tensor([1.,2.,3.])
get_model_relations(model_map_1,model_args=(data,))
{'sample_sample': {'mu': [], 'sd': ['mu'], 'obs': ['sd', 'mu']},
'sample_param': {'mu': ['k1'], 'sd': ['k1'], 'obs': []},
'sample_dist': {'mu': 'Normal', 'sd': 'LogNormal', 'obs': 'Normal'},
'param_constraint': {'k1': Real()},
'plate_sample': {'plate_data': ['obs']},
'observed': ['obs']}
render_model(model_map_1,model_args=(data,),render_distributions=True)
render_model(model_map_1,model_args=(data,),render_distributions=True,render_params=True)
MAP Model 2
def model_map_2(data):
t = pyro.param('t',torch.tensor(1.),constraints.integer)
a = pyro.sample('a', dist.Bernoulli(t))
b = pyro.param('b',torch.tensor(2.))
with pyro.plate('plate_data', len(data)):
pyro.sample('obs', dist.Beta(a, b), obs=data)
data = torch.tensor([1.,2.,3.])
get_model_relations(model_map_2,model_args=(data,))
{'sample_sample': {'mu': [], 'sd': ['mu'], 'obs': ['sd', 'mu']},
'sample_param': {'mu': ['k1'], 'sd': ['k1'], 'obs': []},
'sample_dist': {'mu': 'Normal', 'sd': 'LogNormal', 'obs': 'Normal'},
'param_constraint': {'k1': Real()},
'plate_sample': {'plate_data': ['obs']},
'observed': ['obs']}
render_model(model_map_2,model_args=(data,),render_distributions=True)
render_model(model_map_2,model_args=(data,), render_distributions=True, render_params=True)
Changes made in code
Broadly, I made the following changes in pyro.infer.inspect.py.
- I added a key named
sample_param
in the dictionary returned byget_model_relation()
to get a param that depends on a given sample.
In methodget_model_relations()
I observed the output oftrace.nodes
, I found that there is no provenance tracking for params and I think without provenance tracking, we are not able to get dependent params. Since there is method named_pyro_post_sample()
in classTrackProvenance(Messenger)
which assigning provenance to sample. So I added a similar method for params named_pyro_post_param()
in the same class. This method is called while getting the trace,trace = poutine.trace(model).get_trace(*model_args, **model_kwargs)
.
def _pyro_post_param(self, msg):
if msg["type"] == "param":
provenance = frozenset({msg["name"]}) # track only direct dependencies
value = detach_provenance(msg["value"])
msg["value"] = ProvenanceTensor(value, provenance)
Then to add values in sample_param
I followed a similar procedure for adding values in sample_sample
.
- I added another key named
param_constraint
to store constraints of params. This result will be required by the methodgenerate_graph_specification()
. - I added argument named
render_params: bool = False
in both methodsrender_model()
andgenerate_graph_specification()
. This argument will ensure optional output showing params in graph. - In method
generate_graph_specification()
, dictionarynode_data
looks like below for sample variable,
node_data[rv] = {
"is_observed": .... ,
"distribution": .... ,
}
I added an additional key constraint
in node_data for param only, Note that following changes apply only when render_params = True
.
node_data[param] = {
"is_observed": False ,
"distribution":None ,
"constraint": constraint
}
Further, edge_list
and plate_groups
will also be updated by adding params data.
5. In the render_graph()
method, I kept the shape of the param as plain
and I added a code to show the constraint of params.
Please give your feedback on this. Can I make PR If the dictionary and graph meet your expectations?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. ๐๐๐
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google โค๏ธ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.