#!/usr/bin/env python
"""System dynamics simulator for the Kermack-McKendrick model.
"""
from __future__ import division
import argparse
import json
import matplotlib.pyplot as plt
from scipy import integrate

__author__ = ("Alexandre Vassalotti <alexandre@peadrop.com>",)

def simulate(days, population, initial_infected, beta, nu):
    def sir_model(X, t):
        S, I, R = X
        return (-beta * I * S,
                 beta * I * S - nu * I,
                 nu * I)

    initial_susceptible = population - initial_infected
    results = integrate.odeint(sir_model,
                               (initial_susceptible, initial_infected, 0),
                               xrange(days))

    return {'susceptible': list(results[:,0]),
            'infected': list(results[:,1]),
            'recovered': list(results[:,2])}

def plot(results):
    plt.figure()
    plt.plot(results['susceptible'], 'b', label='Susceptible')
    plt.plot(results['infected'], 'r', label='Infected')
    plt.plot(results['recovered'], 'g', label='Recovered')
    plt.legend()
    plt.xlabel('Time')
    plt.ylabel('Individual Count')
    plt.title('System dynamics simulation')
    plt.show()

def save(results, filename):
    with open(filename, "w") as f:
        json.dump(results, f, indent=2)

def main():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('--days', default=500, type=int,
        help='number of days to simulate')
    parser.add_argument('--population', metavar='N', default=5000, type=float,
        help='population size')
    parser.add_argument('--infected', metavar='I0', default=5, type=float,
        help='initial number of infected')
    parser.add_argument('--beta', default=0.10, type=float,
        help='rate of infection')
    parser.add_argument('--nu', default=0.01, type=float,
        help='rate of recovery')
    parser.add_argument('--plot', action='store_true',
        help='display a plot at the end of the simulation')
    parser.add_argument('--save', metavar='FILENAME', default=None, type=str,
        help='save statistics to the given file in JSON format')
    args = parser.parse_args()

    results = simulate(args.days, args.population, args.infected,
                       args.beta, args.nu)

    if not args.plot and not args.save:
        print "You need to set the flags --plot or --save to see results."

    if args.plot:
        plot(results)
    if args.save:
        save(results, args.save)

if __name__ == '__main__':
    main()