# dystopian_waltz.py
#
# Copyright (C) 2013, 2014 Guillaume Tucker <guillaume@mangoz.org>
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any later
# version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# this program.  If not, see <http://www.gnu.org/licenses/>.

import sys
import os
import argparse
import random
import splat.data
import splat.gen
import splat.filters
import splat.interpol
import splat.scales
from splat import lin2dB, dB2lin

SPLAT_VER = '1.1'
if splat.__version__ != SPLAT_VER:
    print("WARNING - Splat version mismatch: {}, expected: {}".format(
            splat.__version__, SPLAT_VER))

# default argument values
DEF_TEMPO = 164.0
DEF_OUTPUT = "dystopian_waltz.wav"

# -----------------------------------------------------------------------------
# Sequencer classes
#

class SampleSet(object):

    def __init__(self, entries):
        self._samples = []
        self._dict = dict()
        for f, n, g in entries:
            s = splat.data.Fragment.open(f)
            s.amp(g)
            name = os.path.splitext(os.path.basename(f))[0]
            self._dict[name] = s
            for i in range(n):
                self._samples.append(s)

    def pick(self):
        return random.choice(self._samples)

    def get(self, name):
        return self._dict[name]


class Pattern(object):

    def __init__(self, beats=4):
        self._beats = beats

    @property
    def beats(self):
        return self._beats

    def play(self, frag, bar, beat, t, T):
        raise NotImplementedError("Pattern.play() must be implemented")

    def mix(self, master, sample, t, g=None):
        master.mix(sample, t, 0.0, g)


class Sequencer(object):

    def __init__(self, tempo, frag, verbose=False, silence=0.0):
        self.T = tempo
        self.frag = frag
        self.verbose = verbose
        self._silence = silence

    @property
    def T(self):
        return self._T

    @T.setter
    def T(self, value):
        self._T = 60.0 / value

    def run(self, patterns):
        total_beats = sum(p[0].beats for p in patterns)
        new_size = self._silence + (total_beats * self.T)
        if self.frag.duration < new_size:
            self.frag.resize(new_size)
        n_beat = 0
        for bar, group in enumerate(patterns):
            beats = group[0].beats
            for pattern in group:
                bar_beat = n_beat
                for beat in range(beats):
                    t = self._silence + (bar_beat * self.T)
                    if self.verbose is True:
                        print("{}/{} {} {:.3f}".format(
                                (n_beat + beat + 1), total_beats,
                                (bar, beat, bar_beat), t))
                    pattern.play(self.frag, bar, beat, t, self.T)
                    bar_beat += 1
            n_beat += beats

# -----------------------------------------------------------------------------
# Dystopian Waltz drum patterns
#

class FlakyPattern(Pattern):

    def __init__(self, time_error=0.02, gain_error=3.0, *args, **kw):
        super(FlakyPattern, self).__init__(*args, **kw)
        self._te = (-time_error, time_error)
        self._ge = (-gain_error, gain_error)

    def mix(self, frag, sample, t, g=0.0):
        t += random.uniform(*self._te)
        h = tuple(g + random.uniform(*self._ge) for i in range(frag.channels))
        frag.mix(sample, t, 0.0, h)


class DystopianPattern(FlakyPattern):

    def __init__(self, bass, snare, ride, crash, *args, **kw):
        super(DystopianPattern, self).__init__(*args, **kw)
        self._bass = bass
        self._snare = snare
        self._ride = ride
        self._crash = crash


class Silence(Pattern):

    def play(self, frag, bar, beat, t, T):
        pass


class Tick(Pattern):

    def __init__(self, *args, **kw):
        super(Tick, self).__init__(*args, **kw)
        tick1 = splat.data.Fragment()
        tick2 = splat.data.Fragment()
        gen = splat.gen.TriangleGenerator(tick1)
        gen.filters = [splat.filters.linear_fade]
        gen.run(0.0, 0.05, 2200.0)
        gen.frag = tick2
        gen.run(0.0, 0.05, 3300.0)
        self._tick = [tick1 for i in range(3)] + [tick2]

    def play(self, frag, bar, beat, t, T):
        self.mix(frag, self._tick[beat], t, - 3.0 * (4 - beat))


class PatternI(DystopianPattern):

    def __init__(self, extra, *args, **kw):
        super(PatternI, self).__init__(*args, **kw)
        self._extra = extra
        self._arriving = self._extra.get('arriving')

    def play(self, frag, bar, beat, t, T):
        self.mix(frag, self._snare.pick(), t)
        if beat == 0 and bar == 1:
            self.mix(frag, self._arriving, t - self._arriving.duration)
        elif beat == 3:
            self.mix(frag, self._bass.pick(), t + T / 2, 3.0)
            self.mix(frag, self._bass.pick(), t + T * 3 / 4, 0.0)


class PatternA(DystopianPattern):

    def play(self, frag, bar, beat, t, T):
        if beat == 0:
            self.mix(frag, self._bass.pick(), t, 4.0)
            self.mix(frag, self._bass.pick(), t + (T / 2.0), 0.0)
        elif beat == 1:
            self.mix(frag, self._snare.pick(), t, 0.0)
        elif beat == 2:
            self.mix(frag, self._ride.pick(), t, -6.0)
            self.mix(frag, self._bass.pick(), t + (T / 2.0), 1.5)
        else:
            self.mix(frag, self._snare.pick(), t, -0.5)
            self.mix(frag, self._ride.pick(), t + (T / 2.0), -4.0)


class PatternB(DystopianPattern):

    def play(self, frag, bar, beat, t, T):
        self.mix(frag, self._bass.pick(), t, 3.0)
        self.mix(frag, self._snare.pick(), (t + T/2), 2.0)


class PatternB1(DystopianPattern):

    def play(self, frag, bar, beat, t, T):
        if beat == 0:
            self.mix(frag, self._crash.pick(), (t - T/2), -1.0)


class PatternB2(DystopianPattern):

    def __init__(self, *args, **kw):
        super(PatternB2, self).__init__(*args, **kw)
        self._n = 0

    def play(self, frag, bar, beat, t, T):
        self.mix(frag, self._ride.pick(), t + T/2, -3.0)
        if self._n % 3 == 0:
            self.mix(frag, self._bass.pick(), t, 3.0)
        else:
            self.mix(frag, self._snare.pick(), t, 1.5)
        self._n += 1
        if self._n == 8:
            self._n = 0


class PatternC(DystopianPattern):

    def play(self, frag, bar, beat, t, T):
        self.mix(frag, self._bass.pick(), t, 3.0)
        self.mix(frag, self._ride.pick(), (t + T/2), 0.0)
        self.mix(frag, self._ride.pick(), (t + 3*T/4), -6.0)


class PatternC1(DystopianPattern):

    def play(self, frag, bar, beat, t, T):
        if beat == 3:
            self.mix(frag, self._snare.pick(), t, 0.0)


class PatternC2(DystopianPattern):

    def play(self, frag, bar, beat, t, T):
        self.mix(frag, self._snare.pick(), t, 0.0)
        self.mix(frag, self._snare.pick(), t + T/4, -3.0)


class PatternPow(DystopianPattern):

    def __init__(self, extra, beats, *args, **kw):
        super(PatternPow, self).__init__(*args, **kw)
        self._extra = extra
        self._pow = self._extra.get('pow3')
        self._beats = beats

    def play(self, frag, bar, beat, t, T):
        if beat in self._beats:
            for x in range(4):
                self.mix(frag, self._pow, (t + T * x / 4), (-x * 3.5))


class PatternEnding(Pattern):

    def __init__(self, extra, *args, **kw):
        super(PatternEnding, self).__init__(*args, **kw)
        self._extra = extra
        self._bang = self._extra.get('bang1')

    def play(self, frag, bar, beat, t, T):
        if beat == 0:
            self.mix(frag, self._bang, t - T - 2.6, 3.0)

# -----------------------------------------------------------------------------
# Dystopian Waltz bass patterns
#

class BassPattern(Pattern):

    def __init__(self, gen, notes, nlen=0.9, *args, **kw):
        super(BassPattern, self).__init__(*args, **kw)
        self._gen = gen
        self._notes = notes
        self._nlen = nlen
        self._s = splat.scales.LogScale(fund=55.0)


class BassPatternA(BassPattern):

    def play(self, frag, bar, beat, t, T):
        note = self._notes[beat]
        if note is not None:
            self._gen.run(t, t + (T * self._nlen), self._s[note])


class BassPatternB(BassPattern):

    def play(self, frag, bar, beat, t, T):
        notes = self._notes[beat]
        if notes is None:
            return
        n = len(notes)
        for i, (note, m) in enumerate(notes):
            if note is not None:
                t0 = i * T / n
                t1 = m * self._nlen * T / n
                self._gen.run(t + t0, t + t0 + t1, self._s[note])


class BassPatternSpline(BassPattern):

    def __init__(self, sub_beats, sub_gain, *args, **kw):
        super(BassPatternSpline, self).__init__(*args, **kw)
        self._sub_beats = sub_beats
        self._sub_gain = sub_gain
        self._pts = []
        self._t0 = None

    def play(self, frag, bar, beat, t, T):
        if beat == 0:
            self._t0 = t
        note_name = self._notes[beat]
        if note_name is None:
            return
        freq = splat.lin2dB(self._s[note_name])
        self._pts.append((t, freq))
        if beat == 3:
            last_note = self._notes[-1]
            last_freq = splat.lin2dB(self._s[last_note])
            self._pts.append((t + T, last_freq, 0.0))
            spline = splat.interpol.Spline(self._pts)
            T2 = T * 0.28
            sub = self._sub_beats / 4
            for i in range(self._sub_beats):
                t1 = self._t0 + float(i) * T / sub
                self._gen.run(t1, t1 + T2, splat.dB2lin(spline.value(t1)),
                              levels=self._sub_gain)
            self._pts = []


# -----------------------------------------------------------------------------
# Main function
#

def main(argv):
    parser = argparse.ArgumentParser(description="Dystopian Waltz")
    parser.add_argument('--tempo', type=float, default=DEF_TEMPO,
                        help="Tempo in BPM, default is {}".format(DEF_TEMPO))
    parser.add_argument('-o', '--output', default=DEF_OUTPUT,
                        help="Output file, default is {}".format(DEF_OUTPUT))
    parser.add_argument('-v', '--verbose', action='store_true',
                        help="Verbose output")
    parser.add_argument('--no-reverb', action='store_true',
                        help="Disable reverb to reduce computing time")
    parser.add_argument('--no-ticks', action='store_true',
                        help="Disable ticks during intro")
    parser.add_argument('--no-drums', action='store_true',
                        help="Disable drums (including ticks)")
    parser.add_argument('--no-bass', action='store_true',
                        help="Disable bass")
    parser.add_argument('--hack', action='store_true',
                        help="Quick hack to try only a small part")
    args = parser.parse_args(argv[1:])

    bass = SampleSet([('train/bass-002.wav', 7, 0.0),
                      ('fireworks/boom-1.wav', 12, 0.0)])
    snare = SampleSet([('train/snare-001.wav', 12, (1.5, 0.0)),
                       ('train/snare-002.wav', 9, (0.0, 1.5)),
                       ('fireworks/pow1.wav', 1, -6.0)])
    snare1 = SampleSet([('train/snare-001.wav', 12, (1.5, 0.0)),
                        ('train/snare-002.wav', 9, (0.0, 1.5))])
    ride = SampleSet([('train/ride-001.wav', 12, (2.0, 1.5)),
                      ('train/ride-002.wav', 3, (-1.5, -1.0)),
                      ('fireworks/clap-rocket1.wav', 3, 0.0)])
    crash = SampleSet([('train/crash-001b.wav', 12, -3.0),
                       ('fireworks/whistling-rocket4.wav', 1, 0.0),
                       ('fireworks/clap-rocket2.wav', 1, 0.0)])
    extra = SampleSet([('fireworks/pow3.wav', 1, 0.0),
                       ('train/arriving.wav', 1, 6.0),
                       ('thunderstorm/bang1.wav', 1, 0.0)])
    stuff = (bass, snare, ride, crash)

    gap = extra.get('arriving').duration + 0.5 - (4.0 * 60.0 / args.tempo)
    if args.hack:
        gap = 0.0
    master = splat.data.Fragment(channels=2)
    seq = Sequencer(args.tempo, master, args.verbose, gap)

    S = Silence()
    I = PatternI(extra, bass, snare1, ride, crash)
    A1 = PatternA(*stuff)
    A2 = PatternA(*stuff)
    B = PatternB(*stuff)
    B1 = PatternB1(*stuff)
    B2 = PatternB2(*stuff)
    C = PatternC(*stuff)
    C1 = PatternC1(*stuff)
    C2 = PatternC2(*stuff)
    P1 = PatternPow(extra, range(2, 4), *stuff)
    P2 = PatternPow(extra, range(4), *stuff)
    E = PatternEnding(extra)

    ticks = [(Silence() if args.no_ticks else Tick(),)]
    intro = [(I,) for i in range(4)]
    verse = [(A1,) for i in range(8)] + [(A2,) for i in range(2)]
    chorus = [(B, B1) for i in range(3)] + [(B, B1, P1)]
    chorus += [(B2,) for i in range(2)]
    chorus += [(B,), (B, P2)]
    chorus += [(B,) for i in range(2)]
    instr = [(C, C1) for i in range(8)] + [(C, C2, P1) for i in range(2)]
    ending = [(B,), (B, P2)] + [(B,) for i in range(2)] + [(E,)]

    patterns = []
    patterns += ticks + intro
    patterns += (verse + chorus) * 2
    patterns += instr + chorus
    patterns += verse + chorus
    patterns += ending

    if args.hack:
        patterns = ticks + instr + chorus

    if args.no_drums is False:
        print("Generating drums...")
        seq.run(patterns)

# There seems to be a problem with splat.filters.reverb when used with the gen
#    bgen = splat.gen.TriangleGenerator(master)
#    if args.no_reverb is False:
#        print("Reverb enabled")
#        breverbd = splat.filters.reverb_delays()
#        all_filters.append((splat.filters.reverb, (breverbd,)))

    bgen = splat.gen.TriangleGenerator()
    all_filters = [(splat.filters.linear_fade, (0.05,)),]
    bgen.filters = all_filters
    bgen.levels = (-9.0, -9.0)

    BA_A = BassPatternA(bgen, ('A', None, None, None), 1.75)
    BA_C = BassPatternA(bgen, ('C', None, 'C', None))
    BA_E1 = BassPatternA(bgen, ('E', None, 'Eb', 'Eb'))
    BA_E2 = BassPatternA(bgen, ('E', 'E', 'E', 'E'))
    BA_E3 = BassPatternA(bgen, ('E', 'E', 'E', 'Eb'))
    BA_D = BassPatternA(bgen, ('D', None, 'D', None))
    BB_1 = BassPatternB(bgen, ((('A', 1.0),),
                               (('A', 1.0), ('E', 1.0)),
                               ((None, 1.0), ('A', 1.0)),
                               (('C', 1.0), ('A', 1.0))))
    BB_2 = BassPatternB(bgen, ((('A', 1.0), ('A', 0.75)),
                               (('E', 1.0),),
                               ((None, 1.0), ('A', 1.0)),
                               (('C', 0.75), ('G', 0.75))))
    BB_3 = BassPatternB(bgen, ((('E', 1.0), ('E', 1.0)),
                               (('E', 1.0), ('E', 1.0)),
                               (('G#', 1.0),),
                               (('G#', 1.0),)))
    BC = BassPatternSpline(16, -6.0, bgen, ('B', 'B', 'Bb', 'Bb', 'A'))
    BI_1 = BassPatternSpline(12, -14.0, bgen,
                             ('Bb2', 'D2', 'A2', 'Db2', 'Ab1'))
    BI_2 = BassPatternSpline(12, -14.0, bgen,
                             ('Ab1', 'C2', 'G1', 'B1', 'E'))
    BI_3 = BassPatternSpline(8, -10.0, bgen,
                             ('Bb1', 'D1', 'A1', 'Db1', 'Ab0'))
    BI_4 = BassPatternSpline(8, -10.0, bgen,
                             ('Ab0', 'C1', 'G0', 'B0', 'E'))
    bintro = [(S,), (BA_E2,), (BA_E3,), (BA_D,), (BA_C,)]
    bverse = 2 * ([(BA_A,) for i in range(2)] + [(BA_C,), (BA_E1,)])
    bverse += [(BA_A,) for i in range(2)]
    bchorus = [(BA_E2,) for i in range(3)] + [(BA_E3,)]
    bchorus += [(BA_D,) for i in range(2)] + [(BA_C,), (BC,)]
    bchorus += [(BA_A,) for i in range(2)]
    binstr = [(BB_1,) for i in range(4)]
    binstr += [(BB_2,) for i in range(3)] + [(BB_3,)]
    binstr += [(BI_1, BI_3), (BI_2, BI_4)]
    bending = [(BA_C,), (BC,)] + [(BA_A,) for i in range(2)]

    bpatterns = bintro
    bpatterns += (bverse + bchorus) * 2
    bpatterns += binstr + bchorus
    bpatterns += bverse + bchorus
    bpatterns += bending

    if args.hack:
        print("Hack enabled on bass patterns")
        bpatterns = [(S,)] + binstr + bchorus

    if args.no_bass is False:
        print("Generating bass")
        seq.run(bpatterns)

        # This dance is to work around the issue when reverb is used by gen
        if args.no_reverb is False:
            print("Reverb...")
            breverbd = splat.filters.reverb_delays()
            splat.filters.reverb(bgen.frag, breverbd)
        print("Mixing...")
        master.mix(bgen.frag)

    master.normalize()
    print("Saving output: {}".format(args.output))
    master.save(args.output)

    return True

if __name__ == '__main__':
    ret = main(sys.argv)
    sys.exit(0 if ret is True else 1)
