// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

//go:build go1.18
// +build go1.18

package vulncheck

import (
	"context"
	"log"
	"os"
	"sort"
	"strings"

	"golang.org/x/tools/go/packages"
	gvc "golang.org/x/tools/gopls/internal/govulncheck"
	"golang.org/x/tools/internal/lsp/command"
	"golang.org/x/vuln/client"
	"golang.org/x/vuln/osv"
	"golang.org/x/vuln/vulncheck"
)

func init() {
	Govulncheck = govulncheck
}

func govulncheck(ctx context.Context, cfg *packages.Config, patterns string) (res command.VulncheckResult, _ error) {
	if patterns == "" {
		patterns = "."
	}

	dbClient, err := client.NewClient(findGOVULNDB(cfg), client.Options{HTTPCache: gvc.DefaultCache()})
	if err != nil {
		return res, err
	}

	c := cmd{Client: dbClient}
	vulns, err := c.Run(ctx, cfg, patterns)
	if err != nil {
		return res, err
	}

	res.Vuln = vulns
	return res, err
}

func findGOVULNDB(cfg *packages.Config) []string {
	for _, kv := range cfg.Env {
		if strings.HasPrefix(kv, "GOVULNDB=") {
			return strings.Split(kv[len("GOVULNDB="):], ",")
		}
	}
	if GOVULNDB := os.Getenv("GOVULNDB"); GOVULNDB != "" {
		return strings.Split(GOVULNDB, ",")
	}
	return []string{"https://vuln.go.dev"}
}

type Vuln = command.Vuln
type CallStack = command.CallStack
type StackEntry = command.StackEntry

// cmd is an in-process govulncheck command runner
// that uses the provided client.Client.
type cmd struct {
	Client client.Client
}

// Run runs the govulncheck after loading packages using the provided packages.Config.
func (c *cmd) Run(ctx context.Context, cfg *packages.Config, patterns ...string) (_ []Vuln, err error) {
	logger := log.New(log.Default().Writer(), "", 0)
	cfg.Mode |= packages.NeedModule | packages.NeedName | packages.NeedFiles |
		packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedTypes |
		packages.NeedTypesSizes | packages.NeedSyntax | packages.NeedTypesInfo | packages.NeedDeps

	logger.Println("loading packages...")
	loadedPkgs, err := gvc.LoadPackages(cfg, patterns...)
	if err != nil {
		logger.Printf("package load failed: %v", err)
		return nil, err
	}

	logger.Printf("analyzing %d packages...\n", len(loadedPkgs))

	r, err := vulncheck.Source(ctx, loadedPkgs, &vulncheck.Config{Client: c.Client, SourceGoVersion: goVersion()})
	if err != nil {
		return nil, err
	}

	logger.Printf("selecting affecting vulnerabilities from %d findings...\n", len(r.Vulns))
	unaffectedMods := filterUnaffected(r.Vulns)
	r.Vulns = filterCalled(r)

	logger.Printf("found %d vulnerabilities.\n", len(r.Vulns))
	callInfo := gvc.GetCallInfo(r, loadedPkgs)
	return toVulns(callInfo, unaffectedMods)
	// TODO: add import graphs.
}

// filterCalled returns vulnerabilities where the symbols are actually called.
func filterCalled(r *vulncheck.Result) []*vulncheck.Vuln {
	var vulns []*vulncheck.Vuln
	for _, v := range r.Vulns {
		if v.CallSink != 0 {
			vulns = append(vulns, v)
		}
	}
	return vulns
}

// filterUnaffected returns vulnerabilities where no symbols are called,
// grouped by module.
func filterUnaffected(vulns []*vulncheck.Vuln) map[string][]*osv.Entry {
	// It is possible that the same vuln.OSV.ID has vuln.CallSink != 0
	// for one symbol, but vuln.CallSink == 0 for a different one, so
	// we need to filter out ones that have been called.
	called := map[string]bool{}
	for _, vuln := range vulns {
		if vuln.CallSink != 0 {
			called[vuln.OSV.ID] = true
		}
	}

	modToIDs := map[string]map[string]*osv.Entry{}
	for _, vuln := range vulns {
		if !called[vuln.OSV.ID] {
			if _, ok := modToIDs[vuln.ModPath]; !ok {
				modToIDs[vuln.ModPath] = map[string]*osv.Entry{}
			}
			// keep only one vuln.OSV instance for the same ID.
			modToIDs[vuln.ModPath][vuln.OSV.ID] = vuln.OSV
		}
	}
	output := map[string][]*osv.Entry{}
	for m, vulnSet := range modToIDs {
		var vulns []*osv.Entry
		for _, vuln := range vulnSet {
			vulns = append(vulns, vuln)
		}
		sort.Slice(vulns, func(i, j int) bool { return vulns[i].ID < vulns[j].ID })
		output[m] = vulns
	}
	return output
}

func fixed(v *osv.Entry) string {
	lf := gvc.LatestFixed(v.Affected)
	if lf != "" && lf[0] != 'v' {
		lf = "v" + lf
	}
	return lf
}

func toVulns(ci *gvc.CallInfo, unaffectedMods map[string][]*osv.Entry) ([]Vuln, error) {
	var vulns []Vuln

	for _, vg := range ci.VulnGroups {
		v0 := vg[0]
		vuln := Vuln{
			ID:             v0.OSV.ID,
			PkgPath:        v0.PkgPath,
			CurrentVersion: ci.ModuleVersions[v0.ModPath],
			FixedVersion:   fixed(v0.OSV),
			Details:        v0.OSV.Details,

			Aliases: v0.OSV.Aliases,
			Symbol:  v0.Symbol,
			ModPath: v0.ModPath,
			URL:     href(v0.OSV),
		}

		// Keep first call stack for each vuln.
		for _, v := range vg {
			if css := ci.CallStacks[v]; len(css) > 0 {
				vuln.CallStacks = append(vuln.CallStacks, toCallStack(css[0]))
				vuln.CallStackSummaries = append(vuln.CallStackSummaries, gvc.SummarizeCallStack(css[0], ci.TopPackages, v.PkgPath))
			}
		}
		vulns = append(vulns, vuln)
	}
	for m, vg := range unaffectedMods {
		for _, v0 := range vg {
			vuln := Vuln{
				ID:             v0.ID,
				Details:        v0.Details,
				Aliases:        v0.Aliases,
				ModPath:        m,
				URL:            href(v0),
				CurrentVersion: "",
				FixedVersion:   fixed(v0),
			}
			vulns = append(vulns, vuln)
		}
	}
	return vulns, nil
}
