import math
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import animation, rc

WIDTH, HEIGHT = 300, 300

class Ball(plt.Circle) :
    def __init__(self, radius=None, xy=None,
                 v=None, color=None):
        if radius == None: 
            radius = np.random.randint(4, 10)
        if xy == None:
            x = np.random.randint(radius, WIDTH-radius)
            y = np.random.randint(radius, HEIGHT-radius)
            xy = (x,y)
        if v == None:
            dx = -5 + np.random.random()*10
            dy = -5 + np.random.random()*10
            v = (dx,dy)
        self.v = v
        if color == None: 
            color = np.random.random(3,)
        super().__init__(xy, radius=radius, facecolor=color)

def update_v_if_collide(b1, b2):
    x1, y1 = b1.center
    x2, y2 = b2.center
    r1 = b1.radius
    r2 = b2.radius
    vx1, vy1 = b1.v
    vx2, vy2 = b2.v
    dx = x1 - x2
    dy = y1 - y2
    t = dx**2 + dy**2 - (r1+r2)**2
    if t <= 0:
        if t < 0:
            x1 -= vx1; y1 -= vy1
            x2 -= vx2; y2 -= vy2
        b1.v, b2.v = update_v(x1, y1, r1, vx1, vy1,
                              x2, y2, r2, vx2, vy2)

def update_v(x1, y1, r1, vx1, vy1, x2, y2, r2, vx2, vy2):
    # the 2d-collision equations can be found at
    # https://en.wikipedia.org/wiki/Elastic_collision

    
    # insert your code here
    
    
    return (vx1, vy1), (vx2, vy2)

def move(ball):
    x, y = ball.center
    vx, vy = ball.v
    r = ball.radius
    x += vx
    y += vy
    if not (r <= x <= WIDTH-r):
        vx *= -1
        x += vx
        if x-r < 0: x = r
        if x+r > WIDTH: x = WIDTH-r
    if not (r <= y <= HEIGHT-r):
        vy *= -1
        y += vy
        if y-r < 0: y = r
        if y+r > HEIGHT: y = HEIGHT-r
    ball.center = (x,y)
    ball.v = (vx,vy)

def animate(i):
    for i in range(len(balls)):
        move(balls[i])
    for i in range(len(balls)):
        for j in range(i+1, len(balls)):
            update_v_if_collide(balls[i], balls[j])
    return balls

#--------------------------------------------------
fig, ax = plt.subplots()
fig.set_size_inches(4,4)
ax.set_xlim((0, WIDTH))
ax.set_ylim((0, HEIGHT))

balls = []
n = 9
for i in range(n):
    b = Ball(xy = ((i+1)*WIDTH/(n+1), HEIGHT-70), v=(0,-1))
    balls.append(b)
    ax.add_patch(b)
b = Ball(xy=(WIDTH/2, 40), radius=30, v=(1,5), color='red')
balls.append(b)
ax.add_patch(b)
anim = animation.FuncAnimation(fig, animate, 
                               frames=None,
                               interval=30,
                               blit=True)
plt.show()
