-- Basic partical swarm algorithim
-- Bradford N. Barr - 2010
--
-- PSO Algorithm: 
--
-- Initialize the particles
--      Create a list of particals
--      Set their velocity to 0 or random 
-- 
-- For each particle
--      Calculate the fitness value for each
--      If the fitness value is better than the best
--          Save the new fitness as best
--
-- For each particle
--      Find in the particle neighborhood the particle with the best fitness
--      Calculate particle velocity
--      Apply velocity constraints
--      Update particle position
--      Apply position constraints

function init_particles(num, dim, rand)
    local particles = {}
    for i=1, num do
        local particle = {}
        particle.p = {}
        particle.v = {}
        for j=1, dim do
            if rand then
                tinsert(particle.p, random())
                tinsert(particle.v, random())
            else 
                tinsert(particle.p, 0)
                tinsert(particle.v, 0)
            end
        end
        particle.f = 999999 
        tinsert(particles, particle)
    end
    return particles
end

function calc_fitness (particles, fitness_func, pbest)
    local gbest = particles[1]

    for i=1, getn(particles) do
        particles[i].f = fitness_func(particles[i])

        if particles[i].f < pbest.f then
            pbest = {
                p = {},
                v = {},
            }
            for j=1, getn(particles[i].p) do
                tinsert(pbest.p, particles[i].p[j])
                tinsert(pbest.v, particles[i].v[j])
            end
            pbest.f = particles[i].f
        end

        if particles[i].f < gbest.f then
            gbest = particles[i]
        end
    end

    return pbest, gbest
end

function update_velocity (particle, pbest, gbest, pphi, gphi)
    local velocity = {}
    for i=1, getn(particle.v) do
        local prand = pphi*random()
        local grand = gphi*random()
        local pdiff = pbest.p[i] - particle.p[i]
        local gdiff = gbest.p[i] - particle.p[i]
        local vel = particle.v[i] + (prand * pdiff) + (grand * gdiff)
        tinsert(velocity, vel)
    end
    return velocity
end

function update_position (particle)
    local position = {}
    for i=1, getn(particle.p) do
        tinsert(position, particle.p[i] + particle.v[i]) 
    end
    return position
end

function update_particles (particles, pbest, gbest, pphi, gphi)
    for i=1, getn(particles) do
        particles[i].v = update_velocity(particles[i], pbest, gbest, pphi, gphi)
        particles[i].p = update_position(particles[i])
    end
end

function run_pso (param, fitness_func, print_particles)
    randomseed(param.seed or random(10000))

    local iterations = param.iterations or 100
    local success = param.success or .1
    local rand = param.rand or 1
    local pphi = param.pphi or 2
    local gphi = param.gphi or 2
    local num = param.num or 10
    local dim = param.dim or 2

    local particles = init_particles(num, dim, random)

    local pbest = particles[1] 
    local gbest = particles[1]
    
    while iterations > 0 and pbest.f > success do
        pbest, gbest = calc_fitness(particles, fitness_func, pbest)
        update_particles(particles, pbest, gbest, pphi, gphi)
        print_particles(particles, iterations, param.iterations)
        iterations = iterations - 1
    end
    
    return pbest
end

function print_particles (particles, iterations, max_iterations)
    local current_iteration = max_iterations - iterations + 1
    local num_particles = getn(particles)
    local num_dimensions = getn(particles[1].p)

    local dat_str = ''
    for i=1, num_particles do
        for j=1, num_dimensions do
            dat_str = dat_str .. particles[i].p[j] .. ' '
        end
        for j=1, num_dimensions do
            dat_str = dat_str .. particles[i].v[j] .. ' '
        end
        dat_str = dat_str .. particles[i].f .. '\n'
    end
    local dat_filename = "iter" .. current_iteration .. ".dat"
    local dat_file = openfile(dat_filename, 'w')
    write(dat_file, dat_str)
    closefile(dat_file)

    local plot_str = ''
    if iterations == max_iterations then
        plot_str = plot_str .. "set xrange [-1:15]\n"
        plot_str = plot_str .. "set yrange [-1:15]\n"
        plot_str = plot_str .. "set zrange [0:10]\n"
        plot_str = plot_str .. "set ticslevel 0\n"
        plot_str = plot_str .. "set dgrid3d 30,30\n"
        plot_str = plot_str .. "set terminal png\n"
    end
    plot_str = plot_str .. "set output \"~/Code/lua/ai/pso/iter" 
    plot_str = plot_str .. current_iteration .. "-3d.png\"\n"
    plot_str = plot_str .. "splot \"~/Code/lua/ai/pso/" .. dat_filename 
    plot_str = plot_str .. "\" using 1:2:"
    plot_str = plot_str .. (2 * num_dimensions + 1) 
    plot_str = plot_str .. " with lines ti \"Iteration ".. current_iteration
    plot_str = plot_str .. "\", \"< echo '5 8 0'\" ti \"Target\"\n"
    plot_str = plot_str .. "set output \"~/Code/lua/ai/pso/iter" 
    plot_str = plot_str .. current_iteration .. "-2d.png\"\n"
    plot_str = plot_str .. "plot \"~/Code/lua/ai/pso/" .. dat_filename 
    plot_str = plot_str .. "\" using 1:2 ti \"Iteration".. current_iteration
    plot_str = plot_str .. "\", \"< echo '5 8'\" ti \"Target\"\n"

    local plot_filename = 'plot.plt'
    local plot_file = openfile(plot_filename, 'a+')
    write(plot_file, plot_str)
    closefile(plot_file)
end

--local fitness_func = function (particle)
--    local point = {5, 8}
--    local fitness = 0.0
--    for i=1, getn(point) do
--        fitness = fitness + (point[i] - particle.p[i])^2
--    end
--    return sqrt(fitness)
--end
--
--param = {
--    iterations = 10000,
--    success = .01,
--    seed = 1024,
--    rand = 1,
--    pphi = .01,
--    gphi = .01,
--    num = 5,
--    dim = 2,
--}

--best = run_pso(param, fitness_func, print_particles)

--best_str = 'BEST: \n{'
--best_str = best_str..' p{'..best.p[1]..', '..best.p[2]..'}'
--best_str = best_str..' v{'..best.v[1]..', '..best.v[2]..'}'
--best_str = best_str..' f:'..best.f..'\n}'
--print(best_str)
