Skip to content

Commit

Permalink
fix: Decompress request body when multi Content-Encoding sent on requ…
Browse files Browse the repository at this point in the history
…est headers (#2555)

* 🔧 feat: Decode body in order when sent a list on content-encoding

* 🚀 perf: Change `getSplicedStrList` to have 0 allocations

* 🍵 test: Add tests for the new features

* 🍵 test: Ensure session test will not raise an error unexpectedly

* 🐗 feat: Replace strings.TrimLeft by utils.TrimLeft

Add docs to functions to inform correctly what the change is

* 🌷 refactor: Apply linter rules

* 🍵 test: Add test cases to the new body method change

* 🔧 feat: Remove return problems to be able to reach original body

* 🌷 refactor: Split Body method into two to make it more maintainable

Also, with the previous fix to problems detected by tests, it becomes really hard to make the linter happy, so this change also helps in it

* 🚀 perf: Came back with Header.VisitAll, to improve speed

* 📃 docs: Update Context docs
  • Loading branch information
Jictyvoo authored Aug 6, 2023
1 parent e91b02b commit f29f39b
Show file tree
Hide file tree
Showing 6 changed files with 395 additions and 67 deletions.
89 changes: 75 additions & 14 deletions ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,31 +260,92 @@ func (c *Ctx) BaseURL() string {
return c.baseURI
}

// Body contains the raw body submitted in a POST request.
// BodyRaw contains the raw body submitted in a POST request.
// Returned value is only valid within the handler. Do not store any references.
// Make copies or use the Immutable setting instead.
func (c *Ctx) BodyRaw() []byte {
return c.fasthttp.Request.Body()
}

func (c *Ctx) tryDecodeBodyInOrder(
originalBody *[]byte,
encodings []string,
) ([]byte, uint8, error) {
var (
err error
body []byte
decodesRealized uint8
)

for index, encoding := range encodings {
decodesRealized++
switch encoding {
case StrGzip:
body, err = c.fasthttp.Request.BodyGunzip()
case StrBr, StrBrotli:
body, err = c.fasthttp.Request.BodyUnbrotli()
case StrDeflate:
body, err = c.fasthttp.Request.BodyInflate()
default:
decodesRealized--
if len(encodings) == 1 {
body = c.fasthttp.Request.Body()
}
return body, decodesRealized, nil
}

if err != nil {
return nil, decodesRealized, err
}

// Only execute body raw update if it has a next iteration to try to decode
if index < len(encodings)-1 && decodesRealized > 0 {
if index == 0 {
tempBody := c.fasthttp.Request.Body()
*originalBody = make([]byte, len(tempBody))
copy(*originalBody, tempBody)
}
c.fasthttp.Request.SetBodyRaw(body)
}
}

return body, decodesRealized, nil
}

// Body contains the raw body submitted in a POST request.
// This method will decompress the body if the 'Content-Encoding' header is provided.
// It returns the original (or decompressed) body data which is valid only within the handler.
// Don't store direct references to the returned data.
// If you need to keep the body's data later, make a copy or use the Immutable option.
func (c *Ctx) Body() []byte {
var err error
var encoding string
var body []byte
var (
err error
body, originalBody []byte
headerEncoding string
encodingOrder = []string{"", "", ""}
)

// faster than peek
c.Request().Header.VisitAll(func(key, value []byte) {
if c.app.getString(key) == HeaderContentEncoding {
encoding = c.app.getString(value)
headerEncoding = c.app.getString(value)
}
})

switch encoding {
case StrGzip:
body, err = c.fasthttp.Request.BodyGunzip()
case StrBr, StrBrotli:
body, err = c.fasthttp.Request.BodyUnbrotli()
case StrDeflate:
body, err = c.fasthttp.Request.BodyInflate()
default:
body = c.fasthttp.Request.Body()
// Split and get the encodings list, in order to attend the
// rule defined at: https://www.rfc-editor.org/rfc/rfc9110#section-8.4-5
encodingOrder = getSplicedStrList(headerEncoding, encodingOrder)
if len(encodingOrder) == 0 {
return c.fasthttp.Request.Body()
}

var decodesRealized uint8
body, decodesRealized, err = c.tryDecodeBodyInOrder(&originalBody, encodingOrder)

// Ensure that the body will be the original
if originalBody != nil && decodesRealized > 0 {
c.fasthttp.Request.SetBodyRaw(originalBody)
}
if err != nil {
return []byte(err.Error())
}
Expand Down
225 changes: 195 additions & 30 deletions ctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"bufio"
"bytes"
"compress/gzip"
"compress/zlib"
"context"
"crypto/tls"
"encoding/xml"
Expand Down Expand Up @@ -323,47 +324,211 @@ func Test_Ctx_Body(t *testing.T) {
utils.AssertEqual(t, []byte("john=doe"), c.Body())
}

// go test -run Test_Ctx_Body_With_Compression
func Test_Ctx_Body_With_Compression(t *testing.T) {
t.Parallel()
func Benchmark_Ctx_Body(b *testing.B) {
const input = "john=doe"

app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.Request().Header.Set("Content-Encoding", "gzip")
var b bytes.Buffer
gz := gzip.NewWriter(&b)
_, err := gz.Write([]byte("john=doe"))
utils.AssertEqual(t, nil, err)
err = gz.Flush()
utils.AssertEqual(t, nil, err)
err = gz.Close()
utils.AssertEqual(t, nil, err)
c.Request().SetBody(b.Bytes())
utils.AssertEqual(t, []byte("john=doe"), c.Body())

c.Request().SetBody([]byte(input))
for i := 0; i < b.N; i++ {
_ = c.Body()
}

utils.AssertEqual(b, []byte(input), c.Body())
}

// go test -run Test_Ctx_Body_With_Compression
func Test_Ctx_Body_With_Compression(t *testing.T) {
t.Parallel()
tests := []struct {
name string
contentEncoding string
body []byte
expectedBody []byte
}{
{
name: "gzip",
contentEncoding: "gzip",
body: []byte("john=doe"),
expectedBody: []byte("john=doe"),
},
{
name: "unsupported_encoding",
contentEncoding: "undefined",
body: []byte("keeps_ORIGINAL"),
expectedBody: []byte("keeps_ORIGINAL"),
},
{
name: "gzip then unsupported",
contentEncoding: "gzip, undefined",
body: []byte("Go, be gzipped"),
expectedBody: []byte("Go, be gzipped"),
},
{
name: "invalid_deflate",
contentEncoding: "gzip,deflate",
body: []byte("I'm not correctly compressed"),
expectedBody: []byte(zlib.ErrHeader.Error()),
},
}

for _, testObject := range tests {
tCase := testObject // Duplicate object to ensure it will be unique across all runs
t.Run(tCase.name, func(t *testing.T) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.Request().Header.Set("Content-Encoding", tCase.contentEncoding)

if strings.Contains(tCase.contentEncoding, "gzip") {
var b bytes.Buffer
gz := gzip.NewWriter(&b)
_, err := gz.Write(tCase.body)
if err != nil {
t.Fatal(err)
}
if err = gz.Flush(); err != nil {
t.Fatal(err)
}
if err = gz.Close(); err != nil {
t.Fatal(err)
}
tCase.body = b.Bytes()
}

c.Request().SetBody(tCase.body)
body := c.Body()
utils.AssertEqual(t, tCase.expectedBody, body)

// Check if body raw is the same as previous before decompression
utils.AssertEqual(
t, tCase.body, c.Request().Body(),
"Body raw must be the same as set before",
)
})
}
}

// go test -v -run=^$ -bench=Benchmark_Ctx_Body_With_Compression -benchmem -count=4
func Benchmark_Ctx_Body_With_Compression(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.Request().Header.Set("Content-Encoding", "gzip")
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
_, err := gz.Write([]byte("john=doe"))
utils.AssertEqual(b, nil, err)
err = gz.Flush()
utils.AssertEqual(b, nil, err)
err = gz.Close()
utils.AssertEqual(b, nil, err)
encodingErr := errors.New("failed to encoding data")

var (
compressGzip = func(data []byte) ([]byte, error) {
var buf bytes.Buffer
writer := gzip.NewWriter(&buf)
if _, err := writer.Write(data); err != nil {
return nil, encodingErr
}
if err := writer.Flush(); err != nil {
return nil, encodingErr
}
if err := writer.Close(); err != nil {
return nil, encodingErr
}
return buf.Bytes(), nil
}
compressDeflate = func(data []byte) ([]byte, error) {
var buf bytes.Buffer
writer := zlib.NewWriter(&buf)
if _, err := writer.Write(data); err != nil {
return nil, encodingErr
}
if err := writer.Flush(); err != nil {
return nil, encodingErr
}
if err := writer.Close(); err != nil {
return nil, encodingErr
}
return buf.Bytes(), nil
}
)
compressionTests := []struct {
contentEncoding string
compressWriter func([]byte) ([]byte, error)
}{
{
contentEncoding: "gzip",
compressWriter: compressGzip,
},
{
contentEncoding: "gzip,invalid",
compressWriter: compressGzip,
},
{
contentEncoding: "deflate",
compressWriter: compressDeflate,
},
{
contentEncoding: "gzip,deflate",
compressWriter: func(data []byte) ([]byte, error) {
var (
buf bytes.Buffer
writer interface {
io.WriteCloser
Flush() error
}
err error
)

// deflate
{
writer = zlib.NewWriter(&buf)
if _, err = writer.Write(data); err != nil {
return nil, encodingErr
}
if err = writer.Flush(); err != nil {
return nil, encodingErr
}
if err = writer.Close(); err != nil {
return nil, encodingErr
}
}

c.Request().SetBody(buf.Bytes())
data = make([]byte, buf.Len())
copy(data, buf.Bytes())
buf.Reset()

// gzip
{
writer = gzip.NewWriter(&buf)
if _, err = writer.Write(data); err != nil {
return nil, encodingErr
}
if err = writer.Flush(); err != nil {
return nil, encodingErr
}
if err = writer.Close(); err != nil {
return nil, encodingErr
}
}

for i := 0; i < b.N; i++ {
_ = c.Body()
return buf.Bytes(), nil
},
},
}

utils.AssertEqual(b, []byte("john=doe"), c.Body())
for _, ct := range compressionTests {
b.Run(ct.contentEncoding, func(b *testing.B) {
app := New()
const input = "john=doe"
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)

c.Request().Header.Set("Content-Encoding", ct.contentEncoding)
compressedBody, err := ct.compressWriter([]byte(input))
utils.AssertEqual(b, nil, err)

c.Request().SetBody(compressedBody)
for i := 0; i < b.N; i++ {
_ = c.Body()
}

utils.AssertEqual(b, []byte(input), c.Body())
})
}
}

// go test -run Test_Ctx_BodyParser
Expand Down
Loading

0 comments on commit f29f39b

Please sign in to comment.