Git Product home page Git Product logo

karm216's Introduction

karm216's People

Contributors

karm-patel avatar

Watchers

 avatar

karm216's Issues

Practicing Comment

Hi, My name is karm. I am working with Prof. @nipunbatra and lab colleague @patel-zeel.
I made desired changes in the code and tested it on the following examples.

MLE Model 1

def model_mle_1(data):
    mu = pyro.param('mu', torch.tensor(0.),constraint=constraints.unit_interval)
    sd = pyro.param('sd', torch.tensor(1.),constraint=constraints.greater_than_eq(0))
    with pyro.plate('plate_data', len(data)):
        pyro.sample('obs', dist.Normal(mu, sd), obs=data)
data = torch.tensor([1.,2.,3.])
get_model_relations(model_mle_1,model_args=(data,))

{'sample_sample': {'obs': []},
'sample_param': {'obs': ['sd', 'mu']},
'sample_dist': {'obs': 'Normal'},
'param_constraint': {'mu': Interval(lower_bound=0.0, upper_bound=1.0),
'sd': GreaterThanEq(lower_bound=0)},
'plate_sample': {'plate_data': ['obs']},
'observed': ['obs']}

render_model(model_mle_1,model_args=(data,),render_distributions=True)

image

render_model(model_mle_1,model_args=(data,),render_distributions=True,render_params=True)

image

MAP Model 1

def model_map_1(data):
    k1 = pyro.param('k1',torch.tensor(1.))
    mu = pyro.sample('mu', dist.Normal(0, k1))
    sd = pyro.sample('sd', dist.LogNormal(mu, k1))
    with pyro.plate('plate_data', len(data)):
        pyro.sample('obs', dist.Normal(mu, sd), obs=data)
data = torch.tensor([1.,2.,3.])
get_model_relations(model_map_1,model_args=(data,))

{'sample_sample': {'mu': [], 'sd': ['mu'], 'obs': ['sd', 'mu']},
'sample_param': {'mu': ['k1'], 'sd': ['k1'], 'obs': []},
'sample_dist': {'mu': 'Normal', 'sd': 'LogNormal', 'obs': 'Normal'},
'param_constraint': {'k1': Real()},
'plate_sample': {'plate_data': ['obs']},
'observed': ['obs']}

render_model(model_map_1,model_args=(data,),render_distributions=True)

image

render_model(model_map_1,model_args=(data,),render_distributions=True,render_params=True)

image

MAP Model 2

def model_map_2(data):
    t = pyro.param('t',torch.tensor(1.),constraints.integer)
    a = pyro.sample('a', dist.Bernoulli(t))
    b =  pyro.param('b',torch.tensor(2.))
    with pyro.plate('plate_data', len(data)):
        pyro.sample('obs', dist.Beta(a, b), obs=data)
data = torch.tensor([1.,2.,3.])
get_model_relations(model_map_2,model_args=(data,))

{'sample_sample': {'mu': [], 'sd': ['mu'], 'obs': ['sd', 'mu']},
'sample_param': {'mu': ['k1'], 'sd': ['k1'], 'obs': []},
'sample_dist': {'mu': 'Normal', 'sd': 'LogNormal', 'obs': 'Normal'},
'param_constraint': {'k1': Real()},
'plate_sample': {'plate_data': ['obs']},
'observed': ['obs']}

render_model(model_map_2,model_args=(data,),render_distributions=True)

image

render_model(model_map_2,model_args=(data,), render_distributions=True, render_params=True)

image

Changes made in code

Broadly, I made the following changes in pyro.infer.inspect.py.

  1. I added a key named sample_param in the dictionary returned by get_model_relation() to get a param that depends on a given sample.
    In method get_model_relations()I observed the output of trace.nodes, I found that there is no provenance tracking for params and I think without provenance tracking, we are not able to get dependent params. Since there is method named _pyro_post_sample() in class TrackProvenance(Messenger) which assigning provenance to sample. So I added a similar method for params named _pyro_post_param() in the same class. This method is called while getting the trace, trace = poutine.trace(model).get_trace(*model_args, **model_kwargs).
def _pyro_post_param(self, msg):
  if msg["type"] == "param":
      provenance = frozenset({msg["name"]})  # track only direct dependencies
      value = detach_provenance(msg["value"])
      msg["value"] = ProvenanceTensor(value, provenance)

Then to add values in sample_param I followed a similar procedure for adding values in sample_sample.

  1. I added another key named param_constraint to store constraints of params. This result will be required by the method generate_graph_specification().
  2. I added argument named render_params: bool = False in both methods render_model() and generate_graph_specification(). This argument will ensure optional output showing params in graph.
  3. In method generate_graph_specification(), dictionary node_data looks like below for sample variable,
node_data[rv] = {
            "is_observed": .... ,
            "distribution": .... , 
        }

I added an additional key constraint in node_data for param only, Note that following changes apply only when render_params = True.

node_data[param] = {
            "is_observed": False ,
            "distribution":None ,
            "constraint": constraint
        }

Further, edge_list and plate_groups will also be updated by adding params data.
5. In the render_graph() method, I kept the shape of the param as plain and I added a code to show the constraint of params.

Please give your feedback on this. Can I make PR If the dictionary and graph meet your expectations?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.