// Package sshd provides functionality for handling SSH sessions
package sshd

import (
	"context"
	"errors"
	"fmt"
	"log/slog"
	"reflect"
	"time"

	"gitlab.com/gitlab-org/labkit/fields"
	"gitlab.com/gitlab-org/labkit/v2/log"
	"golang.org/x/crypto/ssh"
	grpccodes "google.golang.org/grpc/codes"
	grpcstatus "google.golang.org/grpc/status"

	shellCmd "gitlab.com/gitlab-org/gitlab-shell/v14/cmd/gitlab-shell/command"
	"gitlab.com/gitlab-org/gitlab-shell/v14/internal/command"
	"gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/readwriter"
	"gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/shared/disallowedcommand"
	"gitlab.com/gitlab-org/gitlab-shell/v14/internal/config"
	"gitlab.com/gitlab-org/gitlab-shell/v14/internal/console"
	"gitlab.com/gitlab-org/gitlab-shell/v14/internal/metrics"
	"gitlab.com/gitlab-org/gitlab-shell/v14/internal/sshenv"
)

type session struct {
	// State set up by the connection
	cfg                 *config.Config
	channel             ssh.Channel
	gitlabKeyID         string
	gitlabKrb5Principal string
	gitlabUsername      string
	namespace           string
	remoteAddr          string

	// State managed by the session
	execCmd            string
	gitProtocolVersion string
	started            time.Time
}

type execRequest struct {
	Command string
}

type envRequest struct {
	Name  string
	Value string
}

type exitStatusReq struct {
	ExitStatus uint32
}

func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) (context.Context, error) {
	ctxWithLogData := ctx

	slog.DebugContext(ctx, "session: handle: entering request loop")

	var err error
	for req := range requests {
		ctx = log.WithFields(ctx,
			slog.Int("bytesize", len(req.Payload)),
			slog.String("type", req.Type),
			slog.Bool("want_reply", req.WantReply),
		)

		slog.DebugContext(ctx, "session: handle: request received")

		var shouldContinue bool
		switch req.Type {
		case "env":
			shouldContinue, err = s.handleEnv(ctx, req)
		case "exec":
			// The command has been executed as `ssh user@host command` or `exec` channel has been used
			// in the app implementation
			shouldContinue = false
			ctxWithLogData, err = s.handleExec(ctx, req)
		case "shell":
			// The command has been entered into the shell or `shell` channel has been used
			// in the app implementation
			shouldContinue = false
			var status uint32
			ctxWithLogData, status, err = s.handleShell(ctx, req)
			s.exit(ctx, status)
		default:
			// Ignore unknown requests but don't terminate the session
			shouldContinue = true

			if req.WantReply {
				if err = req.Reply(false, []byte{}); err != nil {
					slog.DebugContext(ctx, "session: handle: Failed to reply", slog.String(fields.ErrorMessage, err.Error()))
				}
			}
		}

		slog.DebugContext(ctx, "session: handle: request processed", slog.Bool("should_continue", shouldContinue))

		if !shouldContinue {
			_ = s.channel.Close()
			break
		}
	}

	slog.DebugContext(ctx, "session: handle: exiting request loop")

	return ctxWithLogData, err
}

func (s *session) handleEnv(ctx context.Context, req *ssh.Request) (bool, error) {
	var accepted bool
	var envReq envRequest

	if err := ssh.Unmarshal(req.Payload, &envReq); err != nil {
		slog.ErrorContext(ctx, "session: handleEnv: failed to unmarshal request", slog.String(fields.ErrorMessage, err.Error()))
		return false, err
	}

	switch envReq.Name {
	case sshenv.GitProtocolEnv:
		s.gitProtocolVersion = envReq.Value
		accepted = true
	default:
		// Client requested a forbidden envvar, nothing to do
	}

	if req.WantReply {
		if err := req.Reply(accepted, []byte{}); err != nil {
			slog.DebugContext(ctx, "session: handleEnv: Failed to reply", slog.String(fields.ErrorMessage, err.Error()))
		}
	}

	slog.DebugContext(ctx, "session: handleEnv: processed",
		slog.Bool("accepted", accepted),
		slog.String("env_request", fmt.Sprintf("%+v", envReq)),
	)

	return true, nil
}

func (s *session) handleExec(ctx context.Context, req *ssh.Request) (context.Context, error) {
	var execReq execRequest

	if err := ssh.Unmarshal(req.Payload, &execReq); err != nil {
		return ctx, err
	}

	s.execCmd = execReq.Command

	ctxWithLogData, status, err := s.handleShell(ctx, req)
	s.exit(ctxWithLogData, status)

	return ctxWithLogData, err
}

func (s *session) handleShell(ctx context.Context, req *ssh.Request) (context.Context, uint32, error) {
	if req.WantReply {
		if err := req.Reply(true, []byte{}); err != nil {
			slog.DebugContext(ctx, "session: handleShell: Failed to reply", slog.String(fields.ErrorMessage, err.Error()))
		}
	}

	env := sshenv.Env{
		IsSSHConnection:    true,
		OriginalCommand:    s.execCmd,
		GitProtocolVersion: s.gitProtocolVersion,
		RemoteAddr:         s.remoteAddr,
		NamespacePath:      s.namespace,
	}

	countingWriter := &readwriter.CountingWriter{W: s.channel}

	rw := &readwriter.ReadWriter{
		Out:    countingWriter,
		In:     s.channel,
		ErrOut: s.channel.Stderr(),
	}

	cmd, err := s.getCommand(env, rw)

	if err != nil {
		return s.handleCommandError(ctx, err)
	}

	cmdName := reflect.TypeOf(cmd).String()

	establishSessionDuration := time.Since(s.started).Seconds()
	slog.InfoContext(ctx, "session: handleShell: executing command",
		slog.Any("env", env),
		slog.String("command", cmdName),
		slog.Float64("established_session_duration_s", establishSessionDuration),
	)
	metrics.SshdSessionEstablishedDuration.Observe(establishSessionDuration)

	ctxWithLogData, err := cmd.Execute(ctx)

	logData := extractLogDataFromContext(ctxWithLogData)
	logData.WrittenBytes = countingWriter.N

	ctxWithLogData = context.WithValue(ctx, logInfo{}, logData)

	if err != nil {
		grpcStatus := grpcstatus.Convert(err)
		if grpcStatus.Code() != grpccodes.Internal {
			s.toStderr(ctx, "ERROR: %v\n", grpcStatus.Message())
		}

		return ctx, 1, err
	}

	slog.InfoContext(ctx, "session: handleShell: command executed successfully")

	return ctxWithLogData, 0, nil
}

func (s *session) handleCommandError(ctx context.Context, err error) (context.Context, uint32, error) {
	if errors.Is(err, disallowedcommand.Error) {
		s.toStderr(ctx, "ERROR: Unknown command: %v\n", s.execCmd)
	} else {
		s.toStderr(ctx, "ERROR: Failed to parse command: %v\n", err.Error())
	}

	return ctx, 128, err
}

func (s *session) getCommand(env sshenv.Env, rw *readwriter.ReadWriter) (command.Command, error) {
	var cmd command.Command
	var err error

	switch {
	case s.gitlabKrb5Principal != "":
		cmd, err = shellCmd.NewWithKrb5Principal(s.gitlabKrb5Principal, env, s.cfg, rw)
	case s.gitlabUsername != "":
		cmd, err = shellCmd.NewWithUsername(s.gitlabUsername, env, s.cfg, rw)
	default:
		cmd, err = shellCmd.NewWithKey(s.gitlabKeyID, env, s.cfg, rw)
	}
	return cmd, err
}

func (s *session) toStderr(ctx context.Context, format string, args ...interface{}) {
	out := fmt.Sprintf(format, args...)
	slog.DebugContext(ctx, "session: toStderr: output", slog.String("stderr", out))
	console.DisplayWarningMessage(out, s.channel.Stderr())
}

func (s *session) exit(ctx context.Context, status uint32) {
	slog.InfoContext(ctx, "session: exit: exiting", slog.Int("exit_status", int(status)))

	req := exitStatusReq{ExitStatus: status}

	_ = s.channel.CloseWrite()
	_, _ = s.channel.SendRequest("exit-status", false, ssh.Marshal(req))
}
