#!/usr/bin/env python
"""Agent-based simulator of infection spread.
"""
from __future__ import division
import argparse
import bisect
import collections
import itertools
import operator
import json
import random
import sys
import time

import matplotlib.pyplot as plt
import pygame
from pygame.locals import *

__author__ = ("Alexandre Vassalotti <alexandre@peadrop.com>",)

BLACK = (0, 0, 0)
WHITE = (255, 255, 255)
RED = (255, 0, 0)
GREEN = (0, 255, 0)
BGCOLOR = BLACK

INFO_UPDATE_RATE = 10
DRAW_POINTS_THRESHOLD = 500
YDIM = XDIM = 500

# These are set by the main() function.
RADIUS = None
GAMMA = None


class Agent(object):

    __slots__ = ('x', 'y', 'state')

    def __init__(self, state='susceptible'):
        self.x = random.random()
        self.y = random.random()
        self.state = state

    def __repr__(self):
        return "Agent(x=%f, y=%f, state='%s')" % (self.x, self.y, self.state)

    def update(self, world):
        self.x = random.random()
        self.y = random.random()
        if self.state == 'susceptible':
            if world.is_infected_agents(self.x, self.y, RADIUS):
                self.state = 'infected'
        elif self.state == 'infected' and random.random() < GAMMA:
            self.state = 'recovered'

    def draw(self, screen, draw_point=False):
        if self.state == 'susceptible':
            color = WHITE
        elif self.state == 'infected':
            color = RED
        else:
            color = GREEN

        if draw_point:
            screen.set_at((int(XDIM * self.x), int(YDIM * self.y)), color)
        else:
            pygame.draw.circle(screen, color, (int(XDIM * self.x),
                                               int(YDIM * self.y)), 2)


class World(object):

    def __init__(self, population, infected):
        self.agents = collections.defaultdict(list)
        self.population = population
        for i in xrange(population - infected):
            self.agents['susceptible'].append(Agent())
        for i in xrange(infected):
            self.agents['infected'].append(Agent('infected'))
        self._sort_infected()

    def _sort_infected(self):
        self.agents['infected'].sort(key=operator.attrgetter('x'))
        self.infected_xs = map(operator.attrgetter('x'),
                               self.agents['infected'])

    def all_agents(self):
        return itertools.chain(*self.agents.values())

    def update(self):
        agents = collections.defaultdict(list)
        for agent in self.all_agents():
            agent.update(self)
            agents[agent.state].append(agent)
        self.agents = agents
        self._sort_infected()

    def draw(self, screen):
        draw_points = self.population > DRAW_POINTS_THRESHOLD
        for agent in self.all_agents():
            agent.draw(screen, draw_points)

    def is_infected_agents(self, x0, y0, radius):
        """Tell whether there is an infected agent within the range
        defined by the given radius.

        Check boundary conditions and wrap around the edges if necessary.  The
        algorithm is implemented using a sweep and prune approach to scale
        well to large population.
        """
        lb = x0 - radius
        ub = x0 + radius
        start = bisect.bisect_left(self.infected_xs, lb % 1.0)
        end = bisect.bisect_right(self.infected_xs, ub % 1.0)
        if lb > 0 and ub < 1.0:
            infected = self.agents['infected'][start:end]            
        else:
            infected = self.agents['infected'][start:] + \
                self.agents['infected'][:end]

        if lb > 0 and ub < 1.0 and y0 - radius > 0 and y0 + radius < 1.0:
            # Fast path
            for agent in infected:
                if (agent.x - x0)**2 + (agent.y - y0)**2 < radius**2:
                    return True
        else:
            # Wrap around the edges as necessary.
            for agent in infected:
                dx = abs(agent.x - x0)
                dy = abs(agent.y - y0)
                if min(dx, 1.0 - dx)**2 + min(dy, 1.0 - dy)**2 < radius**2:
                    return True
        return False


class WorldStatistics:

    def __init__(self, world):
        self.world = world
        self.susceptible_count = []
        self.infected_count = []
        self.recovered_count = []

    def record(self):
        self.susceptible_count.append(len(self.world.agents['susceptible']))
        self.infected_count.append(len(self.world.agents['infected']))
        self.recovered_count.append(len(self.world.agents['recovered']))

    def plot(self):
        plt.figure()
        plt.plot(self.susceptible_count, 'b', label='Susceptible')
        plt.plot(self.infected_count, 'r', label='Infected')
        plt.plot(self.recovered_count, 'g', label='Recovered')
        plt.legend()
        plt.xlabel('Time')
        plt.ylabel('Individual Count')
        plt.title('Agent-based simulation')
        plt.show()

    def save(self, filename):
        with open(filename, "w") as f:
            json.dump({'susceptible': self.susceptible_count,
                       'infected': self.infected_count,
                       'recovered': self.recovered_count}, f, indent=2)


def main():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('--days', default=500, type=int,
        help='number of days to simulate')
    parser.add_argument('--population', metavar='N', default=5000, type=int,
        help='population size')
    parser.add_argument('--infected', metavar='I0', default=5, type=int,
        help='initial number of infected')
    parser.add_argument('--radius', default=0.003, type=float,
        help='radius of infection')
    parser.add_argument('--gamma', default=0.01, type=float,
        help='rate of recovery')
    parser.add_argument('--plot', action='store_true',
        help='display a plot at the end of the simulation')
    parser.add_argument('--save', metavar='FILENAME', default=None, type=str,
        help='save statistics to the given file in JSON format')
    args = parser.parse_args()

    pygame.init()
    screen = pygame.display.set_mode((XDIM, YDIM))
    pygame.display.set_caption('Pandemic')

    world = World(args.population, args.infected)
    global GAMMA, RADIUS
    GAMMA = args.gamma
    RADIUS = args.radius

    stats = WorldStatistics(world)

    print >>sys.stderr, "Press 'Space' to end the simulation."
    total_time = 0
    counter = itertools.count()
    running = True
    while running:
        start_time = time.time()

        for e in pygame.event.get():
            if e.type == QUIT or (e.type == KEYUP and e.key == K_SPACE):
                running = False

        world.update()
        stats.record()

        screen.fill(BLACK)
        world.draw(screen)
        pygame.display.flip()

        total_time += time.time() - start_time
        i = next(counter)
        if i % INFO_UPDATE_RATE == 0:
            info_str = ("%04d %.1f FPS\tS:%d\tI:%d\tR:%d\r" %
                        (i, INFO_UPDATE_RATE / total_time,
                         len(world.agents['susceptible']),
                         len(world.agents['infected']),
                         len(world.agents['recovered'])))

            sys.stderr.write(" " * len(info_str) + "\r")
            sys.stderr.write(info_str)
            sys.stderr.flush()
            total_time = 0
        if i >= args.days:
            running = False

    if args.plot:
        stats.plot()
    if args.save:
        stats.save(args.save)

    print >>sys.stderr, "\nGood Bye!"


if __name__ == '__main__':
    main()