From b5ecf9ad69530b136f7c988b2564804e42fd7288 Mon Sep 17 00:00:00 2001 From: Fernandez Ludovic Date: Thu, 18 Apr 2024 17:50:14 +0200 Subject: [PATCH] chore: refactor main function --- cmd/misspell/main.go | 354 ++++++++++++++++++++++++------------------- 1 file changed, 201 insertions(+), 153 deletions(-) diff --git a/cmd/misspell/main.go b/cmd/misspell/main.go index 0a47444..2d81170 100644 --- a/cmd/misspell/main.go +++ b/cmd/misspell/main.go @@ -18,16 +18,6 @@ import ( "github.com/golangci/misspell" ) -var ( - defaultWrite *template.Template - defaultRead *template.Template - - stdout *log.Logger - debug *log.Logger - - version = "dev" -) - const ( outputFormatCSV = "csv" outputFormatSQLite = "sqlite" @@ -50,58 +40,17 @@ CREATE TABLE misspell( sqliteFooter = "COMMIT;" ) -func worker(writeit bool, r *misspell.Replacer, mode string, files <-chan string, results chan<- int) { - count := 0 - for filename := range files { - orig, err := misspell.ReadTextFile(filename) - if err != nil { - log.Println(err) - continue - } - if orig == "" { - continue - } - - debug.Printf("Processing %s", filename) - - var updated string - var changes []misspell.Diff +var version = "dev" - if mode == "go" { - updated, changes = r.ReplaceGo(orig) - } else { - updated, changes = r.Replace(orig) - } - - if len(changes) == 0 { - continue - } - count += len(changes) - for _, diff := range changes { - // add in filename - diff.Filename = filename - - // output can be done by doing multiple goroutines - // and can clobber os.Stdout. - // - // the log package can be used simultaneously from multiple goroutines - var output bytes.Buffer - if writeit { - defaultWrite.Execute(&output, diff) - } else { - defaultRead.Execute(&output, diff) - } - - // goroutine-safe print to os.Stdout - stdout.Println(output.String()) - } +var ( + output *log.Logger + debug *log.Logger +) - if writeit { - os.WriteFile(filename, []byte(updated), 0) - } - } - results <- count -} +var ( + defaultWrite *template.Template + defaultRead *template.Template +) //nolint:funlen,nestif,gocognit,gocyclo,maintidx // TODO(ldez) must be fixed. func main() { @@ -129,16 +78,36 @@ func main() { fmt.Println(version) return } + if *showLegal { fmt.Println(misspell.Legal) return } + + // + // Number of Workers / CPU to use + // + if *workers < 0 { + log.Fatalf("-j must >= 0") + } + if *workers == 0 { + *workers = runtime.NumCPU() + } if *debugFlag { - debug = log.New(os.Stderr, "DEBUG ", 0) - } else { - debug = log.New(io.Discard, "", 0) + *workers = 1 } + // + // Source input mode + // + switch *mode { + case "auto", "go", "text": + default: + log.Fatalf("Mode must be one of auto=guess, go=golang source, text=plain or markdown-like text") + } + + debug = newDebugLogger(*debugFlag) + r := misspell.Replacer{ Replacements: misspell.DictMain, Debug: *debugFlag, @@ -164,25 +133,11 @@ func main() { // Load user defined words // if *userDictPath != "" { - file, err := os.Open(*userDictPath) - if err != nil { - log.Fatalf("Failed to load user defined corrections: %v, err: %v", *userDictPath, err) - } - defer file.Close() - - reader := csv.NewReader(file) - reader.FieldsPerRecord = 2 - - data, err := reader.ReadAll() + userDict, err := readUserDict(*userDictPath) if err != nil { log.Fatalf("reading user defined corrections: %v", err) } - var userDict []string - for _, row := range data { - userDict = append(userDict, row...) - } - r.AddRuleList(userDict) } @@ -194,75 +149,26 @@ func main() { } // - // Source input mode + // Output logger // - switch *mode { - case "auto": - case "go": - case "text": - default: - log.Fatalf("Mode must be one of auto=guess, go=golang source, text=plain or markdown-like text") - } - - // We can't just write to os.Stdout directly - // since we have multiple goroutine all writing at the same time causing broken output. - // Log is routine safe. - // We see it, so it doesn't use a prefix or include a time stamp. - switch { - case *quietFlag || *outFlag == os.DevNull: - stdout = log.New(io.Discard, "", 0) - case *outFlag == "/dev/stderr" || *outFlag == "stderr": - stdout = log.New(os.Stderr, "", 0) - case *outFlag == "/dev/stdout" || *outFlag == "stdout": - stdout = log.New(os.Stdout, "", 0) - case *outFlag == "" || *outFlag == "-": - stdout = log.New(os.Stdout, "", 0) - default: - fo, err := os.Create(*outFlag) - if err != nil { - log.Fatalf("unable to create outfile %q: %s", *outFlag, err) - } - defer fo.Close() - stdout = log.New(fo, "", 0) - } + var cleanup func() error + output, cleanup = newLogger(*quietFlag, *outFlag) + defer func() { _ = cleanup() }() // - // Custom output + // Custom output format // + var err error + defaultWrite, defaultRead, err = createTemplates(*format) + if err != nil { + log.Fatal(err) + } + switch { case *format == outputFormatCSV: - tmpl := template.Must(template.New(outputFormatCSV).Parse(csvTmpl)) - defaultWrite = tmpl - defaultRead = tmpl - stdout.Println(csvHeader) + output.Println(csvHeader) case *format == outputFormatSQLite || *format == outputFormatSQLite3: - tmpl := template.Must(template.New(outputFormatSQLite3).Parse(sqliteTmpl)) - defaultWrite = tmpl - defaultRead = tmpl - stdout.Println(sqliteHeader) - case *format != "": - t, err := template.New("custom").Parse(*format) - if err != nil { - log.Fatalf("Unable to compile log format: %s", err) - } - defaultWrite = t - defaultRead = t - default: // format == "" - defaultWrite = template.Must(template.New("defaultWrite").Parse(defaultWriteTmpl)) - defaultRead = template.Must(template.New("defaultRead").Parse(defaultReadTmpl)) - } - - // - // Number of Workers / CPU to use - // - if *workers < 0 { - log.Fatalf("-j must >= 0") - } - if *workers == 0 { - *workers = runtime.NumCPU() - } - if *debugFlag { - *workers = 1 + output.Println(sqliteHeader) } // Done with Flags. @@ -276,8 +182,8 @@ func main() { if len(args) == 0 { // If we are working with pipes/stdin/stdout there is no concurrency, // so we can directly send data to the writers. - var fileout io.Writer - var errout io.Writer + var fileOut io.Writer + var errOut io.Writer switch *writeit { case true: // If we are writing the corrected stream, @@ -285,13 +191,13 @@ func main() { // and the misspelling errors goes to stderr, // so we can do something like this: // curl something | misspell -w | gzip > afile.gz - fileout = os.Stdout - errout = os.Stderr + fileOut = os.Stdout + errOut = os.Stderr case false: // If we are not writing out the corrected stream then work just like files. // Misspelling errors are sent to stdout. - fileout = io.Discard - errout = os.Stdout + fileOut = io.Discard + errOut = os.Stdout } count := 0 @@ -302,27 +208,33 @@ func main() { if *quietFlag { return } + diff.Filename = "stdin" + if *writeit { - defaultWrite.Execute(errout, diff) + defaultWrite.Execute(errOut, diff) } else { - defaultRead.Execute(errout, diff) + defaultRead.Execute(errOut, diff) } - errout.Write([]byte{'\n'}) + + errOut.Write([]byte{'\n'}) } - err := r.ReplaceReader(os.Stdin, fileout, next) + err := r.ReplaceReader(os.Stdin, fileOut, next) if err != nil { - os.Exit(1) + log.Fatal(err) } + switch *format { case outputFormatSQLite, outputFormatSQLite3: - fileout.Write([]byte(sqliteFooter)) + fileOut.Write([]byte(sqliteFooter)) } + if count != 0 && *exitError { // error os.Exit(2) } + return } @@ -351,10 +263,146 @@ func main() { switch *format { case outputFormatSQLite, outputFormatSQLite3: - stdout.Println(sqliteFooter) + output.Println(sqliteFooter) } if count != 0 && *exitError { os.Exit(2) } } + +func worker(writeit bool, r *misspell.Replacer, mode string, files <-chan string, results chan<- int) { + count := 0 + for filename := range files { + orig, err := misspell.ReadTextFile(filename) + if err != nil { + log.Println(err) + continue + } + + if orig == "" { + continue + } + + debug.Printf("Processing %s", filename) + + var updated string + var changes []misspell.Diff + + if mode == "go" { + updated, changes = r.ReplaceGo(orig) + } else { + updated, changes = r.Replace(orig) + } + + if len(changes) == 0 { + continue + } + + count += len(changes) + + for _, diff := range changes { + // add in filename + diff.Filename = filename + + // Output can be done by doing multiple goroutines + // and can clobber os.Stdout. + // + // the log package can be used simultaneously from multiple goroutines + var buffer bytes.Buffer + if writeit { + defaultWrite.Execute(&buffer, diff) + } else { + defaultRead.Execute(&buffer, diff) + } + + // goroutine-safe print to os.Stdout + output.Println(buffer.String()) + } + + if writeit { + os.WriteFile(filename, []byte(updated), 0) + } + } + results <- count +} + +func readUserDict(userDictPath string) ([]string, error) { + file, err := os.Open(userDictPath) + if err != nil { + return nil, fmt.Errorf("failed to load user defined corrections %q: %w", userDictPath, err) + } + defer func() { _ = file.Close() }() + + reader := csv.NewReader(file) + reader.FieldsPerRecord = 2 + + data, err := reader.ReadAll() + if err != nil { + return nil, fmt.Errorf("reading user defined corrections: %w", err) + } + + var userDict []string + for _, row := range data { + userDict = append(userDict, row...) + } + + return userDict, nil +} + +func createTemplates(format string) (writeTmpl, readTmpl *template.Template, err error) { + switch { + case format == outputFormatCSV: + tmpl := template.Must(template.New(outputFormatCSV).Parse(csvTmpl)) + return tmpl, tmpl, nil + + case format == outputFormatSQLite || format == outputFormatSQLite3: + tmpl := template.Must(template.New(outputFormatSQLite3).Parse(sqliteTmpl)) + return tmpl, tmpl, nil + + case format != "": + tmpl, err := template.New("custom").Parse(format) + if err != nil { + return nil, nil, fmt.Errorf("unable to compile log format: %w", err) + } + return tmpl, tmpl, nil + + default: // format == "" + writeTmpl = template.Must(template.New("defaultWrite").Parse(defaultWriteTmpl)) + readTmpl = template.Must(template.New("defaultRead").Parse(defaultReadTmpl)) + return + } +} + +func newLogger(quiet bool, outputPath string) (logger *log.Logger, cleanup func() error) { + // We can't just write to os.Stdout directly + // since we have multiple goroutine all writing at the same time causing broken output. + // Log is routine safe. + // We see it, so it doesn't use a prefix or include a time stamp. + switch { + case quiet || outputPath == os.DevNull: + logger = log.New(io.Discard, "", 0) + case outputPath == "/dev/stderr" || outputPath == "stderr": + logger = log.New(os.Stderr, "", 0) + case outputPath == "/dev/stdout" || outputPath == "stdout": + logger = log.New(os.Stdout, "", 0) + case outputPath == "" || outputPath == "-": + logger = log.New(os.Stdout, "", 0) + default: + fo, err := os.Create(outputPath) + if err != nil { + log.Fatalf("unable to create outfile %q: %s", outputPath, err) + } + return log.New(fo, "", 0), fo.Close + } + + return logger, func() error { return nil } +} + +func newDebugLogger(enable bool) *log.Logger { + if enable { + return log.New(os.Stderr, "DEBUG ", 0) + } + + return log.New(io.Discard, "", 0) +}