package logrus

import (
	"bytes"
	"encoding/json"
	"io/ioutil"
	"strconv"
	"strings"
	"sync"
	"testing"
	"time"

	"github.com/stretchr/testify/assert"
)

func LogAndAssertJSON(t *testing.T, log func(*Logger), assertions func(fields Fields)) {
	var buffer bytes.Buffer
	var fields Fields

	logger := New()
	logger.Out = &buffer
	logger.Formatter = new(JSONFormatter)

	log(logger)

	err := json.Unmarshal(buffer.Bytes(), &fields)
	assert.Nil(t, err)

	assertions(fields)
}

func LogAndAssertText(t *testing.T, log func(*Logger), assertions func(fields map[string]string)) {
	var buffer bytes.Buffer

	logger := New()
	logger.Out = &buffer
	logger.Formatter = &TextFormatter{
		DisableColors: true,
	}

	log(logger)

	fields := make(map[string]string)
	for _, kv := range strings.Split(buffer.String(), " ") {
		if !strings.Contains(kv, "=") {
			continue
		}
		kvArr := strings.Split(kv, "=")
		key := strings.TrimSpace(kvArr[0])
		val := kvArr[1]
		if kvArr[1][0] == '"' {
			var err error
			val, err = strconv.Unquote(val)
			assert.NoError(t, err)
		}
		fields[key] = val
	}
	assertions(fields)
}

func TestPrint(t *testing.T) {
	LogAndAssertJSON(t, func(log *Logger) {
		log.Print("test")
	}, func(fields Fields) {
		assert.Equal(t, fields["msg"], "test")
		assert.Equal(t, fields["level"], "info")
	})
}

func TestInfo(t *testing.T) {
	LogAndAssertJSON(t, func(log *Logger) {
		log.Info("test")
	}, func(fields Fields) {
		assert.Equal(t, fields["msg"], "test")
		assert.Equal(t, fields["level"], "info")
	})
}

func TestWarn(t *testing.T) {
	LogAndAssertJSON(t, func(log *Logger) {
		log.Warn("test")
	}, func(fields Fields) {
		assert.Equal(t, fields["msg"], "test")
		assert.Equal(t, fields["level"], "warning")
	})
}

func TestInfolnShouldAddSpacesBetweenStrings(t *testing.T) {
	LogAndAssertJSON(t, func(log *Logger) {
		log.Infoln("test", "test")
	}, func(fields Fields) {
		assert.Equal(t, fields["msg"], "test test")
	})
}

func TestInfolnShouldAddSpacesBetweenStringAndNonstring(t *testing.T) {
	LogAndAssertJSON(t, func(log *Logger) {
		log.Infoln("test", 10)
	}, func(fields Fields) {
		assert.Equal(t, fields["msg"], "test 10")
	})
}

func TestInfolnShouldAddSpacesBetweenTwoNonStrings(t *testing.T) {
	LogAndAssertJSON(t, func(log *Logger) {
		log.Infoln(10, 10)
	}, func(fields Fields) {
		assert.Equal(t, fields["msg"], "10 10")
	})
}

func TestInfoShouldAddSpacesBetweenTwoNonStrings(t *testing.T) {
	LogAndAssertJSON(t, func(log *Logger) {
		log.Infoln(10, 10)
	}, func(fields Fields) {
		assert.Equal(t, fields["msg"], "10 10")
	})
}

func TestInfoShouldNotAddSpacesBetweenStringAndNonstring(t *testing.T) {
	LogAndAssertJSON(t, func(log *Logger) {
		log.Info("test", 10)
	}, func(fields Fields) {
		assert.Equal(t, fields["msg"], "test10")
	})
}

func TestInfoShouldNotAddSpacesBetweenStrings(t *testing.T) {
	LogAndAssertJSON(t, func(log *Logger) {
		log.Info("test", "test")
	}, func(fields Fields) {
		assert.Equal(t, fields["msg"], "testtest")
	})
}

func TestWithFieldsShouldAllowAssignments(t *testing.T) {
	var buffer bytes.Buffer
	var fields Fields

	logger := New()
	logger.Out = &buffer
	logger.Formatter = new(JSONFormatter)

	localLog := logger.WithFields(Fields{
		"key1": "value1",
	})

	localLog.WithField("key2", "value2").Info("test")
	err := json.Unmarshal(buffer.Bytes(), &fields)
	assert.Nil(t, err)

	assert.Equal(t, "value2", fields["key2"])
	assert.Equal(t, "value1", fields["key1"])

	buffer = bytes.Buffer{}
	fields = Fields{}
	localLog.Info("test")
	err = json.Unmarshal(buffer.Bytes(), &fields)
	assert.Nil(t, err)

	_, ok := fields["key2"]
	assert.Equal(t, false, ok)
	assert.Equal(t, "value1", fields["key1"])
}

func TestUserSuppliedFieldDoesNotOverwriteDefaults(t *testing.T) {
	LogAndAssertJSON(t, func(log *Logger) {
		log.WithField("msg", "hello").Info("test")
	}, func(fields Fields) {
		assert.Equal(t, fields["msg"], "test")
	})
}

func TestUserSuppliedMsgFieldHasPrefix(t *testing.T) {
	LogAndAssertJSON(t, func(log *Logger) {
		log.WithField("msg", "hello").Info("test")
	}, func(fields Fields) {
		assert.Equal(t, fields["msg"], "test")
		assert.Equal(t, fields["fields.msg"], "hello")
	})
}

func TestUserSuppliedTimeFieldHasPrefix(t *testing.T) {
	LogAndAssertJSON(t, func(log *Logger) {
		log.WithField("time", "hello").Info("test")
	}, func(fields Fields) {
		assert.Equal(t, fields["fields.time"], "hello")
	})
}

func TestUserSuppliedLevelFieldHasPrefix(t *testing.T) {
	LogAndAssertJSON(t, func(log *Logger) {
		log.WithField("level", 1).Info("test")
	}, func(fields Fields) {
		assert.Equal(t, fields["level"], "info")
		assert.Equal(t, fields["fields.level"], 1.0) // JSON has floats only
	})
}

func TestDefaultFieldsAreNotPrefixed(t *testing.T) {
	LogAndAssertText(t, func(log *Logger) {
		ll := log.WithField("herp", "derp")
		ll.Info("hello")
		ll.Info("bye")
	}, func(fields map[string]string) {
		for _, fieldName := range []string{"fields.level", "fields.time", "fields.msg"} {
			if _, ok := fields[fieldName]; ok {
				t.Fatalf("should not have prefixed %q: %v", fieldName, fields)
			}
		}
	})
}

func TestWithTimeShouldOverrideTime(t *testing.T) {
	now := time.Now().Add(24 * time.Hour)

	LogAndAssertJSON(t, func(log *Logger) {
		log.WithTime(now).Info("foobar")
	}, func(fields Fields) {
		assert.Equal(t, fields["time"], now.Format(defaultTimestampFormat))
	})
}

func TestWithTimeShouldNotOverrideFields(t *testing.T) {
	now := time.Now().Add(24 * time.Hour)

	LogAndAssertJSON(t, func(log *Logger) {
		log.WithField("herp", "derp").WithTime(now).Info("blah")
	}, func(fields Fields) {
		assert.Equal(t, fields["time"], now.Format(defaultTimestampFormat))
		assert.Equal(t, fields["herp"], "derp")
	})
}

func TestWithFieldShouldNotOverrideTime(t *testing.T) {
	now := time.Now().Add(24 * time.Hour)

	LogAndAssertJSON(t, func(log *Logger) {
		log.WithTime(now).WithField("herp", "derp").Info("blah")
	}, func(fields Fields) {
		assert.Equal(t, fields["time"], now.Format(defaultTimestampFormat))
		assert.Equal(t, fields["herp"], "derp")
	})
}

func TestTimeOverrideMultipleLogs(t *testing.T) {
	var buffer bytes.Buffer
	var firstFields, secondFields Fields

	logger := New()
	logger.Out = &buffer
	formatter := new(JSONFormatter)
	formatter.TimestampFormat = time.StampMilli
	logger.Formatter = formatter

	llog := logger.WithField("herp", "derp")
	llog.Info("foo")

	err := json.Unmarshal(buffer.Bytes(), &firstFields)
	assert.NoError(t, err, "should have decoded first message")

	buffer.Reset()

	time.Sleep(10 * time.Millisecond)
	llog.Info("bar")

	err = json.Unmarshal(buffer.Bytes(), &secondFields)
	assert.NoError(t, err, "should have decoded second message")

	assert.NotEqual(t, firstFields["time"], secondFields["time"], "timestamps should not be equal")
}

func TestDoubleLoggingDoesntPrefixPreviousFields(t *testing.T) {

	var buffer bytes.Buffer
	var fields Fields

	logger := New()
	logger.Out = &buffer
	logger.Formatter = new(JSONFormatter)

	llog := logger.WithField("context", "eating raw fish")

	llog.Info("looks delicious")

	err := json.Unmarshal(buffer.Bytes(), &fields)
	assert.NoError(t, err, "should have decoded first message")
	assert.Equal(t, len(fields), 4, "should only have msg/time/level/context fields")
	assert.Equal(t, fields["msg"], "looks delicious")
	assert.Equal(t, fields["context"], "eating raw fish")

	buffer.Reset()

	llog.Warn("omg it is!")

	err = json.Unmarshal(buffer.Bytes(), &fields)
	assert.NoError(t, err, "should have decoded second message")
	assert.Equal(t, len(fields), 4, "should only have msg/time/level/context fields")
	assert.Equal(t, fields["msg"], "omg it is!")
	assert.Equal(t, fields["context"], "eating raw fish")
	assert.Nil(t, fields["fields.msg"], "should not have prefixed previous `msg` entry")

}

func TestConvertLevelToString(t *testing.T) {
	assert.Equal(t, "debug", DebugLevel.String())
	assert.Equal(t, "info", InfoLevel.String())
	assert.Equal(t, "warning", WarnLevel.String())
	assert.Equal(t, "error", ErrorLevel.String())
	assert.Equal(t, "fatal", FatalLevel.String())
	assert.Equal(t, "panic", PanicLevel.String())
}

func TestParseLevel(t *testing.T) {
	l, err := ParseLevel("panic")
	assert.Nil(t, err)
	assert.Equal(t, PanicLevel, l)

	l, err = ParseLevel("PANIC")
	assert.Nil(t, err)
	assert.Equal(t, PanicLevel, l)

	l, err = ParseLevel("fatal")
	assert.Nil(t, err)
	assert.Equal(t, FatalLevel, l)

	l, err = ParseLevel("FATAL")
	assert.Nil(t, err)
	assert.Equal(t, FatalLevel, l)

	l, err = ParseLevel("error")
	assert.Nil(t, err)
	assert.Equal(t, ErrorLevel, l)

	l, err = ParseLevel("ERROR")
	assert.Nil(t, err)
	assert.Equal(t, ErrorLevel, l)

	l, err = ParseLevel("warn")
	assert.Nil(t, err)
	assert.Equal(t, WarnLevel, l)

	l, err = ParseLevel("WARN")
	assert.Nil(t, err)
	assert.Equal(t, WarnLevel, l)

	l, err = ParseLevel("warning")
	assert.Nil(t, err)
	assert.Equal(t, WarnLevel, l)

	l, err = ParseLevel("WARNING")
	assert.Nil(t, err)
	assert.Equal(t, WarnLevel, l)

	l, err = ParseLevel("info")
	assert.Nil(t, err)
	assert.Equal(t, InfoLevel, l)

	l, err = ParseLevel("INFO")
	assert.Nil(t, err)
	assert.Equal(t, InfoLevel, l)

	l, err = ParseLevel("debug")
	assert.Nil(t, err)
	assert.Equal(t, DebugLevel, l)

	l, err = ParseLevel("DEBUG")
	assert.Nil(t, err)
	assert.Equal(t, DebugLevel, l)

	l, err = ParseLevel("invalid")
	assert.Equal(t, "not a valid logrus Level: \"invalid\"", err.Error())
}

func TestGetSetLevelRace(t *testing.T) {
	wg := sync.WaitGroup{}
	for i := 0; i < 100; i++ {
		wg.Add(1)
		go func(i int) {
			defer wg.Done()
			if i%2 == 0 {
				SetLevel(InfoLevel)
			} else {
				GetLevel()
			}
		}(i)

	}
	wg.Wait()
}

func TestLoggingRace(t *testing.T) {
	logger := New()

	var wg sync.WaitGroup
	wg.Add(100)

	for i := 0; i < 100; i++ {
		go func() {
			logger.Info("info")
			wg.Done()
		}()
	}
	wg.Wait()
}

func TestLoggingRaceWithHooksOnEntry(t *testing.T) {
	logger := New()
	hook := new(ModifyHook)
	logger.AddHook(hook)
	entry := logger.WithField("context", "clue")

	var wg sync.WaitGroup
	wg.Add(100)

	for i := 0; i < 100; i++ {
		go func() {
			entry.Info("info")
			wg.Done()
		}()
	}
	wg.Wait()
}

func TestReplaceHooks(t *testing.T) {
	old, cur := &TestHook{}, &TestHook{}

	logger := New()
	logger.SetOutput(ioutil.Discard)
	logger.AddHook(old)

	hooks := make(LevelHooks)
	hooks.Add(cur)
	replaced := logger.ReplaceHooks(hooks)

	logger.Info("test")

	assert.Equal(t, old.Fired, false)
	assert.Equal(t, cur.Fired, true)

	logger.ReplaceHooks(replaced)
	logger.Info("test")
	assert.Equal(t, old.Fired, true)
}

// Compile test
func TestLogrusInterface(t *testing.T) {
	var buffer bytes.Buffer
	fn := func(l FieldLogger) {
		b := l.WithField("key", "value")
		b.Debug("Test")
	}
	// test logger
	logger := New()
	logger.Out = &buffer
	fn(logger)

	// test Entry
	e := logger.WithField("another", "value")
	fn(e)
}

// Implements io.Writer using channels for synchronization, so we can wait on
// the Entry.Writer goroutine to write in a non-racey way. This does assume that
// there is a single call to Logger.Out for each message.
type channelWriter chan []byte

func (cw channelWriter) Write(p []byte) (int, error) {
	cw <- p
	return len(p), nil
}

func TestEntryWriter(t *testing.T) {
	cw := channelWriter(make(chan []byte, 1))
	log := New()
	log.Out = cw
	log.Formatter = new(JSONFormatter)
	log.WithField("foo", "bar").WriterLevel(WarnLevel).Write([]byte("hello\n"))

	bs := <-cw
	var fields Fields
	err := json.Unmarshal(bs, &fields)
	assert.Nil(t, err)
	assert.Equal(t, fields["foo"], "bar")
	assert.Equal(t, fields["level"], "warning")
}

func TestLogLevelEnabled(t *testing.T) {
	log := New()
	log.SetLevel(PanicLevel)
	assert.Equal(t, true, log.IsLevelEnabled(PanicLevel))
	assert.Equal(t, false, log.IsLevelEnabled(FatalLevel))
	assert.Equal(t, false, log.IsLevelEnabled(ErrorLevel))
	assert.Equal(t, false, log.IsLevelEnabled(WarnLevel))
	assert.Equal(t, false, log.IsLevelEnabled(InfoLevel))
	assert.Equal(t, false, log.IsLevelEnabled(DebugLevel))

	log.SetLevel(FatalLevel)
	assert.Equal(t, true, log.IsLevelEnabled(PanicLevel))
	assert.Equal(t, true, log.IsLevelEnabled(FatalLevel))
	assert.Equal(t, false, log.IsLevelEnabled(ErrorLevel))
	assert.Equal(t, false, log.IsLevelEnabled(WarnLevel))
	assert.Equal(t, false, log.IsLevelEnabled(InfoLevel))
	assert.Equal(t, false, log.IsLevelEnabled(DebugLevel))

	log.SetLevel(ErrorLevel)
	assert.Equal(t, true, log.IsLevelEnabled(PanicLevel))
	assert.Equal(t, true, log.IsLevelEnabled(FatalLevel))
	assert.Equal(t, true, log.IsLevelEnabled(ErrorLevel))
	assert.Equal(t, false, log.IsLevelEnabled(WarnLevel))
	assert.Equal(t, false, log.IsLevelEnabled(InfoLevel))
	assert.Equal(t, false, log.IsLevelEnabled(DebugLevel))

	log.SetLevel(WarnLevel)
	assert.Equal(t, true, log.IsLevelEnabled(PanicLevel))
	assert.Equal(t, true, log.IsLevelEnabled(FatalLevel))
	assert.Equal(t, true, log.IsLevelEnabled(ErrorLevel))
	assert.Equal(t, true, log.IsLevelEnabled(WarnLevel))
	assert.Equal(t, false, log.IsLevelEnabled(InfoLevel))
	assert.Equal(t, false, log.IsLevelEnabled(DebugLevel))

	log.SetLevel(InfoLevel)
	assert.Equal(t, true, log.IsLevelEnabled(PanicLevel))
	assert.Equal(t, true, log.IsLevelEnabled(FatalLevel))
	assert.Equal(t, true, log.IsLevelEnabled(ErrorLevel))
	assert.Equal(t, true, log.IsLevelEnabled(WarnLevel))
	assert.Equal(t, true, log.IsLevelEnabled(InfoLevel))
	assert.Equal(t, false, log.IsLevelEnabled(DebugLevel))

	log.SetLevel(DebugLevel)
	assert.Equal(t, true, log.IsLevelEnabled(PanicLevel))
	assert.Equal(t, true, log.IsLevelEnabled(FatalLevel))
	assert.Equal(t, true, log.IsLevelEnabled(ErrorLevel))
	assert.Equal(t, true, log.IsLevelEnabled(WarnLevel))
	assert.Equal(t, true, log.IsLevelEnabled(InfoLevel))
	assert.Equal(t, true, log.IsLevelEnabled(DebugLevel))
}
