Contact
CoCalc Logo Icon
StoreFeaturesDocsShareSupport News AboutSign UpSign In
| Download

Aulas do Curso de Modelagem matemática IV da FGV-EMAp

Views: 2412
License: GPL3
Image: default
Kernel: Python 3 (system-wide)

Estimando parâmetros de modelos EDO

Um dos problemas centrais da modelagem é como descobrir os valores "corretos" dos parâmetros do modelo. Já exploramos metodologias de otimização para este fim. Neste notebook vamos tratar este problema como um problema de inferência. Para isso vamos assumir que nossos parâmetros são variáveis aleatórias. Seja θ\theta o conjunto de parâmetros de nosso modelo e q(θ)q(\theta) a distribuição de probabilidade conjunta destes parâmetros. seja M\mathcal{M} o nosso modelo, e ϕ\phi o conjunto das saídas deste modelo: {xi(t)}i=1n\{x_i(t)\}_{i=1}^{n}, onde n é a dimensão do nosso modelo. ϕ=M(θ)\phi=\mathcal{M}(\theta)

Assumir que os parametros θ\theta do modelo são variáveis aleatórias implica que as saídas do modelo também passam a possuir uma distribuição de probabilidade. Vamos resolver este problema de inferência por meio de inferência Bayesiana. Vamos começar por atribuir uma distribuição a priori para θ\theta, que chamaremos de q(θ)q(\theta) que induz uma distribuição sobre as saídas do modelo, q(ϕ)q(\phi). Usaremos a fórmula de Bayes para estimar simultaneamente a distribuição posterior de θ\theta, π(θ)\pi(\theta) e de ϕ\phi, π(ϕ)\pi(\phi).

π(θdados)p(dadosθ)q(θ)\pi(\theta|dados)\propto p(dados|\theta)q(\theta)

Para estimar este modelo precisaremos lançar mão de métodos numéricos como algoritmos de Markov-chain Monte-Carlo (MCMC).

Para saber mais

  1. Coelho et al. (2011) A Bayesian Framework for Parameter Estimation in Dynamical Models

Nesta aula, vamos utilizar a biblioteca PyMC para fazer a estimação de parâmetros de um modelo SIR Para executar este notebook vamos precisar instalar o PyMC

pip install pymc
!pip install pymc
Defaulting to user installation because normal site-packages is not writeable Requirement already satisfied: pymc in /usr/local/lib/python3.10/dist-packages (4.1.7) Requirement already satisfied: scipy>=1.4.1 in /usr/lib/python3/dist-packages (from pymc) (1.8.0) Requirement already satisfied: aeppl==0.0.35 in /usr/local/lib/python3.10/dist-packages (from pymc) (0.0.35) Requirement already satisfied: numpy>=1.15.0 in /usr/local/lib/python3.10/dist-packages (from pymc) (1.23.2) Requirement already satisfied: cloudpickle in /usr/local/lib/python3.10/dist-packages/cloudpickle-2.1.0-py3.10.egg (from pymc) (2.1.0) Requirement already satisfied: pandas>=0.24.0 in /home/fccoelho/.local/lib/python3.10/site-packages (from pymc) (1.4.3) Requirement already satisfied: cachetools>=4.2.1 in /usr/local/lib/python3.10/dist-packages (from pymc) (5.2.0) Requirement already satisfied: aesara==2.8.2 in /usr/local/lib/python3.10/dist-packages (from pymc) (2.8.2) Requirement already satisfied: typing-extensions>=3.7.4 in /usr/local/lib/python3.10/dist-packages (from pymc) (4.3.0) Requirement already satisfied: fastprogress>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from pymc) (1.0.3) Requirement already satisfied: arviz>=0.12.0 in /usr/local/lib/python3.10/dist-packages (from pymc) (0.12.1) Requirement already satisfied: etuples in /usr/local/lib/python3.10/dist-packages (from aesara==2.8.2->pymc) (0.3.7) Requirement already satisfied: logical-unification in /usr/local/lib/python3.10/dist-packages (from aesara==2.8.2->pymc) (0.4.5) Requirement already satisfied: setuptools>=48.0.0 in /usr/lib/python3/dist-packages (from aesara==2.8.2->pymc) (59.6.0) Requirement already satisfied: miniKanren in /usr/local/lib/python3.10/dist-packages (from aesara==2.8.2->pymc) (1.0.3) Requirement already satisfied: cons in /usr/local/lib/python3.10/dist-packages (from aesara==2.8.2->pymc) (0.4.5) Requirement already satisfied: filelock in /usr/lib/python3/dist-packages (from aesara==2.8.2->pymc) (3.6.0) Requirement already satisfied: xarray-einstats>=0.2 in /usr/local/lib/python3.10/dist-packages (from arviz>=0.12.0->pymc) (0.3.0) Requirement already satisfied: xarray>=0.16.1 in /usr/local/lib/python3.10/dist-packages (from arviz>=0.12.0->pymc) (2022.6.0) Requirement already satisfied: packaging in /home/fccoelho/.local/lib/python3.10/site-packages (from arviz>=0.12.0->pymc) (21.3) Requirement already satisfied: matplotlib>=3.0 in /home/fccoelho/.local/lib/python3.10/site-packages (from arviz>=0.12.0->pymc) (3.5.3) Requirement already satisfied: netcdf4 in /usr/local/lib/python3.10/dist-packages (from arviz>=0.12.0->pymc) (1.6.0) Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=0.24.0->pymc) (2022.2.1) Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=0.24.0->pymc) (2.8.2) Requirement already satisfied: pyparsing>=2.2.1 in /usr/lib/python3/dist-packages (from matplotlib>=3.0->arviz>=0.12.0->pymc) (2.4.7) Requirement already satisfied: fonttools>=4.22.0 in /usr/lib/python3/dist-packages (from matplotlib>=3.0->arviz>=0.12.0->pymc) (4.29.1) Requirement already satisfied: cycler>=0.10 in /usr/lib/python3/dist-packages (from matplotlib>=3.0->arviz>=0.12.0->pymc) (0.11.0) Requirement already satisfied: kiwisolver>=1.0.1 in /usr/lib/python3/dist-packages (from matplotlib>=3.0->arviz>=0.12.0->pymc) (1.3.2) Requirement already satisfied: pillow>=6.2.0 in /usr/lib/python3/dist-packages (from matplotlib>=3.0->arviz>=0.12.0->pymc) (9.0.1) Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.1->pandas>=0.24.0->pymc) (1.16.0) Requirement already satisfied: toolz in /usr/local/lib/python3.10/dist-packages/toolz-0.12.0-py3.10.egg (from logical-unification->aesara==2.8.2->pymc) (0.12.0) Requirement already satisfied: multipledispatch in /usr/local/lib/python3.10/dist-packages (from logical-unification->aesara==2.8.2->pymc) (0.6.0) Requirement already satisfied: cftime in /usr/local/lib/python3.10/dist-packages (from netcdf4->arviz>=0.12.0->pymc) (1.6.1) --- Logging error --- Traceback (most recent call last): File "/usr/local/lib/python3.10/dist-packages/pip/_internal/utils/logging.py", line 177, in emit self.console.print(renderable, overflow="ignore", crop=False, style=style) File "/usr/local/lib/python3.10/dist-packages/pip/_vendor/rich/console.py", line 1673, in print extend(render(renderable, render_options)) File "/usr/local/lib/python3.10/dist-packages/pip/_vendor/rich/console.py", line 1305, in render for render_output in iter_render: File "/usr/local/lib/python3.10/dist-packages/pip/_internal/utils/logging.py", line 134, in __rich_console__ for line in lines: File "/usr/local/lib/python3.10/dist-packages/pip/_vendor/rich/segment.py", line 249, in split_lines for segment in segments: File "/usr/local/lib/python3.10/dist-packages/pip/_vendor/rich/console.py", line 1283, in render renderable = rich_cast(renderable) File "/usr/local/lib/python3.10/dist-packages/pip/_vendor/rich/protocol.py", line 36, in rich_cast renderable = cast_method() File "/usr/local/lib/python3.10/dist-packages/pip/_internal/self_outdated_check.py", line 130, in __rich__ pip_cmd = get_best_invocation_for_this_pip() File "/usr/local/lib/python3.10/dist-packages/pip/_internal/utils/entrypoints.py", line 58, in get_best_invocation_for_this_pip if found_executable and os.path.samefile( File "/usr/lib/python3.10/genericpath.py", line 101, in samefile s2 = os.stat(f2) FileNotFoundError: [Errno 2] Arquivo ou diretório inexistente: '/usr/bin/pip' Call stack: File "/usr/local/bin/pip", line 8, in <module> sys.exit(main()) File "/usr/local/lib/python3.10/dist-packages/pip/_internal/cli/main.py", line 70, in main return command.main(cmd_args) File "/usr/local/lib/python3.10/dist-packages/pip/_internal/cli/base_command.py", line 101, in main return self._main(args) File "/usr/local/lib/python3.10/dist-packages/pip/_internal/cli/base_command.py", line 223, in _main self.handle_pip_version_check(options) File "/usr/local/lib/python3.10/dist-packages/pip/_internal/cli/req_command.py", line 190, in handle_pip_version_check pip_self_version_check(session, options) File "/usr/local/lib/python3.10/dist-packages/pip/_internal/self_outdated_check.py", line 236, in pip_self_version_check logger.warning("[present-rich] %s", upgrade_prompt) File "/usr/lib/python3.10/logging/__init__.py", line 1489, in warning self._log(WARNING, msg, args, **kwargs) File "/usr/lib/python3.10/logging/__init__.py", line 1624, in _log self.handle(record) File "/usr/lib/python3.10/logging/__init__.py", line 1634, in handle self.callHandlers(record) File "/usr/lib/python3.10/logging/__init__.py", line 1696, in callHandlers hdlr.handle(record) File "/usr/lib/python3.10/logging/__init__.py", line 968, in handle self.emit(record) File "/usr/local/lib/python3.10/dist-packages/pip/_internal/utils/logging.py", line 179, in emit self.handleError(record) Message: '[present-rich] %s' Arguments: (UpgradePrompt(old='22.2.2', new='22.3'),)
!pip install graphviz
%matplotlib inline import pymc as pm from pymc.ode import DifferentialEquation import numpy as np import matplotlib.pyplot as plt from scipy.integrate import odeint import arviz as az import aesara as ae plt.style.use('seaborn-darkgrid')
/usr/lib/python3/dist-packages/pkg_resources/__init__.py:116: PkgResourcesDeprecationWarning: 1.16.0-unknown is an invalid version and will not be supported in a future release warnings.warn( /usr/lib/python3/dist-packages/pkg_resources/__init__.py:116: PkgResourcesDeprecationWarning: 0.1.43ubuntu1 is an invalid version and will not be supported in a future release warnings.warn( /usr/lib/python3/dist-packages/pkg_resources/__init__.py:116: PkgResourcesDeprecationWarning: 1.1build1 is an invalid version and will not be supported in a future release warnings.warn(

def SIR(y, t, p): ds = -p[0] * y[0] * y[1] di = p[0] * y[0] * y[1] - p[1] * y[1] return [ds, di]

times = np.arange(0, 5, 0.25)

beta,gamma = 4,1.0

Gerando curvas simuladas

y = odeint(SIR, t=times, y0=[0.99, 0.01], args=((beta, gamma),), rtol=1e-8)

Simulando dados Assumindo uma distribuição log-normal com média igual às séries simuladas

yobs = np.random.lognormal(mean=np.log(y[1::]), sigma=[0.2, 0.3])

plt.plot(times[1::], yobs, marker='o', linestyle='none') plt.plot(times, y[:, 0], color='C0', alpha=0.5, label=f'S(t)S(t)') plt.plot(times, y[:, 1], color='C1', alpha=0.5, label=f'I(t)I(t)'); plt.legend();

def SIR(y, t, p): ds = -p[0] * y[0] * y[1] di = p[0] * y[0] * y[1] - p[1] * y[1] return [ds, di] times = np.arange(0, 5, 0.25) beta,gamma = 4,1.0 # Gerando curvas simuladas y = odeint(SIR, t=times, y0=[0.99, 0.01], args=((beta, gamma),), rtol=1e-8) # Simulando dados Assumindo uma distribuição log-normal com média igual às séries simuladas yobs = np.random.lognormal(mean=np.log(y[1::]), sigma=[0.2, 0.3]) plt.plot(times[1::], yobs, marker='o', linestyle='none') plt.plot(times, y[:, 0], color='C0', alpha=0.5, label=f'$S(t)$') plt.plot(times, y[:, 1], color='C1', alpha=0.5, label=f'$I(t)$'); plt.legend();
Image in a Jupyter notebook

Abaixo deifinimos nosso modelo usando a classe DifferentialEquation do PyMC3

sir_model = DifferentialEquation( func=SIR, times=np.arange(0.25, 5, 0.25), n_states=2, n_theta=2, t0=0, )

Construindo o modelo de Inferência

Agora precisamos definir as distribuições a priori dos parâmetros do modelo e a verossimilhança de ϕ\phi.

with pm.Model() as model: sigma = pm.HalfCauchy('sigma', 1, shape=2) # Distribuições a priori # R0 é limitada inferiormente em 1 para sempre termos uma epidemia. R0 = pm.Truncated('R0',pm.Normal.dist(2,3), lower=1) gam = pm.Lognormal('gamma', pm.math.log(2), 2) beta = pm.Deterministic('beta', gam * R0) sir_curves = sir_model(y0=[0.99, 0.01], theta=[beta, gam]) Y = pm.Lognormal('Y', mu=pm.math.log(sir_curves), sigma=sigma, observed=yobs) # db = pm.backends.HDF5('traces.h5') # Salva as amostras e assim evita de manter tudo na memória trace = pm.sample(2000, tune=1000)#, cores=4)
Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... Multiprocess sampling (4 chains in 4 jobs) NUTS: [sigma, R0, gamma]
100.00% [12000/12000 23:03<00:00 Sampling 4 chains, 0 divergences]
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 1386 seconds.
pm.model_to_graphviz(model)
Image in a Jupyter notebook
# data = az.from_pymc3(trace=trace) # data trace
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:      (chain: 4, draw: 2000, sigma_dim_0: 2)
      Coordinates:
        * chain        (chain) int64 0 1 2 3
        * draw         (draw) int64 0 1 2 3 4 5 6 ... 1994 1995 1996 1997 1998 1999
        * sigma_dim_0  (sigma_dim_0) int64 0 1
      Data variables:
          sigma        (chain, draw, sigma_dim_0) float64 0.1619 0.2758 ... 0.2347
          R0           (chain, draw) float64 4.08 4.119 4.13 ... 3.936 4.063 4.025
          gamma        (chain, draw) float64 0.9666 0.95 0.9637 ... 0.9797 0.9816
          beta         (chain, draw) float64 3.944 3.913 3.98 ... 3.966 3.98 3.951
      Attributes:
          created_at:                 2022-10-17T12:01:33.246403
          arviz_version:              0.12.1
          inference_library:          pymc
          inference_library_version:  4.2.2
          sampling_time:              1385.5495262145996
          tuning_steps:               1000
      xarray.Dataset
        • chain: 4
        • draw: 2000
        • sigma_dim_0: 2
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 ... 1996 1997 1998 1999
          array([   0,    1,    2, ..., 1997, 1998, 1999])
        • sigma_dim_0
          (sigma_dim_0)
          int64
          0 1
          array([0, 1])
        • sigma
          (chain, draw, sigma_dim_0)
          float64
          0.1619 0.2758 ... 0.2043 0.2347
          array([[[0.16194283, 0.27582415],
                  [0.21533003, 0.2993452 ],
                  [0.18926746, 0.2835422 ],
                  ...,
                  [0.16404819, 0.32120297],
                  [0.18942945, 0.27494536],
                  [0.26371176, 0.37120496]],
          
                 [[0.27690088, 0.246949  ],
                  [0.3136976 , 0.2630236 ],
                  [0.20806111, 0.24841116],
                  ...,
                  [0.1703425 , 0.27473533],
                  [0.21579777, 0.2174818 ],
                  [0.19883162, 0.25391925]],
          
                 [[0.19629668, 0.19436936],
                  [0.19878078, 0.18784681],
                  [0.16531049, 0.22129923],
                  ...,
                  [0.19461083, 0.2682655 ],
                  [0.24996538, 0.25864739],
                  [0.20290179, 0.22651584]],
          
                 [[0.1978626 , 0.24821272],
                  [0.29522923, 0.24218275],
                  [0.18617717, 0.26801301],
                  ...,
                  [0.23275013, 0.27782811],
                  [0.21059479, 0.2970014 ],
                  [0.20427517, 0.23473615]]])
        • R0
          (chain, draw)
          float64
          4.08 4.119 4.13 ... 4.063 4.025
          array([[4.08006481, 4.11867595, 4.13010453, ..., 4.27084513, 4.20804235,
                  4.20910826],
                 [4.0925921 , 4.09626495, 4.08006957, ..., 4.09508061, 3.88525759,
                  4.39527075],
                 [4.04693821, 4.09807867, 4.0130251 , ..., 4.01719072, 4.0676588 ,
                  4.09826598],
                 [4.01842088, 4.00063883, 3.91500675, ..., 3.93551501, 4.06296525,
                  4.02484244]])
        • gamma
          (chain, draw)
          float64
          0.9666 0.95 ... 0.9797 0.9816
          array([[0.96656697, 0.94996246, 0.9637487 , ..., 0.91183401, 0.91051427,
                  0.92858627],
                 [0.9669123 , 0.95412574, 0.97325939, ..., 0.98041258, 1.05355816,
                  0.88182663],
                 [0.97043886, 0.98890501, 0.98808953, ..., 0.94738202, 0.99816632,
                  0.97263006],
                 [0.98808305, 0.9771111 , 1.00346033, ..., 1.00768296, 0.97966436,
                  0.98161077]])
        • beta
          (chain, draw)
          float64
          3.944 3.913 3.98 ... 3.98 3.951
          array([[3.94365588, 3.91258753, 3.98038288, ..., 3.89430184, 3.8314826 ,
                  3.90852014],
                 [3.95717766, 3.90835185, 3.97096601, ..., 4.01486855, 4.09334482,
                  3.87586681],
                 [3.92730608, 4.05261053, 3.9652281 , ..., 3.80581424, 4.06020002,
                  3.98609667],
                 [3.97053356, 3.90906862, 3.92855396, ..., 3.9657514 , 3.98034227,
                  3.95082867]])
      • created_at :
        2022-10-17T12:01:33.246403
        arviz_version :
        0.12.1
        inference_library :
        pymc
        inference_library_version :
        4.2.2
        sampling_time :
        1385.5495262145996
        tuning_steps :
        1000

    • <xarray.Dataset>
      Dimensions:  (chain: 4, draw: 2000, Y_dim_0: 19, Y_dim_1: 2)
      Coordinates:
        * chain    (chain) int64 0 1 2 3
        * draw     (draw) int64 0 1 2 3 4 5 6 7 ... 1993 1994 1995 1996 1997 1998 1999
        * Y_dim_0  (Y_dim_0) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
        * Y_dim_1  (Y_dim_1) int64 0 1
      Data variables:
          Y        (chain, draw, Y_dim_0, Y_dim_1) float64 -2.837 3.991 ... 3.359
      Attributes:
          created_at:                 2022-10-17T12:07:09.996350
          arviz_version:              0.12.1
          inference_library:          pymc
          inference_library_version:  4.2.2
      xarray.Dataset
        • chain: 4
        • draw: 2000
        • Y_dim_0: 19
        • Y_dim_1: 2
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 ... 1996 1997 1998 1999
          array([   0,    1,    2, ..., 1997, 1998, 1999])
        • Y_dim_0
          (Y_dim_0)
          int64
          0 1 2 3 4 5 6 ... 13 14 15 16 17 18
          array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18])
        • Y_dim_1
          (Y_dim_1)
          int64
          0 1
          array([0, 1])
        • Y
          (chain, draw, Y_dim_0, Y_dim_1)
          float64
          -2.837 3.991 0.5955 ... 3.763 3.359
          array([[[[-2.83729435,  3.99116787],
                   [ 0.59550409,  3.47709097],
                   [ 0.59117259,  2.74706849],
                   ...,
                   [ 3.93709918,  2.74359528],
                   [ 3.63852054,  2.32217137],
                   [ 3.3626232 ,  3.20551247]],
          
                  [[-1.66806403,  3.92218531],
                   [ 0.41902342,  3.44865178],
                   [ 0.43993241,  2.74101512],
                   ...,
                   [ 4.06887055,  2.66154479],
                   [ 3.99272464,  2.17549986],
                   [ 3.40236649,  3.10661134]],
          
                  [[-2.09971697,  3.98564166],
                   [ 0.50267037,  3.414982  ],
                   [ 0.50508848,  2.6582407 ],
                   ...,
          ...
                   ...,
                   [ 3.43734155,  2.67354112],
                   [ 3.21580878,  2.73426928],
                   [ 3.94279966,  3.14680081]],
          
                  [[-1.73486215,  3.94217909],
                   [ 0.43176754,  3.41375292],
                   [ 0.44470896,  2.68161472],
                   ...,
                   [ 3.98977722,  2.65086899],
                   [ 3.85114929,  2.58400666],
                   [ 3.55955585,  3.12158297]],
          
                  [[-1.82703902,  4.10190444],
                   [ 0.45500141,  3.5538181 ],
                   [ 0.47063252,  2.79808781],
                   ...,
                   [ 3.76258385,  2.88530278],
                   [ 3.55610888,  2.30003568],
                   [ 3.76343085,  3.3591063 ]]]])
      • created_at :
        2022-10-17T12:07:09.996350
        arviz_version :
        0.12.1
        inference_library :
        pymc
        inference_library_version :
        4.2.2

    • <xarray.Dataset>
      Dimensions:              (chain: 4, draw: 2000)
      Coordinates:
        * chain                (chain) int64 0 1 2 3
        * draw                 (draw) int64 0 1 2 3 4 5 ... 1995 1996 1997 1998 1999
      Data variables: (12/16)
          max_energy_error     (chain, draw) float64 0.04752 1.229 ... -0.3531 0.9638
          largest_eigval       (chain, draw) float64 nan nan nan nan ... nan nan nan
          perf_counter_diff    (chain, draw) float64 0.3857 0.3969 ... 0.6516 0.4944
          energy               (chain, draw) float64 -70.77 -69.45 ... -70.08 -70.85
          diverging            (chain, draw) bool False False False ... False False
          acceptance_rate      (chain, draw) float64 0.9885 0.4683 ... 0.9711 0.7337
          ...                   ...
          perf_counter_start   (chain, draw) float64 1.09e+06 1.09e+06 ... 1.091e+06
          tree_depth           (chain, draw) int64 3 3 2 4 3 3 3 4 ... 4 2 3 3 3 3 4 4
          step_size_bar        (chain, draw) float64 0.4845 0.4845 ... 0.4374 0.4374
          energy_error         (chain, draw) float64 -0.02568 0.4814 ... 0.006804
          smallest_eigval      (chain, draw) float64 nan nan nan nan ... nan nan nan
          process_time_diff    (chain, draw) float64 0.3856 0.3969 ... 0.6516 0.4944
      Attributes:
          created_at:                 2022-10-17T12:01:33.256788
          arviz_version:              0.12.1
          inference_library:          pymc
          inference_library_version:  4.2.2
          sampling_time:              1385.5495262145996
          tuning_steps:               1000
      xarray.Dataset
        • chain: 4
        • draw: 2000
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 ... 1996 1997 1998 1999
          array([   0,    1,    2, ..., 1997, 1998, 1999])
        • max_energy_error
          (chain, draw)
          float64
          0.04752 1.229 ... -0.3531 0.9638
          array([[ 0.04751819,  1.2290926 , -0.05435397, ..., -0.35840855,
                   0.16576707,  0.80781441],
                 [-0.11811429,  0.12377295, -0.47450201, ...,  0.1959466 ,
                   0.09861758,  0.19445882],
                 [-0.27461052,  0.59854836, -0.98787772, ...,  2.09923006,
                  -1.71114537, -0.52319771],
                 [-1.66550541,  0.42944531,  0.78322745, ..., -0.87408869,
                  -0.35309494,  0.96377704]])
        • largest_eigval
          (chain, draw)
          float64
          nan nan nan nan ... nan nan nan nan
          array([[nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan]])
        • perf_counter_diff
          (chain, draw)
          float64
          0.3857 0.3969 ... 0.6516 0.4944
          array([[0.38568215, 0.39694783, 0.19115934, ..., 0.35778542, 0.17823057,
                  0.34375677],
                 [0.18151282, 0.35969604, 0.26313544, ..., 0.74967518, 0.38037686,
                  0.7775518 ],
                 [0.3582785 , 0.19445369, 0.36552347, ..., 0.18885131, 0.29811462,
                  0.37983919],
                 [0.37468457, 0.36776569, 0.37251874, ..., 0.3380478 , 0.6516414 ,
                  0.49443003]])
        • energy
          (chain, draw)
          float64
          -70.77 -69.45 ... -70.08 -70.85
          array([[-70.77024957, -69.44833583, -71.39659155, ..., -68.16397831,
                  -69.24418168, -67.09901627],
                 [-69.72352997, -67.60497209, -67.16189981, ..., -70.40114088,
                  -66.67201658, -64.51244822],
                 [-69.38810016, -68.6183161 , -68.78074421, ..., -67.63136463,
                  -68.28958726, -70.28071384],
                 [-70.00698035, -69.31334645, -65.99953669, ..., -67.56168739,
                  -70.08419048, -70.84545975]])
        • diverging
          (chain, draw)
          bool
          False False False ... False False
          array([[False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False]])
        • acceptance_rate
          (chain, draw)
          float64
          0.9885 0.4683 ... 0.9711 0.7337
          array([[0.98851137, 0.4683444 , 0.99248017, ..., 1.        , 0.94121673,
                  0.63756936],
                 [0.97110983, 0.94431548, 1.        , ..., 0.94898734, 0.97507446,
                  0.95200489],
                 [0.96331946, 0.7084452 , 0.89169157, ..., 0.34545839, 1.        ,
                  0.9327147 ],
                 [0.95248394, 0.72680098, 0.78084921, ..., 0.80523612, 0.97114109,
                  0.73365869]])
        • lp
          (chain, draw)
          float64
          71.76 71.58 72.05 ... 71.83 72.1
          array([[71.75688993, 71.57820991, 72.05168304, ..., 69.79878284,
                  70.70591328, 68.30231896],
                 [70.55569328, 68.7887392 , 72.45575836, ..., 71.60173004,
                  68.69546115, 66.82653267],
                 [70.79015118, 69.77264989, 71.22376911, ..., 67.69801163,
                  70.81127628, 72.27506796],
                 [72.21942495, 69.33845801, 68.79914773, ..., 70.43754911,
                  71.83005642, 72.10443243]])
        • step_size
          (chain, draw)
          float64
          0.4315 0.4315 ... 0.3899 0.3899
          array([[0.43145352, 0.43145352, 0.43145352, ..., 0.43145352, 0.43145352,
                  0.43145352],
                 [0.48069527, 0.48069527, 0.48069527, ..., 0.48069527, 0.48069527,
                  0.48069527],
                 [0.42458761, 0.42458761, 0.42458761, ..., 0.42458761, 0.42458761,
                  0.42458761],
                 [0.38988225, 0.38988225, 0.38988225, ..., 0.38988225, 0.38988225,
                  0.38988225]])
        • index_in_trajectory
          (chain, draw)
          int64
          -6 -4 3 10 -7 -4 ... 2 6 -4 2 12 4
          array([[ -6,  -4,   3, ...,  -2,   2,  -3],
                 [  2,   1,   2, ..., -14,   4,  10],
                 [  2,  -2,   2, ...,   1,  -3,   7],
                 [  3,  -4,   6, ...,   2,  12,   4]])
        • n_steps
          (chain, draw)
          float64
          7.0 7.0 3.0 15.0 ... 7.0 15.0 11.0
          array([[ 7.,  7.,  3., ...,  7.,  3.,  7.],
                 [ 3.,  7.,  5., ..., 15.,  7., 15.],
                 [ 7.,  3.,  7., ...,  3.,  5.,  7.],
                 [ 7.,  7.,  7., ...,  7., 15., 11.]])
        • perf_counter_start
          (chain, draw)
          float64
          1.09e+06 1.09e+06 ... 1.091e+06
          array([[1090216.26868747, 1090216.65469135, 1090217.05196404, ...,
                  1090953.70763352, 1090954.06572961, 1090954.24421561],
                 [1090209.45045377, 1090209.63221539, 1090209.9922186 , ...,
                  1091011.71584736, 1091012.46582033, 1091012.84648953],
                 [1090240.30906091, 1090240.66766887, 1090240.86243772, ...,
                  1090988.12974182, 1090988.31893374, 1090988.61734984],
                 [1090229.57790728, 1090229.95297206, 1090230.32103625, ...,
                  1091030.40035476, 1091030.73866435, 1091031.39052327]])
        • tree_depth
          (chain, draw)
          int64
          3 3 2 4 3 3 3 4 ... 4 2 3 3 3 3 4 4
          array([[3, 3, 2, ..., 3, 2, 3],
                 [2, 3, 3, ..., 4, 3, 4],
                 [3, 2, 3, ..., 2, 3, 3],
                 [3, 3, 3, ..., 3, 4, 4]])
        • step_size_bar
          (chain, draw)
          float64
          0.4845 0.4845 ... 0.4374 0.4374
          array([[0.48451942, 0.48451942, 0.48451942, ..., 0.48451942, 0.48451942,
                  0.48451942],
                 [0.42033041, 0.42033041, 0.42033041, ..., 0.42033041, 0.42033041,
                  0.42033041],
                 [0.45968627, 0.45968627, 0.45968627, ..., 0.45968627, 0.45968627,
                  0.45968627],
                 [0.43741726, 0.43741726, 0.43741726, ..., 0.43741726, 0.43741726,
                  0.43741726]])
        • energy_error
          (chain, draw)
          float64
          -0.02568 0.4814 ... 0.006804
          array([[-0.02568471,  0.48136034,  0.01035434, ..., -0.06902002,
                   0.02387616,  0.35680725],
                 [-0.11811429,  0.12377295, -0.31916373, ...,  0.15767994,
                  -0.02565573,  0.096389  ],
                 [-0.01177102,  0.45928297, -0.98787772, ...,  2.09923006,
                  -1.28677458, -0.36765744],
                 [-1.47577468,  0.42944531,  0.25010433, ..., -0.59515062,
                  -0.2423183 ,  0.0068044 ]])
        • smallest_eigval
          (chain, draw)
          float64
          nan nan nan nan ... nan nan nan nan
          array([[nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan]])
        • process_time_diff
          (chain, draw)
          float64
          0.3856 0.3969 ... 0.6516 0.4944
          array([[0.38561325, 0.39692719, 0.19115906, ..., 0.35778596, 0.1782305 ,
                  0.34375714],
                 [0.18151338, 0.3596971 , 0.26311466, ..., 0.74958285, 0.38033363,
                  0.7774217 ],
                 [0.35825514, 0.19443169, 0.36547933, ..., 0.18881298, 0.29811507,
                  0.37981512],
                 [0.37468561, 0.3676942 , 0.37249411, ..., 0.33802756, 0.65162105,
                  0.49438233]])
      • created_at :
        2022-10-17T12:01:33.256788
        arviz_version :
        0.12.1
        inference_library :
        pymc
        inference_library_version :
        4.2.2
        sampling_time :
        1385.5495262145996
        tuning_steps :
        1000

    • <xarray.Dataset>
      Dimensions:  (Y_dim_0: 19, Y_dim_1: 2)
      Coordinates:
        * Y_dim_0  (Y_dim_0) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
        * Y_dim_1  (Y_dim_1) int64 0 1
      Data variables:
          Y        (Y_dim_0, Y_dim_1) float64 1.483 0.02371 1.061 ... 0.02964 0.05862
      Attributes:
          created_at:                 2022-10-17T12:07:09.999508
          arviz_version:              0.12.1
          inference_library:          pymc
          inference_library_version:  4.2.2
      xarray.Dataset
        • Y_dim_0: 19
        • Y_dim_1: 2
        • Y_dim_0
          (Y_dim_0)
          int64
          0 1 2 3 4 5 6 ... 13 14 15 16 17 18
          array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18])
        • Y_dim_1
          (Y_dim_1)
          int64
          0 1
          array([0, 1])
        • Y
          (Y_dim_0, Y_dim_1)
          float64
          1.483 0.02371 ... 0.02964 0.05862
          array([[1.48329843, 0.02370812],
                 [1.06119893, 0.03445646],
                 [1.01168078, 0.06496748],
                 [1.01441855, 0.17877172],
                 [0.66960364, 0.16217529],
                 [0.57501547, 0.40286537],
                 [0.45910399, 0.54208119],
                 [0.28617287, 0.27252168],
                 [0.15317886, 0.37651174],
                 [0.0905287 , 0.28834945],
                 [0.07670945, 0.20714431],
                 [0.05638283, 0.20995874],
                 [0.04472948, 0.17622128],
                 [0.03639316, 0.16325763],
                 [0.02761103, 0.19802869],
                 [0.03221981, 0.11251146],
                 [0.02214076, 0.09279571],
                 [0.01943621, 0.04898372],
                 [0.02963759, 0.0586243 ]])
      • created_at :
        2022-10-17T12:07:09.999508
        arviz_version :
        0.12.1
        inference_library :
        pymc
        inference_library_version :
        4.2.2

az.plot_posterior(trace, round_to=2, hdi_prob=0.95);
Image in a Jupyter notebook
az.plot_trace(trace);
Image in a Jupyter notebook
az.summary(trace.posterior, kind='stats')
mean sd hdi_3% hdi_97%
sigma[0] 0.206 0.038 0.141 0.277
sigma[1] 0.266 0.050 0.181 0.360
R0 4.091 0.095 3.920 4.273
gamma 0.969 0.034 0.904 1.033
beta 3.960 0.072 3.822 4.096
az.summary(trace.posterior, kind='diagnostics')
mcse_mean mcse_sd ess_bulk ess_tail r_hat
sigma[0] 0.001 0.000 5851.0 4557.0 1.0
sigma[1] 0.001 0.001 5197.0 4512.0 1.0
R0 0.002 0.001 3343.0 3314.0 1.0
gamma 0.001 0.000 3159.0 2826.0 1.0
beta 0.001 0.001 4485.0 4357.0 1.0

Relembrando de onde partimos: Amostrando da Priori

with model: prior_idata = pm.sample_prior_predictive()
Sampling: [R0, Y, gamma, sigma]
prior_idata

A partir destas amostras temos que realizar simulações para reconstruir a distribuição a priori induzida da saída do nossso modelo. Para isso vamos escrever uma simples função que além de realizar as simulações também as adiciona a um objeto DataArray da biblioteca Xarray.

import xarray as xr
def simul_output(idata, group='posterior'): idata = idata.prior if group == "prior" else idata.posterior i = 0 for b, g in zip(np.array(idata['beta'])[0],np.array(idata['gamma'])[0]): y = odeint(SIR, t=np.arange(0.25, 5, 0.25), y0=[0.99, 0.01], args=((b, g),), rtol=1e-8) if i == 0: sims = xr.DataArray(y,coords=[range(y.shape[0]),['S','I']], dims=["time", "state"]) else: y = xr.DataArray(y,coords=[range(y.shape[0]),['S','I']], dims=["time", "state"]) sims = xr.concat([sims,y], "reps") i +=1 return sims sims = simul_output(prior_idata, "prior")

Agora que temos as 500 simulações podemos computar a mediana e o intervalo de credibilidade de 95%:

sims.quantile([0.025,0.5,0.975],'reps')#.plot.line(x='time');
<xarray.DataArray (quantile: 3, time: 19, state: 2)>
array([[[ 9.90000000e-01,  1.00000000e-02],
        [ 2.16018439e-03, -3.85432445e-11],
        [ 1.07836902e-03, -3.09197532e-11],
        [ 9.86009775e-04, -5.96713974e-11],
        [ 9.38196955e-04, -3.76586987e-11],
        [ 8.42955232e-04, -5.43293292e-11],
        [ 8.33691614e-04, -4.99809063e-11],
        [ 7.26865718e-04, -4.11994600e-11],
        [ 7.26865718e-04, -4.04094170e-11],
        [ 7.26865718e-04, -4.79877910e-11],
        [ 6.93128072e-04, -1.89150841e-10],
        [ 6.27549461e-04, -2.20812685e-10],
        [ 5.15100662e-04, -2.16863071e-10],
        [ 4.49543152e-04, -3.32939435e-10],
        [ 4.14458419e-04, -3.28770814e-10],
        [ 3.62329718e-04, -1.73921638e-10],
        [ 3.28046993e-04, -2.04740131e-10],
        [ 3.28046993e-04, -1.68541313e-10],
        [ 3.24894294e-04, -2.75348373e-10]],

...

       [[ 9.90000000e-01,  1.00000000e-02],
        [ 9.89757485e-01,  4.43880969e-01],
        [ 9.89511991e-01,  4.79171810e-01],
        [ 9.89263484e-01,  4.62201128e-01],
        [ 9.89011920e-01,  4.91115204e-01],
        [ 9.88757279e-01,  4.42538064e-01],
        [ 9.88499524e-01,  4.49963167e-01],
        [ 9.88238612e-01,  4.71145485e-01],
        [ 9.87974499e-01,  4.83254767e-01],
        [ 9.87707147e-01,  4.86908414e-01],
        [ 9.87436521e-01,  4.83254376e-01],
        [ 9.87162584e-01,  4.60860509e-01],
        [ 9.86885299e-01,  4.53085076e-01],
        [ 9.86604627e-01,  4.60804143e-01],
        [ 9.86320528e-01,  4.47760658e-01],
        [ 9.86032962e-01,  4.42334877e-01],
        [ 9.85741889e-01,  4.22256827e-01],
        [ 9.85447268e-01,  4.13330658e-01],
        [ 9.85149061e-01,  4.13504170e-01]]])
Coordinates:
  * time      (time) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
  * state     (state) <U1 'S' 'I'
  * quantile  (quantile) float64 0.025 0.5 0.975
xarray.DataArray
  • quantile: 3
  • time: 19
  • state: 2
  • 0.99 0.01 0.00216 -3.854e-11 0.001078 ... 0.9854 0.4133 0.9851 0.4135
    array([[[ 9.90000000e-01,  1.00000000e-02],
            [ 2.16018439e-03, -3.85432445e-11],
            [ 1.07836902e-03, -3.09197532e-11],
            [ 9.86009775e-04, -5.96713974e-11],
            [ 9.38196955e-04, -3.76586987e-11],
            [ 8.42955232e-04, -5.43293292e-11],
            [ 8.33691614e-04, -4.99809063e-11],
            [ 7.26865718e-04, -4.11994600e-11],
            [ 7.26865718e-04, -4.04094170e-11],
            [ 7.26865718e-04, -4.79877910e-11],
            [ 6.93128072e-04, -1.89150841e-10],
            [ 6.27549461e-04, -2.20812685e-10],
            [ 5.15100662e-04, -2.16863071e-10],
            [ 4.49543152e-04, -3.32939435e-10],
            [ 4.14458419e-04, -3.28770814e-10],
            [ 3.62329718e-04, -1.73921638e-10],
            [ 3.28046993e-04, -2.04740131e-10],
            [ 3.28046993e-04, -1.68541313e-10],
            [ 3.24894294e-04, -2.75348373e-10]],
    
    ...
    
           [[ 9.90000000e-01,  1.00000000e-02],
            [ 9.89757485e-01,  4.43880969e-01],
            [ 9.89511991e-01,  4.79171810e-01],
            [ 9.89263484e-01,  4.62201128e-01],
            [ 9.89011920e-01,  4.91115204e-01],
            [ 9.88757279e-01,  4.42538064e-01],
            [ 9.88499524e-01,  4.49963167e-01],
            [ 9.88238612e-01,  4.71145485e-01],
            [ 9.87974499e-01,  4.83254767e-01],
            [ 9.87707147e-01,  4.86908414e-01],
            [ 9.87436521e-01,  4.83254376e-01],
            [ 9.87162584e-01,  4.60860509e-01],
            [ 9.86885299e-01,  4.53085076e-01],
            [ 9.86604627e-01,  4.60804143e-01],
            [ 9.86320528e-01,  4.47760658e-01],
            [ 9.86032962e-01,  4.42334877e-01],
            [ 9.85741889e-01,  4.22256827e-01],
            [ 9.85447268e-01,  4.13330658e-01],
            [ 9.85149061e-01,  4.13504170e-01]]])
    • time
      (time)
      int64
      0 1 2 3 4 5 6 ... 13 14 15 16 17 18
      array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
             18])
    • state
      (state)
      <U1
      'S' 'I'
      array(['S', 'I'], dtype='<U1')
    • quantile
      (quantile)
      float64
      0.025 0.5 0.975
      array([0.025, 0.5  , 0.975])

Agora só nos resta plotar as curvas:

def plot_bands(xarr): bands = xarr.quantile([0.025,0.5,0.975],'reps') plt.fill_between(range(19),bands.sel(quantile=0.025, state='S'),bands.sel(quantile=0.975, state='S'),alpha=0.3) plt.plot(range(19), bands.sel(quantile=0.5, state='S'), label='S') plt.fill_between(range(19),bands.sel(quantile=0.025, state='I'),bands.sel(quantile=0.975, state='I'),alpha=0.3) plt.plot(range(19), bands.sel(quantile=0.5, state='I'), label='I') plt.legend() plot_bands(sims)
Image in a Jupyter notebook

Amostrando da Distribuição Preditiva Posterior

Agora faremos o mesmo para a distribuição posterior dos parâmetros.

with model: trace.extend(pm.sample_posterior_predictive(trace, random_seed=3546))
Sampling: [Y]
100.00% [8000/8000 05:52<00:00]
trace
WARNING: Some output was deleted.
post_sims = simul_output(trace) post_sims
<xarray.DataArray (reps: 2000, time: 19, state: 2)>
array([[[0.99      , 0.01      ],
        [0.97572711, 0.02071364],
        [0.94710405, 0.04203926],
        ...,
        [0.02993243, 0.11254129],
        [0.0270859 , 0.09089571],
        [0.02498817, 0.07323623]],

       [[0.99      , 0.01      ],
        [0.97586603, 0.02064264],
        [0.94760569, 0.04176796],
        ...,
        [0.02946108, 0.11719789],
        [0.02656658, 0.09498329],
        [0.02443322, 0.07679208]],

       [[0.99      , 0.01      ],
        [0.97551715, 0.02091461],
        [0.94621383, 0.04283333],
        ...,
...
        ...,
        [0.02646676, 0.12550069],
        [0.02369624, 0.10238098],
        [0.02165464, 0.08332676]],

       [[0.99      , 0.01      ],
        [0.97623391, 0.02043848],
        [0.94894813, 0.04098761],
        ...,
        [0.02864484, 0.12945965],
        [0.02560142, 0.10580999],
        [0.02335845, 0.08626397]],

       [[0.99      , 0.01      ],
        [0.9758465 , 0.02073243],
        [0.94742746, 0.04212981],
        ...,
        [0.02753694, 0.12140939],
        [0.02473769, 0.09873996],
        [0.02267466, 0.08011455]]])
Coordinates:
  * time     (time) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
  * state    (state) <U1 'S' 'I'
Dimensions without coordinates: reps
xarray.DataArray
  • reps: 2000
  • time: 19
  • state: 2
  • 0.99 0.01 0.9757 0.02071 0.9471 ... 0.02474 0.09874 0.02267 0.08011
    array([[[0.99      , 0.01      ],
            [0.97572711, 0.02071364],
            [0.94710405, 0.04203926],
            ...,
            [0.02993243, 0.11254129],
            [0.0270859 , 0.09089571],
            [0.02498817, 0.07323623]],
    
           [[0.99      , 0.01      ],
            [0.97586603, 0.02064264],
            [0.94760569, 0.04176796],
            ...,
            [0.02946108, 0.11719789],
            [0.02656658, 0.09498329],
            [0.02443322, 0.07679208]],
    
           [[0.99      , 0.01      ],
            [0.97551715, 0.02091461],
            [0.94621383, 0.04283333],
            ...,
    ...
            ...,
            [0.02646676, 0.12550069],
            [0.02369624, 0.10238098],
            [0.02165464, 0.08332676]],
    
           [[0.99      , 0.01      ],
            [0.97623391, 0.02043848],
            [0.94894813, 0.04098761],
            ...,
            [0.02864484, 0.12945965],
            [0.02560142, 0.10580999],
            [0.02335845, 0.08626397]],
    
           [[0.99      , 0.01      ],
            [0.9758465 , 0.02073243],
            [0.94742746, 0.04212981],
            ...,
            [0.02753694, 0.12140939],
            [0.02473769, 0.09873996],
            [0.02267466, 0.08011455]]])
    • time
      (time)
      int64
      0 1 2 3 4 5 6 ... 13 14 15 16 17 18
      array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
             18])
    • state
      (state)
      <U1
      'S' 'I'
      array(['S', 'I'], dtype='<U1')
plot_bands(post_sims)
Image in a Jupyter notebook
%load_ext watermark %watermark -n -u -v -iv -w
The watermark extension is already loaded. To reload it, use: %reload_ext watermark Last updated: Wed Oct 19 2022 Python implementation: CPython Python version : 3.10.6 IPython version : 7.34.0 xarray : 2022.6.0 pymc : 4.2.2 arviz : 0.12.1 aesara : 2.8.7 matplotlib: 3.5.3 numpy : 1.23.2 Watermark: 2.3.1