Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor some things for the needs of github.com/go-chai/chai #1094

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions gen/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@ func New() *Gen {
}
}

type GenConfig struct {
// OutputDir represents the output directory for all the generated files
OutputDir string

// InstanceName is used to get distinct names for different swagger documents in the
// same project. The default value is "swagger".
InstanceName string

// GeneratedTime whether swag should generate the timestamp at the top of docs.go
GeneratedTime bool
}

// Config presents Gen configurations.
type Config struct {
// SearchDir the swag would be parse,comma separated if multiple
Expand Down Expand Up @@ -91,10 +103,6 @@ type Config struct {

// Build builds swagger json file for given searchDir and mainAPIFile. Returns json
func (g *Gen) Build(config *Config) error {
if config.InstanceName == "" {
config.InstanceName = swag.Name
}

searchDirs := strings.Split(config.SearchDir, ",")
for _, searchDir := range searchDirs {
if _, err := os.Stat(searchDir); os.IsNotExist(err) {
Expand Down Expand Up @@ -135,7 +143,23 @@ func (g *Gen) Build(config *Config) error {
if err := p.ParseAPIMultiSearchDir(searchDirs, config.MainAPIFile, config.ParseDepth); err != nil {
return err
}
swagger := p.GetSwagger()

return g.Generate(p.GetSwagger(), &GenConfig{
OutputDir: config.OutputDir,
InstanceName: config.InstanceName,
GeneratedTime: config.GeneratedTime,
})
}

// Generate outputs a swagger spec
func (g *Gen) Generate(swagger *spec.Swagger, config *GenConfig) error {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved out of the Build function the parts that deal with generating the files that store the generated swagger spec, because in my case I don't generate the specs from the swag binary, but I want to create the same files that swag does.

if config.InstanceName == "" {
config.InstanceName = swag.Name
}

if config.OutputDir == "" {
config.OutputDir = "docs/"
}

b, err := g.jsonIndent(swagger)
if err != nil {
Expand Down Expand Up @@ -251,7 +275,7 @@ func parseOverrides(r io.Reader) (map[string]string, error) {
return overrides, nil
}

func (g *Gen) writeGoDoc(packageName string, output io.Writer, swagger *spec.Swagger, config *Config) error {
func (g *Gen) writeGoDoc(packageName string, output io.Writer, swagger *spec.Swagger, config *GenConfig) error {
generator, err := template.New("swagger_info").Funcs(template.FuncMap{
"printDoc": func(v string) string {
// Add schemes
Expand Down
8 changes: 4 additions & 4 deletions gen/gen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ func TestGen_writeGoDoc(t *testing.T) {
swapTemplate := packageTemplate

packageTemplate = `{{{`
err := gen.writeGoDoc("docs", nil, nil, &Config{})
err := gen.writeGoDoc("docs", nil, nil, &GenConfig{})
assert.Error(t, err)

packageTemplate = `{{.Data}}`
Expand All @@ -371,7 +371,7 @@ func TestGen_writeGoDoc(t *testing.T) {
Info: &spec.Info{},
},
}
err = gen.writeGoDoc("docs", &mockWriter{}, swagger, &Config{})
err = gen.writeGoDoc("docs", &mockWriter{}, swagger, &GenConfig{})
assert.Error(t, err)

packageTemplate = `{{ if .GeneratedTime }}Fake Time{{ end }}`
Expand All @@ -380,14 +380,14 @@ func TestGen_writeGoDoc(t *testing.T) {
hook: func(data []byte) {
assert.Equal(t, "Fake Time", string(data))
},
}, swagger, &Config{GeneratedTime: true})
}, swagger, &GenConfig{GeneratedTime: true})
assert.NoError(t, err)
err = gen.writeGoDoc("docs",
&mockWriter{
hook: func(data []byte) {
assert.Equal(t, "", string(data))
},
}, swagger, &Config{GeneratedTime: false})
}, swagger, &GenConfig{GeneratedTime: false})
assert.NoError(t, err)

packageTemplate = swapTemplate
Expand Down
4 changes: 4 additions & 0 deletions operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,10 @@ func (operation *Operation) parseCombinedObjectSchema(refType string, astFile *a
}), nil
}

func (operation *Operation) ParseAPIObjectSchema(schemaType, refType string, astFile *ast.File) (*spec.Schema, error) {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to use this function for creating a swagger schema for a type name

return operation.parseObjectSchema(refType, astFile)
}

func (operation *Operation) parseAPIObjectSchema(schemaType, refType string, astFile *ast.File) (*spec.Schema, error) {
switch schemaType {
case OBJECT:
Expand Down
60 changes: 47 additions & 13 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -1281,21 +1281,55 @@ func defineTypeOfExample(schemaType, arrayType, exampleValue string) (interface{
return nil, fmt.Errorf("%s is unsupported type in example value %s", schemaType, exampleValue)
}

// GetAllGoFileInfo gets all Go source files information for given searchDir.
func (parser *Parser) getAllGoFileInfo(packageDir, searchDir string) error {
return filepath.Walk(searchDir, func(path string, f os.FileInfo, _ error) error {
if err := parser.Skip(path, f); err != nil {
// GetAllGoFileInfoAndParseTypes gets all Go source files information for given searchDir and parses the types from them.
func (parser *Parser) GetAllGoFileInfoAndParseTypes(searchDir string) error {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need this in order to be able to go through packages one by one and incrementally parse the types from them.

I saw that ParseAPIMultiSearchDir() first collects all types from the astFiles and then parses them all at once via parser.packages.ParseTypes(), but in my case that would mean having to do a double pass through all routes, which I'd rather not have to do and instead have a function that parses the types of a single package.

return filepath.Walk(searchDir, func(path string, f os.FileInfo, e error) error {
astFile, err := parser.getGoFileInfo(searchDir, searchDir, path, f, e)
if err != nil {
return err
} else if f.IsDir() {
}

if astFile == nil {
return nil
}

relPath, err := filepath.Rel(searchDir, path)
parser.packages.parseTypesFromFile(astFile, searchDir, make(map[*TypeSpecDef]*Schema))

return nil
})
}

func (parser *Parser) getGoFileInfo(packageDir, searchDir string, path string, f os.FileInfo, _ error) (*ast.File, error) {
if err := parser.Skip(path, f); err != nil {
return nil, err
} else if f.IsDir() {
return nil, nil
}

relPath, err := filepath.Rel(searchDir, path)
if err != nil {
return nil, err
}

astFile, err := parser.parseFile(filepath.ToSlash(filepath.Dir(filepath.Clean(filepath.Join(packageDir, relPath)))), path, nil)

if err != nil {
return nil, err
}

return astFile, nil
}

// GetAllGoFileInfo gets all Go source files information for given searchDir.
func (parser *Parser) getAllGoFileInfo(packageDir, searchDir string) error {
return filepath.Walk(searchDir, func(path string, f os.FileInfo, e error) error {
_, err := parser.getGoFileInfo(packageDir, searchDir, path, f, e)

if err != nil {
return err
}

return parser.parseFile(filepath.ToSlash(filepath.Dir(filepath.Clean(filepath.Join(packageDir, relPath)))), path, nil)
return nil
})
}

Expand All @@ -1321,7 +1355,7 @@ func (parser *Parser) getAllGoFileInfoFromDeps(pkg *depth.Pkg) error {
}

path := filepath.Join(srcDir, f.Name())
if err := parser.parseFile(pkg.Name, path, nil); err != nil {
if _, err := parser.parseFile(pkg.Name, path, nil); err != nil {
return err
}
}
Expand All @@ -1335,23 +1369,23 @@ func (parser *Parser) getAllGoFileInfoFromDeps(pkg *depth.Pkg) error {
return nil
}

func (parser *Parser) parseFile(packageDir, path string, src interface{}) error {
func (parser *Parser) parseFile(packageDir, path string, src interface{}) (*ast.File, error) {
if strings.HasSuffix(strings.ToLower(path), "_test.go") || filepath.Ext(path) != ".go" {
return nil
return nil, nil
}

// positions are relative to FileSet
astFile, err := goparser.ParseFile(token.NewFileSet(), path, src, goparser.ParseComments)
if err != nil {
return fmt.Errorf("ParseFile error:%+v", err)
return nil, fmt.Errorf("ParseFile error:%+v", err)
}

err = parser.packages.CollectAstFile(packageDir, path, astFile)
if err != nil {
return err
return nil, err
}

return nil
return astFile, nil
}

func (parser *Parser) checkOperationIDUniqueness() error {
Expand Down