import numpy as np
import matplotlib.pylab as plt


def random_point(num_points,num_dim,total_size,starting_value):
    """
    put in how many points you need (ex 10),
    how many dimensions your point needs to have (ie, how many dimensions your x vector is),
    the size of the domain ( if it's from -5 to 5, then the domain will be 10),
    and the starting value (for the prev example, the starting value is -5)
    """
    rand_array = total_size * np.random.rand(num_points,num_dim) + starting_value
    return rand_array



def strong_wolfe(func, grad_func, x, pk, c1, c2, alpha = 1, alpha_max = 100, max_iters = 3000, verbose = False):

    """
    Strong Wolfe condition line search method

    Inputs:
    func:           objective function
    grad_func:      gradient of the objective function
    x:              design variables
    pk:             search direction at the kth iteration
    alpha:          initial estimate for step length
    alpha_max:      max value of alpha

    returns:
    alpha:          step length satisfying the strong Wolfe conditions
    """

    #get the function and the gradient at alpha = 0
    fk = func(x)
    gk = grad_func(x)

    # get the dot product of the gradient with the search direction
    # this will enable evaluation of the derivative of the merit function
    pdot_gk = np.dot(gk,pk)

    # stores the old value of  the objective function
    alpha_old = 0.0
    pdot_gj_old = pdot_gk
    fj_old = fk

    for j in range(max_iters):
        # evaulate the merit function
        fj = func(x + alpha*pk)

        # evaluate the gradient at point x + alpha*pk
        gj = grad_func(x + alpha*pk)
        pdot_gj = np.dot(gj,pk)

        # check if sufficient decrease condition is violated
        if (fj > fk + c1 * alpha * pdot_gk or (j > 0 and fj > fj_old)):
            if verbose:
                print('Sufficient decrease conditions violated: interval found')
            return zoom(func,grad_func,fj_old,pdot_gj_old,alpha_old,fj,pdot_gj,alpha,x,fk,gk,pk,c1,c2,max_iters,verbose = verbose)
        
        # check to see if strong wolfe conditions are satisfied
        if np.fabs(pdot_gj) <= c2 * np.fabs(pdot_gk):
            if verbose:
                print('Strong Wolfe alpha found directly')
            func(x + alpha*pk)
            return alpha
        
        # if curvatutre condition is violated...
        if pdot_gj >= 0.0:
            if verbose:
                print('Slope condition violated; interval found')
            return zoom(func,grad_func, fj_old, pdot_gj_old, alpha_old, fj, pdot_gj, alpha,x,fk,gk,pk,c1,c2,max_iters,verbose = verbose)
        
        # record the old values of alpha and fj

        fj_old = fj
        pdot_gj_old = pdot_gj
        alpha_old = alpha

        # pick a new value for alpha

        alpha = min ( 2 * alpha, alpha_max)
        
        if alpha >= alpha_max:
            if verbose:
                print('Line search failed here')
                return None
    if verbose:
        print('Line Search Unsuccessful')
    return alpha



def zoom(func,grad_func,f_low,pdot_low,alpha_low,f_high,pdot_high,alpha_high,x,fk,gk,pk,c1,c2,max_iters,verbose):

    pdot_gk = np.dot(pk,gk)

    for j in range(max_iters):

        #pick an alpha value that bisects the interval
        alpha_j = 0.5* (alpha_high + alpha_low)

        # evaluate the merit function
        fj = func(x + alpha_j*pk)

        # check if the sufficient decrease condition is violated
        if fj > fk + c1 * alpha_j or fj >= f_low:
            if verbose:
                print('Zoom: Sufficient decrease conditions violated')
            alpha_high = alpha_j
            f_high = fj

            # We need the derivative here for pdot_high
            gj = grad_func(x + alpha_j * pk)
            pdot_high = np.dot(gj,pk)
        else:
            # evaluate the gradient of the funtion and the derivative of the merit function
            gj = grad_func(x + alpha_j * pk)
            pdot_gj = np.dot(gj,pk)

            # return alpha, the strong Wolfe conditions are satisfied
            if np.fabs(pdot_gj) <= c2 * np.fabs(pdot_gk):
                if verbose:
                    print('Zoom: Wolfe conditions satisfied')
                return alpha_j
            elif verbose:
                print('Zoom: Curvature condition violated')

                # make sure that the interval is right
                if pdot_gk * (alpha_high - alpha_low) >= 0.0:
                    # swap alpha high/alpha low
                    alpha_high = alpha_low
                    pdot_high = pdot_low
                    f_high = f_low

                # swap alpha low / alpha j
                alpha_low = alpha_j
                pdot_low = pdot_gj
                f_low = fj
    return alpha_j



def himmelblau():
    f = lambda x1,x2: (x1**2 + x2 - 11)**2 + (x1 + x2**2 - 7)**2
    df = lambda x1,x2: np.array([4*x1*(x1**2 + x2 - 11) + 2 * (x1 + x2**2 - 7), \
                           2 * (x1**2 + x2 - 11) + 4 * x2 * (x1 + x2**2 - 7)])
    ddf = lambda x1,x2: np.array([[12 * x1**2 + 4 * x2 - 42, 4 * x1 + 4 * x2], \
                             [4*x1 + 4 * x2, 12 * x2**2 + 4 * x1 - 26]])
    F = lambda X: f(X[0],X[1])
    dF = lambda X: df(X[0],X[1])
    ddF = lambda X: ddf(X[0],X[1])

    return F,dF,ddF

def rosenbrock():
    f = lambda x1,x2: (1-x1)**2 + 100*(x2 - x1**2)**2
    df = lambda x1,x2: np.array([2*(x1-1) - 400*(x2-x1**2)*x1, \
                           200*(x2-x1**2)])
    ddf = lambda x1,x2: np.array([[2 - 400*x2 + 1200*x1**2, -400*x1], \
                             [-400*x1, 200]])
    F = lambda X: f(X[0],X[1])
    dF = lambda X: df(X[0],X[1])
    ddF = lambda X: ddf(X[0],X[1])

    return F,dF,ddF


class BFGS:
    def __init__(self,n):
        self.n = n
        self.D = None
        self.B = None
        return
    
    def update(self,s,y):
        """Perform the update"""

        # check curvature condition
        if np.dot(y,s) >= 0.0:
            # compute value of rho
            rho = 1.0 / np.dot(y,s)

            # if we're on the initial iteration, reset the new values
            if self.D is None:
                scale = np.dot(y,s)/np.dot(y,y)
                self.D = scale*np.eye(self.n)

            # Compute the BFGS update
            W = np.eye(self.n) - rho*np.outer(y,s)
            self.D = np.dot(W.T, np.dot(self.D,W))
            self.D += rho * np.outer(s,s)

            if self.B is None:
                scale = np.dot(y,y)/np.dot(y,s)
                self.B = np.eye(self.n)

            r = np.dot(self.B,s)
            beta = 1.0/np.dot(s,r)
            self.B = self.B - beta*np.outer(r,r) + rho * np.outer(y,y)

        return
    
class SR1:
    def __init__(self, n, tol=1e-12):
        self.n = n
        self.D = np.eye(self.n)
        self.B = np.eye(self.n)
        self.tol = tol
        return

    def update(self, s, y):
        """Perform the update"""

        # Update the Hessian approximation
        r = y - np.dot(self.B, s)
        beta = np.dot(r, s)
        if np.fabs(beta)/np.dot(s, s) > self.tol:
            self.B += np.outer(r, r)/beta

        # Update the approximate inverse Hessian
        r = s - np.dot(self.D, y)
        beta = np.dot(r, y)
        if np.fabs(beta)/np.dot(s, s) > self.tol:
            self.D += np.outer(r, r)/beta

        return