// Copyright (c) 2018-2019, Sylabs Inc. All rights reserved.
// This software is licensed under a 3-clause BSD license. Please consult the
// LICENSE.md file distributed with the sources of this project regarding your
// rights to use or distribute this software.

package signing

import (
	"bytes"
	"context"
	"crypto/sha512"
	"encoding/binary"
	"encoding/hex"
	"encoding/json"
	"errors"
	"fmt"
	"net/http"
	"os"

	"github.com/fatih/color"
	"github.com/sylabs/sif/pkg/sif"
	"github.com/sylabs/singularity/internal/pkg/sylog"
	"github.com/sylabs/singularity/pkg/sypgp"
	"golang.org/x/crypto/openpgp"
	"golang.org/x/crypto/openpgp/clearsign"
)

// ErrVerificationFail is the error when the verify fails
var ErrVerificationFail = errors.New("verification failed")

var errNotFound = errors.New("key does not exist in local, or remote keystore")
var errNotFoundLocal = errors.New("key not in local keyring")

// Key is for json formatting.
type Key struct {
	Signer KeyEntity
}

// KeyEntity holds all the key info, used for json output.
type KeyEntity struct {
	Partition   string
	Name        string
	Fingerprint string
	KeyLocal    bool
	KeyCheck    bool
	DataCheck   bool
}

// KeyList is a list of one or more keys.
type KeyList struct {
	Signatures int
	SignerKeys []*Key
}

type signatureLink struct {
	sigIndex   int   // The index of the descriptor with the signature.
	dataIndex  int   // The index of the descriptor of the signed data.
	groupIndex []int // The descriptor index per/group signature.
}

// computeHashStr generates a hash from data object(s) and generates a string
// to be stored in the signature block.
func computeHashStr(fimg *sif.FileImage, descr []*sif.Descriptor) string {
	hash := sha512.New384()
	for _, v := range descr {
		hash.Write(v.GetData(fimg))
	}
	sum := hash.Sum(nil)

	return fmt.Sprintf("SIFHASH:\n%x", sum)
}

// sifAddSignature adds a signature block to a SIF file
func sifAddSignature(fimg *sif.FileImage, groupid, link uint32, fingerprint [20]byte, signature []byte) error {
	// data we need to create a signature descriptor
	siginput := sif.DescriptorInput{
		Datatype: sif.DataSignature,
		Groupid:  groupid,
		Link:     link,
		Fname:    "part-signature",
		Data:     signature,
	}
	siginput.Size = int64(binary.Size(siginput.Data))

	// extra data needed for the creation of a signature descriptor
	err := siginput.SetSignExtra(sif.HashSHA384, hex.EncodeToString(fingerprint[:]))
	if err != nil {
		return err
	}

	// add new signature data object to SIF file
	err = fimg.AddObject(siginput)
	if err != nil {
		return err
	}

	return nil
}

// Copy-paste from sylabs/sif
// datatypeStr returns a string representation of a datatype.
func datatypeStr(dtype sif.Datatype) string {
	switch dtype {
	case sif.DataDeffile:
		return "Def.FILE"
	case sif.DataEnvVar:
		return "Env.Vars"
	case sif.DataLabels:
		return "JSON.Labels"
	case sif.DataPartition:
		return "FS"
	case sif.DataSignature:
		return "Signature"
	case sif.DataGenericJSON:
		return "JSON.Generic"
	case sif.DataGeneric:
		return "Generic/Raw"
	case sif.DataCryptoMessage:
		return "CryptoMessage"
	}
	return "Unknown data-type"
}

func getDataPartitionToSign(fimg *sif.FileImage, dataType sif.Datatype) ([]*sif.Descriptor, error) {
	sylog.Debugf("Looking for: %s partition to sign...", datatypeStr(dataType))
	// We are using ID 0 (skipping ID), because we are looking for all Datatypes,
	// and ID's will limit the search.
	data, _, err := fimg.GetLinkedDescrsByType(uint32(0), dataType)
	if err != nil && err != sif.ErrNotFound {
		return nil, fmt.Errorf("failed to get descr for deffile: %s", err)
	}
	sylog.Debugf("Found %d partitions", len(data))

	return data, nil
}

// descrToSign returns a *sif.Descriptor for every partition to sign determined by argument.
func descrToSign(fimg *sif.FileImage, id uint32, isGroup, signAll bool) ([]*sif.Descriptor, error) {
	descr := make([]*sif.Descriptor, 1)
	var err error

	if signAll {
		descr[0], _, err = fimg.GetPartPrimSys()
		if err != nil {
			return nil, fmt.Errorf("no primary partition found")
		}

		// signableDatatypes is a list of all the signable Datatypes, all
		// but DataSignature, since theres no need to sign a signature.
		signableDatatypes := []sif.Datatype{
			sif.DataDeffile, sif.DataEnvVar,
			sif.DataLabels, sif.DataGenericJSON,
			sif.DataGeneric, sif.DataCryptoMessage,
		}

		for _, datatype := range signableDatatypes {
			data, err := getDataPartitionToSign(fimg, datatype)
			if err != nil {
				return nil, err
			}
			descr = append(descr, data...)
		}
	} else if isGroup {
		var search = sif.Descriptor{
			Groupid: id | sif.DescrGroupMask,
		}
		descr, _, err = fimg.GetFromDescr(search)
		if err != nil {
			return nil, fmt.Errorf("no descriptors found for groupid %d", id)
		}
	} else if id != 0 {
		descr[0], _, err = fimg.GetFromDescrID(id)
		if err != nil {
			return nil, fmt.Errorf("no descriptor found for id %d", id)
		}
	} else {
		descr[0], _, err = fimg.GetPartPrimSys()
		if err != nil {
			return nil, fmt.Errorf("no primary partition found")
		}
	}

	return descr, nil
}

// Sign takes the path of a container and generates an OpenPGP signature block for
// its system partition. Sign uses the private keys found in the default
// location.
func Sign(cpath string, id uint32, isGroup, signAll bool, keyIdx int) error {
	keyring := sypgp.NewHandle("")

	// Load a private key usable for signing
	elist, err := keyring.LoadPrivKeyring()
	if err != nil {
		return fmt.Errorf("could not load private keyring: %s", err)
	}
	if elist == nil {
		return fmt.Errorf("no private keys in keyring. use 'key newpair' to generate a key, or 'key import' to import a private key from gpg")
	}

	var entity *openpgp.Entity
	if keyIdx != -1 { // -k <idx> has been specified
		if keyIdx >= 0 && keyIdx < len(elist) {
			entity = elist[keyIdx]
		} else {
			return fmt.Errorf("specified (-k, --keyidx) key index out of range")
		}
	} else if len(elist) > 1 {
		entity, err = sypgp.SelectPrivKey(elist)
		if err != nil {
			return fmt.Errorf("failed while reading selection: %s", err)
		}
	} else {
		entity = elist[0]
	}

	// Decrypt key if needed
	if entity.PrivateKey.Encrypted {
		sylog.Debugf("Decrypting key...")
		if err = sypgp.DecryptKey(entity, ""); err != nil {
			return fmt.Errorf("could not decrypt private key, wrong password?")
		}
	}

	// load the container
	fimg, err := sif.LoadContainer(cpath, false)
	if err != nil {
		return fmt.Errorf("failed to load sif container file: %s", err)
	}
	defer fimg.UnloadContainer()

	// figure out which descriptor has data to sign
	descr, err := descrToSign(&fimg, id, isGroup, signAll)
	if err != nil {
		return fmt.Errorf("unable to find a signable partition: %s", err)
	}

	for _, de := range descr {
		sylog.Debugf("Signing %s partition...", datatypeStr(de.Datatype))

		sifhash := ""
		if isGroup {
			// If we are signing a group, then include all the descriptors.
			sifhash = computeHashStr(&fimg, descr)
		} else {
			// Otherwise, just sign one partition at a time.
			sifhash = computeHashStr(&fimg, []*sif.Descriptor{de})
		}
		sylog.Debugf("Signing hash: %s\n", sifhash)

		// create an ascii armored signature block
		var signedmsg bytes.Buffer
		plaintext, err := clearsign.Encode(&signedmsg, entity.PrivateKey, nil)
		if err != nil {
			return fmt.Errorf("could not build a signature block: %s", err)
		}
		_, err = plaintext.Write([]byte(sifhash))
		if err != nil {
			return fmt.Errorf("failed writing hash value to signature block: %s", err)
		}
		if err = plaintext.Close(); err != nil {
			return fmt.Errorf("I/O error while wrapping up signature block: %s", err)
		}

		// finally add the signature block (for descr) as a new SIF data object
		var groupid, link uint32
		if isGroup {
			groupid = sif.DescrUnusedGroup
			link = de.Groupid
		} else {
			groupid = de.Groupid
			link = de.ID
		}
		err = sifAddSignature(&fimg, groupid, link, entity.PrimaryKey.Fingerprint, signedmsg.Bytes())
		if err != nil {
			return fmt.Errorf("failed adding signature block to SIF container file: %s", err)
		}

		// If we are signing a group, then only add one signatrue for all
		// the group partitions.
		if isGroup {
			break
		}
	}

	return nil
}

// getSigsAllPart returns a signatureLink for every non-signature partition.
func getSigsAllPart(fimg *sif.FileImage) ([]signatureLink, error) {
	var err error
	var tbl []signatureLink

	// Ensure theres a primary partition.
	_, _, err = fimg.GetPartPrimSys()
	if err != nil {
		return nil, fmt.Errorf("no primary partition found")
	}

	// Loop through all the partitions, (skipping DataSignatures)
	// and collect all the signatures for a data partition.
	for didx, d := range fimg.DescrArr {
		if !d.Used {
			continue
		}
		// No need to verify a signature.
		if d.Datatype == sif.DataSignature {
			continue
		}

		_, idxs, err := fimg.GetLinkedDescrsByType(d.ID, sif.DataSignature)
		if err != nil {
			// If a partition is not signed, print a warning.
			sylog.Warningf("Missing signature for SIF descriptor %d (%s)", didx+1, datatypeStr(d.Datatype))
			continue
		}

		for _, sidx := range idxs {
			tbl = append(tbl, signatureLink{sidx, didx, nil})
		}
	}

	if len(tbl) == 0 {
		return nil, fmt.Errorf("no signature(s) found in image")
	}

	return tbl, nil
}

// getSigsDescr returns a signatureLink for specified descriptor.
func getSigsDescr(fimg *sif.FileImage, id uint32) ([]signatureLink, error) {
	descr := make([]*sif.Descriptor, 1)
	var err error

	descr[0], _, err = fimg.GetFromDescrID(id)
	if err != nil {
		return nil, fmt.Errorf("no descriptor found for id %d", id)
	}

	_, idx, err := fimg.GetLinkedDescrsByType(id, sif.DataSignature)
	if err != nil {
		return nil, fmt.Errorf("no signatures found for id %d", id)
	}

	sigLink := make([]signatureLink, len(idx))

	for i, l := range idx {
		sigLink[i].sigIndex = l
		sigLink[i].dataIndex = int(id) - 1
	}

	return sigLink, nil
}

// getSigsGroup returns a signatureLink for specified group.
func getSigsGroup(fimg *sif.FileImage, id uint32) ([]signatureLink, error) {
	// find descriptors that are part of a signing group.
	search := sif.Descriptor{
		Groupid: id | sif.DescrGroupMask,
	}
	_, dindex, err := fimg.GetFromDescr(search)
	if err != nil {
		return nil, fmt.Errorf("no descriptors found for groupid %v", id)
	}

	// Find signature blocks pointing to specified group.
	search = sif.Descriptor{
		Datatype: sif.DataSignature,
		Link:     id | sif.DescrGroupMask,
	}
	_, sindex, err := fimg.GetFromDescr(search)
	if err != nil {
		return nil, fmt.Errorf("no signatures found for groupid %v", id)
	}

	sigLink := make([]signatureLink, len(sindex), len(dindex))

	for i, s := range sindex {
		sigLink[i].sigIndex = s
		sigLink[i].groupIndex = append(sigLink[i].groupIndex, dindex...)
	}

	return sigLink, nil
}

// return all signatures for "id" being unique or group id
func getSigsForSelection(fimg *sif.FileImage, id uint32, isGroup, verifyAll bool) ([]signatureLink, error) {
	if verifyAll {
		return getSigsAllPart(fimg)
	} else if isGroup {
		return getSigsGroup(fimg, id)
	} else if id != 0 {
		return getSigsDescr(fimg, id)
	}

	return getSigsLinkPrimPart(fimg)
}

// IsSigned Takse a container path (cpath), and will verify that
// container. Returns false if the container is not signed, likewise,
// will return true if the container is signed. Also returns a error
// if one occures, eg. "the container is not signed", or "container is
// signed by a unknown signer".
func IsSigned(ctx context.Context, cpath, keyServerURI string, authToken string) (bool, error) {
	_, noLocalKey, err := Verify(ctx, cpath, keyServerURI, uint32(0), false, false, authToken, false, false)
	if err != nil {
		return false, fmt.Errorf("unable to verify container: %s", cpath)
	}
	if noLocalKey {
		sylog.Warningf("Container might not be trusted; run 'singularity verify %s' to show who signed it", cpath)
	} else {
		sylog.Infof("Container is trusted - run 'singularity key list' to list your trusted keys")
	}
	return true, nil
}

// Verify takes a container path (cpath), and look for a verification block
// for a specified descriptor. If found, the signature block is used to verify
// the partition hash against the signer's version. Verify will look for OpenPGP
// keys in the default local keyring, if non is found, it will then looks it up
// from a key server if access is enabled, or if localVerify is false. Returns
// a string of formatted output, or json (if jsonVerify is true), and true, if
// theres no local key matching a signers entity.
func Verify(ctx context.Context, cpath, keyServiceURI string, id uint32, isGroup, verifyAll bool, authToken string, localVerify, jsonVerify bool) (string, bool, error) {
	keyring := sypgp.NewHandle("")

	notLocalKey := false

	fimg, err := sif.LoadContainer(cpath, true)
	if err != nil {
		return "", false, fmt.Errorf("failed to load SIF container file: %s", err)
	}
	defer fimg.UnloadContainer()

	// Get all signature blocks (signatures) for ID/GroupID selected (descr) from SIF file.
	sigsLink, err := getSigsForSelection(&fimg, id, isGroup, verifyAll)
	if err != nil {
		return "", false, fmt.Errorf("error while searching for signature blocks: %s", err)
	}

	// Setup some colors.
	green := color.New(color.FgGreen).SprintFunc()
	yellow := color.New(color.FgYellow).SprintFunc()
	red := color.New(color.FgRed).SprintFunc()

	var fail bool
	var errRet error
	var author string

	var keySigner *Key
	keyEntityList := KeyList{}

	author += fmt.Sprintf("Container is signed by %d key(s):\n\n", len(sigsLink))

	// Loop through the signature link, and find the signatures and
	// corresponding partition.
	for _, part := range sigsLink {
		sifhash := ""
		if isGroup {
			// If we are verifying a group, then collect all
			// the group partitions.
			var groupPart []*sif.Descriptor

			for _, d := range part.groupIndex {
				groupPart = append(groupPart, &fimg.DescrArr[d])
			}
			sifhash = computeHashStr(&fimg, groupPart)
		} else {
			sifhash = computeHashStr(&fimg, []*sif.Descriptor{&fimg.DescrArr[part.dataIndex]})
		}
		sylog.Debugf("Verifying hash: %s\n", sifhash)

		dataCheck := true
		// get the entity fingerprint for the signature block
		fingerprint, err := fimg.DescrArr[part.sigIndex].GetEntityString()
		if err != nil {
			sylog.Errorf("could not get the signing entity fingerprint from partition ID: %d: %s", part.sigIndex, err)
			fail = true
			continue
		}

		verifyPartition := ""
		if isGroup {
			verifyPartition = fmt.Sprintf("group: %d", id)
		} else {
			verifyPartition = datatypeStr(fimg.DescrArr[part.dataIndex].Datatype)
		}
		author += fmt.Sprintf("Verifying partition: %s:\n", verifyPartition)
		author += fingerprint + "\n"

		// Extract hash string from signature block
		data := fimg.DescrArr[part.sigIndex].GetData(&fimg)
		block, _ := clearsign.Decode(data)
		if block == nil {
			sylog.Verbosef("%s signature key (%s) corrupted, unable to read data", red("error:"), fingerprint)
			author += fmt.Sprintf("%-18s Signature corrupted, unable to read data\n\n", red("[FAIL]"))

			keySigner = makeKeyEntity("", verifyPartition, fingerprint, false, false, false)
			keyEntityList.SignerKeys = append(keyEntityList.SignerKeys, keySigner)

			fail = true
			continue
		}

		// (1) try to get identity of signer
		i, local, err := getSignerIdentity(ctx, keyring, &fimg.DescrArr[part.sigIndex], block, data, fingerprint, keyServiceURI, authToken, localVerify)
		if err != nil {
			// use [MISSING] if we get an error we expect
			if err == errNotFound || err == errNotFoundLocal {
				author += fmt.Sprintf("%-18s %s\n", red("[MISSING]"), err)
			} else {
				author += fmt.Sprintf("%-18s %s\n", red("[FAIL]"), err)
			}
			fail = true
		} else {
			prefix := green("[LOCAL]")
			if !local {
				prefix = yellow("[REMOTE]")
				notLocalKey = true
			}

			author += fmt.Sprintf("%-18s %s\n", prefix, i)
		}

		// (2) Verify data integrity by comparing hashes
		if !bytes.Equal(bytes.TrimRight(block.Plaintext, "\n"), []byte(sifhash)) {
			sylog.Verbosef("%s key (%s) hash differs, data may be corrupted", red("error:"), fingerprint)
			author += fmt.Sprintf("%-18s system partition hash differs, data may be corrupted\n", red("[FAIL]"))
			dataCheck = false
			fail = true
		} else {
			author += fmt.Sprintf("%-18s Data integrity verified\n", green("[OK]"))
		}
		author += fmt.Sprintf("\n")

		keySigner = makeKeyEntity(i, verifyPartition, fingerprint, local, true, dataCheck)
		keyEntityList.SignerKeys = append(keyEntityList.SignerKeys, keySigner)

	}

	keyEntityList.Signatures = len(sigsLink)

	if jsonVerify {
		jsonData, err := json.MarshalIndent(keyEntityList, "", "  ")
		if err != nil {
			return "", notLocalKey, fmt.Errorf("unable to parse json: %s", err)
		}
		author = string(jsonData) + "\n"
	}

	if fail {
		errRet = ErrVerificationFail
	}

	return author, notLocalKey, errRet
}

func makeKeyEntity(name, partition, fingerprint string, local, corrupted, dataCheck bool) *Key {
	if name == "" {
		name = "unknown"
	}

	keySigner := &Key{
		Signer: KeyEntity{
			Partition:   partition,
			Name:        name,
			Fingerprint: fingerprint,
			KeyLocal:    local,
			KeyCheck:    corrupted,
			DataCheck:   dataCheck,
		},
	}

	return keySigner
}

// Get first Identity data for convenience
func getFirstIdentity(e *openpgp.Entity) string {
	for _, i := range e.Identities {
		return i.Name
	}
	return ""
}

func getSignerIdentity(ctx context.Context, keyring *sypgp.Handle, v *sif.Descriptor, block *clearsign.Block, data []byte, fingerprint, keyServiceURI, authToken string, local bool) (string, bool, error) {
	// load the public keys available locally from the cache
	elist, err := keyring.LoadPubKeyring()
	if err != nil {
		return "", false, fmt.Errorf("could not load public keyring: %s", err)
	}

	// search local keyring for key that matches signature first
	signer, err := openpgp.CheckDetachedSignature(elist, bytes.NewBuffer(block.Bytes), block.ArmoredSignature.Body)
	if err == nil {
		return getFirstIdentity(signer), true, nil
	}

	// if theres a error, thats probably because we dont have a local key. So download it and try again
	// skip downloading and say we failed
	if local {
		return "", false, errNotFoundLocal
	}

	// this is needed to reset the block objects reader since it is consumed in the last call
	block, _ = clearsign.Decode(data)
	if block == nil {
		return "", false, fmt.Errorf("failed to parse signature block")
	}

	// download the key
	sylog.Verbosef("Key not found in local keyring, checking remote keystore: %s\n", fingerprint[32:])
	netlist, err := sypgp.FetchPubkey(ctx, http.DefaultClient, fingerprint, keyServiceURI, authToken, true)
	if err != nil {
		return "", false, errNotFound
	}

	sylog.Verbosef("Found key in remote keystore: %s", fingerprint[32:])
	// search remote keyring for key that matches signature
	signer, err = openpgp.CheckDetachedSignature(netlist, bytes.NewBuffer(block.Bytes), block.ArmoredSignature.Body)
	if err == nil {
		return getFirstIdentity(signer), false, nil
	}

	return "", false, err
}

// getSigsLinkPrimPart is just like getSigsPrimPart, but returns a []signatureLink
// instead of descriptors.
func getSigsLinkPrimPart(fimg *sif.FileImage) ([]signatureLink, error) {
	_, systemPartID, err := fimg.GetPartPrimSys()
	if err != nil {
		return nil, fmt.Errorf("no primary partition found")
	}

	_, sigIdx, err := fimg.GetLinkedDescrsByType(uint32(systemPartID+1), sif.DataSignature)
	if err != nil {
		return nil, fmt.Errorf("no signatures found for system partition")
	}

	sigLink := make([]signatureLink, len(sigIdx))

	for i, s := range sigIdx {
		sigLink[i].sigIndex = s
		sigLink[i].dataIndex = systemPartID
	}

	return sigLink, nil
}

// return all signatures for the primary partition
func getSigsPrimPart(fimg *sif.FileImage) (sigs []*sif.Descriptor, descr []*sif.Descriptor, err error) {
	descr = make([]*sif.Descriptor, 1)

	descr[0], _, err = fimg.GetPartPrimSys()
	if err != nil {
		return nil, nil, fmt.Errorf("no primary partition found")
	}

	sigs, _, err = fimg.GetLinkedDescrsByType(descr[0].ID, sif.DataSignature)
	if err != nil {
		return nil, nil, fmt.Errorf("no signatures found for system partition")
	}

	return
}

func getSignEntities(fimg *sif.FileImage) ([]string, error) {
	// get all signature blocks (signatures) for ID/GroupID selected (descr) from SIF file
	signatures, _, err := getSigsPrimPart(fimg)
	if err != nil {
		return nil, err
	}

	entities := make([]string, 0, len(signatures))
	for _, v := range signatures {
		fingerprint, err := v.GetEntityString()
		if err != nil {
			return nil, err
		}
		entities = append(entities, fingerprint)
	}

	return entities, nil
}

// GetSignEntities returns all signing entities for an ID/Groupid
func GetSignEntities(cpath string) ([]string, error) {
	fimg, err := sif.LoadContainer(cpath, true)
	if err != nil {
		fimg.UnloadContainer()
		return nil, err
	}
	defer fimg.UnloadContainer()

	return getSignEntities(&fimg)
}

// GetSignEntitiesFp returns all signing entities for an ID/Groupid
func GetSignEntitiesFp(fp *os.File) ([]string, error) {
	fimg, err := sif.LoadContainerFp(fp, true)
	if err != nil {
		return nil, err
	}

	return getSignEntities(&fimg)
}
