diff --git a/.devcontainer/docker-compose.yml b/.devcontainer/docker-compose.yml index 1b4a693ea8c..981924e4532 100644 --- a/.devcontainer/docker-compose.yml +++ b/.devcontainer/docker-compose.yml @@ -6,7 +6,7 @@ services: build: context: .. dockerfile: .devcontainer/Dockerfile - command: /bin/sh -c "while sleep 1000; do :; done" + command: sleep infinity env_file: - common.env - backend.env diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9d6819dcafe..2df08998562 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: exclude_types: [svg] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.3.3 + rev: v0.3.4 hooks: - id: ruff - id: ruff-format @@ -24,12 +24,12 @@ repos: exclude: tests/ - repo: https://github.com/fpgmaas/deptry.git - rev: "0.14.0" + rev: "0.15.0" hooks: - id: deptry - repo: https://github.com/returntocorp/semgrep - rev: v1.64.0 + rev: v1.66.2 hooks: - id: semgrep language_version: python3.9 diff --git a/CHANGELOG.md b/CHANGELOG.md index 69054996092..0817889b10c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,8 +10,9 @@ All notable, unreleased changes to this project will be documented in this file. ### Breaking changes ### GraphQL API - +- Add `translatableContent` to all translation types; add translated object id to all translatable content types - #15617 by @zedzior - Add a `taxConfiguration` to a `Channel` - #15610 by @Air-t +- Deprecate the `taxTypes` query - #15802 by @maarcingebala ### Saleor Apps @@ -19,6 +20,8 @@ All notable, unreleased changes to this project will be documented in this file. - Remove `prefetched_for_webhook` to legacy payload generators - #15369 by @AjmalPonneth - Don't raise InsufficientStock for track_inventory=False variants #15475 by @carlosa54 - DB performance improvements in attribute dataloaders - #15474 by @AjmalPonneth +- Calculate order promotions in draft orders - #15459 by @zedzior +- Prevent name overwriting of Product Variants when Updating Product Types - #15670 by @teddyondieki # 3.19.0 @@ -57,6 +60,7 @@ All notable, unreleased changes to this project will be documented in this file. - Added new input `AppInput.identifier`. - Added new parameter `identifier` for `create_app` command. - When `taxAppId` is provided for `TaxConfiguration` do not allow to finalize `checkoutComplete` or `draftOrderComplete` mutations if Tax App or Avatax plugin didn't respond. +- Add `unique_type` to `OrderLineDiscount` and `CheckoutLineDiscount` models - #15774 by @zedzior # 3.18.0 diff --git a/README.md b/README.md index bc282fb4be6..a308ae51394 100644 --- a/README.md +++ b/README.md @@ -8,17 +8,17 @@
- Customer-centric e-commerce on a modern stack + Commerce that works with your language and stack
- A headless, GraphQL commerce platform delivering ultra-fast, dynamic, personalized shopping experiences.
Beautiful online stores, anywhere, on any device. + GraphQL native, API-only platform for scalable composable commerce.

- Join our active, engaged community:
+ Join our community:
Website | Twitter @@ -29,7 +29,7 @@
- Blog + Blog | Subscribe to newsletter
@@ -60,29 +60,55 @@ ## What makes Saleor special? -Saleor is a rapidly-growing open-source e-commerce platform that serves high-volume companies. Designed from the ground up to be extensible, headless, and composable. +- **Technology-agnostic** - no monolithic plugin architecture or technology lock-in. -Learn more about [architecture](https://docs.saleor.io/docs/3.x/overview/architecture). +- **GraphQL only** - Not afterthought API design or fragmentation across different styles of API. -## Features +- **Headless and API only** - APIs are the only way to interact, configure, or extend the backend. + +- **Open source** - a single version of Saleor without feature fragmentation or commercial limitations. + +- **Cloud native** - battle tested on global brands. + +- **Native-multichannel** - Per [channel](https://docs.saleor.io/docs/3.x/developer/channels) control of pricing, currencies, stock, product, and more. + + +## Why API-only Architecture? + +Saleor's API-first extensibility provides powerful tools for developers to extend backend using [webhooks](https://docs.saleor.io/docs/3.x/developer/extending/webhooks/overview), attributes, [metadata](https://docs.saleor.io/docs/3.x/api-usage/metadata), [apps](https://docs.saleor.io/docs/3.x/developer/extending/apps/overview), [subscription queries](https://docs.saleor.io/docs/3.x/developer/extending/webhooks/subscription-webhook-payloads), [API extensions](https://docs.saleor.io/docs/3.x/developer/extending/webhooks/synchronous-events/overview), [dashboard iframes](https://docs.saleor.io/docs/3.x/developer/extending/apps/overview). + +Compared to traditional plugin architectures (monoliths) it provides the following benefits: -- **Headless / API first**: Build mobile apps, custom storefronts, POS, automation, etc -- **Extensible**: Build anything with webhooks, apps, metadata, and attributes -- [**App Store**](https://github.com/saleor/apps): Leverage a collection of built-in integrations -- **GraphQL API**: Get many resources in a single request and [more](https://graphql.org/) -- **Multichannel**: Per channel control of pricing, currencies, stock, product, and more +* There is less downtime as apps are deployed independently. +* Reliability and performance - custom logic is separated from the core. +* Simplified upgrade paths - eliminates incompatibility conflicts between extensions. +* Technology-agnostic - works with any technology, stack, or language. +* Parallel development - easier to collaborate than with a monolithic core. +* Simplified debugging - easier to narrow down bugs in independent services. +* Scalability - extensions and apps can be scaled independently. + +### What are the tradeoffs? +If you are a single developer working with a small business that doesn't have high traffic or a critical need for 24/7 availability, using a service-oriented approach might feel more complex compared to the traditional WordPress or Magento approach that provides a language-specific framework, runtime, database schema, aspect-oriented programming, and other tools to a quick start. + +However, if you deploy on a daily basis, reliability and uptime is critical, +you need to collaborate with other developers, or you have non-trivial requirements you might be in the right place. + +## Features - **Enterprise ready**: Secure, scalable, and stable. Battle-tested by big brands -- **CMS**: Content is king, that's why we have a kingdom built-in - **Dashboard**: User-friendly, fast, and productive. (Decoupled project [repo](https://github.com/saleor/saleor-dashboard) ) - **Global by design** Multi-currency, multi-language, multi-warehouse, tutti multi! -- **Orders**: A comprehensive system for orders, dispatch, and refunds -- **Cart**: Advanced payment and tax options, with full control over discounts and promotions -- **Payments**: Flexible API architecture allows integration of any payment method -- **SEO**: Packed with features that get stores to a wider audience -- **Cloud**: Optimized for deployments using Docker +- **CMS**: Manage product or marketing content. +- **Product management**: A rich content model for large and complex catalogs. +- **Orders**: Flexible order model, split payments, multi-warehouse, returns, and more. +- **Customers**: Order history and preferences. +- **Promotion engine**: Sales, vouchers, cart rules, giftcards. +- **Payment orchestration**: multi-gateway, extensible payment API, flexible flows. +- **Cart**: Advanced payment and tax options, with full control over discounts and promotions. +- **Payments**: Flexible API architecture allows integration of any payment method. +- **Translations**: Fully translatable catalog. +- **SEO**: Unlimited SEO freedom with headless architecture. +- **Apps**: Extend dashboard via iframe with any web stack. -Saleor is free and always will be. -Help us out… If you love free stuff and great software, give us a star! 🌟 ![Saleor Dashboard - Modern UI for managing your e-commerce](https://user-images.githubusercontent.com/9268745/224249510-d3c7658e-6d5c-42c5-b4fb-93eaf65a5335.png) @@ -146,13 +172,6 @@ If nothing grabs your attention, check [our roadmap](https://github.com/orgs/sal Get more details in our [Contributing Guide](https://docs.saleor.io/docs/developer/community/contributing). -## Your feedback - -Do you use Saleor as an e-commerce platform? -Fill out this short survey and help us grow. It will take just a minute, but means a lot! - -[Take a survey](https://mirumee.typeform.com/to/sOIJbJ) - ## License Disclaimer: Everything you see here is open and free to use as long as you comply with the [license](https://github.com/saleor/saleor/blob/master/LICENSE). There are no hidden charges. We promise to do our best to fix bugs and improve the code. diff --git a/poetry.lock b/poetry.lock index 681a059ab4e..a604b4b2c4a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -59,13 +59,13 @@ trio = ["trio (>=0.23)"] [[package]] name = "asgiref" -version = "3.7.2" +version = "3.8.1" description = "ASGI specs, helper code, and adapters" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "asgiref-3.7.2-py3-none-any.whl", hash = "sha256:89b2ef2247e3b562a16eef663bc0e2e703ec6468e2fa8a5cd61cd449786d4f6e"}, - {file = "asgiref-3.7.2.tar.gz", hash = "sha256:9e0ce3aa93a819ba5b45120216b23878cf6e8525eb3848653452b4192b92afed"}, + {file = "asgiref-3.8.1-py3-none-any.whl", hash = "sha256:3e1e3ecc849832fe52ccf2cb6686b7a55f82bb1d6aee72a58826471390335e47"}, + {file = "asgiref-3.8.1.tar.gz", hash = "sha256:c343bd80a0bec947a9860adb4c432ffa7db769836c64238fc34bdc3fec84d590"}, ] [package.dependencies] @@ -321,17 +321,17 @@ files = [ [[package]] name = "boto3" -version = "1.34.64" +version = "1.34.74" description = "The AWS SDK for Python" optional = false -python-versions = ">= 3.8" +python-versions = ">=3.8" files = [ - {file = "boto3-1.34.64-py3-none-any.whl", hash = "sha256:8c6fbd3d45399a4e4685010117fb2dc52fc6afdab5a9460957d463ae0c2cc55d"}, - {file = "boto3-1.34.64.tar.gz", hash = "sha256:e5d681f443645e6953ed0727bf756bf16d85efefcb69cf051d04a070ce65e545"}, + {file = "boto3-1.34.74-py3-none-any.whl", hash = "sha256:71f551491fb12fe07727d371d5561c5919fdf33dbc1d4251c57940d267a53a9e"}, + {file = "boto3-1.34.74.tar.gz", hash = "sha256:b703e22775561a748adc4576c30424b81abd2a00d3c6fb28eec2e5cde92c1eed"}, ] [package.dependencies] -botocore = ">=1.34.64,<1.35.0" +botocore = ">=1.34.74,<1.35.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.10.0,<0.11.0" @@ -340,13 +340,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.34.64" +version = "1.34.74" description = "Low-level, data-driven core of boto 3." optional = false -python-versions = ">= 3.8" +python-versions = ">=3.8" files = [ - {file = "botocore-1.34.64-py3-none-any.whl", hash = "sha256:0ab760908749fe82325698591c49755a5bb20307d85a419aca9cc74e783b9407"}, - {file = "botocore-1.34.64.tar.gz", hash = "sha256:084f8c45216d62dc1add2350e236a2d5283526aacd0681e9818b37a6a5e5438b"}, + {file = "botocore-1.34.74-py3-none-any.whl", hash = "sha256:5d2015b5d91d6c402c122783729ce995ed7283a746b0380957026dc2b3b75969"}, + {file = "botocore-1.34.74.tar.gz", hash = "sha256:32bb519bae62483893330c18a0ea4fd09d1ffa32bc573cd8559c2d9a08fb8c5c"}, ] [package.dependencies] @@ -792,13 +792,13 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""} [[package]] name = "click-didyoumean" -version = "0.3.0" +version = "0.3.1" description = "Enables git-like *did-you-mean* feature in click" optional = false -python-versions = ">=3.6.2,<4.0.0" +python-versions = ">=3.6.2" files = [ - {file = "click-didyoumean-0.3.0.tar.gz", hash = "sha256:f184f0d851d96b6d29297354ed981b7dd71df7ff500d82fa6d11f0856bee8035"}, - {file = "click_didyoumean-0.3.0-py3-none-any.whl", hash = "sha256:a0713dc7a1de3f06bc0df5a9567ad19ead2d3d5689b434768a6145bff77c0667"}, + {file = "click_didyoumean-0.3.1-py3-none-any.whl", hash = "sha256:5c4bb6007cfea5f2fd6583a2fb6701a22a41eb98957e63d0fac41c10e7c3117c"}, + {file = "click_didyoumean-0.3.1.tar.gz", hash = "sha256:4f82fdff0dbe64ef8ab2279bd6aa3f6a99c3b28c05aa09cbfc07c9d7fbb5a463"}, ] [package.dependencies] @@ -1155,13 +1155,13 @@ tzdata = "*" [[package]] name = "django-countries" -version = "7.5.1" +version = "7.6" description = "Provides a country field for Django models." optional = false python-versions = "*" files = [ - {file = "django-countries-7.5.1.tar.gz", hash = "sha256:22915d9b9403932b731622619940a54894a3eb0da9a374e7249c8fc453c122d7"}, - {file = "django_countries-7.5.1-py3-none-any.whl", hash = "sha256:2df707aca7a5e677254bed116cf6021a136ebaccd5c2f46860abd6452bb45521"}, + {file = "django-countries-7.6.tar.gz", hash = "sha256:aba80a9ce7c293671bb36507c98c4f8886b768868160a1b7afd2fa2aee4c5a0b"}, + {file = "django_countries-7.6-py3-none-any.whl", hash = "sha256:1939b6b28fc341615f1b62ba4a681b54e67414df40cdea5e99ce3897546af92c"}, ] [package.dependencies] @@ -1169,8 +1169,8 @@ asgiref = "*" typing-extensions = "*" [package.extras] -dev = ["black", "django", "djangorestframework", "graphene-django", "pytest", "pytest-django", "tox"] -maintainer = ["django", "transifex-client", "zest.releaser[recommended]"] +dev = ["black", "django", "djangorestframework", "graphene-django", "pytest", "pytest-django", "tox (==4.*)"] +maintainer = ["django", "zest.releaser[recommended]"] pyuca = ["pyuca"] test = ["djangorestframework", "graphene-django", "pytest", "pytest-cov", "pytest-django"] @@ -1546,18 +1546,18 @@ probabilistic = ["pyprobables (>=0.6,<0.7)"] [[package]] name = "filelock" -version = "3.13.1" +version = "3.13.3" description = "A platform independent file lock." optional = false python-versions = ">=3.8" files = [ - {file = "filelock-3.13.1-py3-none-any.whl", hash = "sha256:57dbda9b35157b05fb3e58ee91448612eb674172fab98ee235ccb0b5bee19a1c"}, - {file = "filelock-3.13.1.tar.gz", hash = "sha256:521f5f56c50f8426f5e03ad3b281b490a87ef15bc6c526f168290f0c7148d44e"}, + {file = "filelock-3.13.3-py3-none-any.whl", hash = "sha256:5ffa845303983e7a0b7ae17636509bc97997d58afeafa72fb141a17b152284cb"}, + {file = "filelock-3.13.3.tar.gz", hash = "sha256:a79895a25bbefdf55d1a2a0a80968f7dbb28edcd6d4234a0afb3f37ecde4b546"}, ] [package.extras] -docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.24)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] +docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] typing = ["typing-extensions (>=4.8)"] [[package]] @@ -1665,13 +1665,13 @@ yaml = ["PyYAML"] [[package]] name = "google-api-core" -version = "2.17.1" +version = "2.18.0" description = "Google API client core library" optional = false python-versions = ">=3.7" files = [ - {file = "google-api-core-2.17.1.tar.gz", hash = "sha256:9df18a1f87ee0df0bc4eea2770ebc4228392d8cc4066655b320e2cfccb15db95"}, - {file = "google_api_core-2.17.1-py3-none-any.whl", hash = "sha256:610c5b90092c360736baccf17bd3efbcb30dd380e7a6dc28a71059edb8bd0d8e"}, + {file = "google-api-core-2.18.0.tar.gz", hash = "sha256:62d97417bfc674d6cef251e5c4d639a9655e00c45528c4364fbfebb478ce72a9"}, + {file = "google_api_core-2.18.0-py3-none-any.whl", hash = "sha256:5a63aa102e0049abe85b5b88cb9409234c1f70afcda21ce1e40b285b9629c1d6"}, ] [package.dependencies] @@ -1679,6 +1679,7 @@ google-auth = ">=2.14.1,<3.0.dev0" googleapis-common-protos = ">=1.56.2,<2.0.dev0" grpcio = {version = ">=1.33.2,<2.0dev", optional = true, markers = "extra == \"grpc\""} grpcio-status = {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "extra == \"grpc\""} +proto-plus = ">=1.22.3,<2.0.0dev" protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" requests = ">=2.18.0,<3.0.0.dev0" @@ -1689,13 +1690,13 @@ grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] [[package]] name = "google-auth" -version = "2.28.2" +version = "2.29.0" description = "Google Authentication Library" optional = false python-versions = ">=3.7" files = [ - {file = "google-auth-2.28.2.tar.gz", hash = "sha256:80b8b4969aa9ed5938c7828308f20f035bc79f9d8fb8120bf9dc8db20b41ba30"}, - {file = "google_auth-2.28.2-py2.py3-none-any.whl", hash = "sha256:9fd67bbcd40f16d9d42f950228e9cf02a2ded4ae49198b27432d0cded5a74c38"}, + {file = "google-auth-2.29.0.tar.gz", hash = "sha256:672dff332d073227550ffc7457868ac4218d6c500b155fe6cc17d2b13602c360"}, + {file = "google_auth-2.29.0-py2.py3-none-any.whl", hash = "sha256:d452ad095688cd52bae0ad6fafe027f6a6d6f560e810fec20914e17a09526415"}, ] [package.dependencies] @@ -1730,13 +1731,13 @@ grpc = ["grpcio (>=1.38.0,<2.0dev)", "grpcio-status (>=1.38.0,<2.0.dev0)"] [[package]] name = "google-cloud-pubsub" -version = "2.20.2" +version = "2.21.0" description = "Google Cloud Pub/Sub API client library" optional = false python-versions = ">=3.7" files = [ - {file = "google-cloud-pubsub-2.20.2.tar.gz", hash = "sha256:236046ea860230c788e4d4ea2d0f12299cdf1d94ac71ec42ed1a0ce1ba28d66f"}, - {file = "google_cloud_pubsub-2.20.2-py2.py3-none-any.whl", hash = "sha256:9607bb8f973cbd123b5fa2db9c0aa38501a1a42f18593739067cd307263d090f"}, + {file = "google-cloud-pubsub-2.21.0.tar.gz", hash = "sha256:94017f0bc9a85fa3f4d913f312e930a0fe21775bd68dde5c666e2f1b1addf811"}, + {file = "google_cloud_pubsub-2.21.0-py2.py3-none-any.whl", hash = "sha256:fabd19e08faa1f70081b0e5ea003a3c031a1d2c1b798cf1be5ded2dbf1d1dbef"}, ] [package.dependencies] @@ -1753,13 +1754,13 @@ libcst = ["libcst (>=0.3.10)"] [[package]] name = "google-cloud-storage" -version = "2.15.0" +version = "2.16.0" description = "Google Cloud Storage API client library" optional = false python-versions = ">=3.7" files = [ - {file = "google-cloud-storage-2.15.0.tar.gz", hash = "sha256:7560a3c48a03d66c553dc55215d35883c680fe0ab44c23aa4832800ccc855c74"}, - {file = "google_cloud_storage-2.15.0-py2.py3-none-any.whl", hash = "sha256:5d9237f88b648e1d724a0f20b5cde65996a37fe51d75d17660b1404097327dd2"}, + {file = "google-cloud-storage-2.16.0.tar.gz", hash = "sha256:dda485fa503710a828d01246bd16ce9db0823dc51bbca742ce96a6817d58669f"}, + {file = "google_cloud_storage-2.16.0-py2.py3-none-any.whl", hash = "sha256:91a06b96fb79cf9cdfb4e759f178ce11ea885c79938f89590344d079305f5852"}, ] [package.dependencies] @@ -2089,22 +2090,23 @@ protobuf = ">=4.21.6" [[package]] name = "gunicorn" -version = "21.2.0" +version = "22.0.0" description = "WSGI HTTP Server for UNIX" optional = false -python-versions = ">=3.5" +python-versions = ">=3.7" files = [ - {file = "gunicorn-21.2.0-py3-none-any.whl", hash = "sha256:3213aa5e8c24949e792bcacfc176fef362e7aac80b76c56f6b5122bf350722f0"}, - {file = "gunicorn-21.2.0.tar.gz", hash = "sha256:88ec8bff1d634f98e61b9f65bc4bf3cd918a90806c6f5c48bc5603849ec81033"}, + {file = "gunicorn-22.0.0-py3-none-any.whl", hash = "sha256:350679f91b24062c86e386e198a15438d53a7a8207235a78ba1b53df4c4378d9"}, + {file = "gunicorn-22.0.0.tar.gz", hash = "sha256:4a0b436239ff76fb33f11c07a16482c521a7e09c1ce3cc293c2330afe01bec63"}, ] [package.dependencies] packaging = "*" [package.extras] -eventlet = ["eventlet (>=0.24.1)"] +eventlet = ["eventlet (>=0.24.1,!=0.36.0)"] gevent = ["gevent (>=1.4.0)"] setproctitle = ["setproctitle"] +testing = ["coverage", "eventlet", "gevent", "pytest", "pytest-cov"] tornado = ["tornado (>=0.2)"] [[package]] @@ -2242,13 +2244,13 @@ files = [ [[package]] name = "importlib-metadata" -version = "7.0.2" +version = "7.1.0" description = "Read metadata from Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "importlib_metadata-7.0.2-py3-none-any.whl", hash = "sha256:f4bc4c0c070c490abf4ce96d715f68e95923320370efb66143df00199bb6c100"}, - {file = "importlib_metadata-7.0.2.tar.gz", hash = "sha256:198f568f3230878cb1b44fbd7975f87906c22336dba2e4a7f05278c281fbd792"}, + {file = "importlib_metadata-7.1.0-py3-none-any.whl", hash = "sha256:30962b96c0c223483ed6cc7280e7f0199feb01a0e40cfae4d4450fc6fab1f570"}, + {file = "importlib_metadata-7.1.0.tar.gz", hash = "sha256:b78938b926ee8d5f020fc4772d487045805a55ddbad2ecf21c6d60938dc7fcd2"}, ] [package.dependencies] @@ -2257,7 +2259,7 @@ zipp = ">=0.5" [package.extras] docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] perf = ["ipython"] -testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] +testing = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] [[package]] name = "iniconfig" @@ -2368,13 +2370,13 @@ referencing = ">=0.31.0" [[package]] name = "kombu" -version = "5.3.5" +version = "5.3.6" description = "Messaging library for Python." optional = false python-versions = ">=3.8" files = [ - {file = "kombu-5.3.5-py3-none-any.whl", hash = "sha256:0eac1bbb464afe6fb0924b21bf79460416d25d8abc52546d4f16cad94f789488"}, - {file = "kombu-5.3.5.tar.gz", hash = "sha256:30e470f1a6b49c70dc6f6d13c3e4cc4e178aa6c469ceb6bcd55645385fc84b93"}, + {file = "kombu-5.3.6-py3-none-any.whl", hash = "sha256:49f1e62b12369045de2662f62cc584e7df83481a513db83b01f87b5b9785e378"}, + {file = "kombu-5.3.6.tar.gz", hash = "sha256:f3da5b570a147a5da8280180aa80b03807283d63ea5081fcdb510d18242431d9"}, ] [package.dependencies] @@ -2392,7 +2394,7 @@ mongodb = ["pymongo (>=4.1.1)"] msgpack = ["msgpack"] pyro = ["pyro4"] qpid = ["qpid-python (>=0.26)", "qpid-tools (>=0.26)"] -redis = ["redis (>=4.5.2,!=4.5.5,<6.0.0)"] +redis = ["redis (>=4.5.2,!=4.5.5,!=5.0.2)"] slmq = ["softlayer-messaging (>=1.0.3)"] sqlalchemy = ["sqlalchemy (>=1.4.48,<2.1)"] sqs = ["boto3 (>=1.26.143)", "pycurl (>=7.43.0.5)", "urllib3 (>=1.26.16)"] @@ -3057,13 +3059,13 @@ xpath = ["lxml (>=4.4.0)"] [[package]] name = "phonenumberslite" -version = "8.13.32" +version = "8.13.33" description = "Python version of Google's common library for parsing, formatting, storing and validating international phone numbers." optional = false python-versions = "*" files = [ - {file = "phonenumberslite-8.13.32-py2.py3-none-any.whl", hash = "sha256:7b6a539c2dd483a10385528a9cb31e217c6deecfa1fcce6fd0386cadadae2490"}, - {file = "phonenumberslite-8.13.32.tar.gz", hash = "sha256:e1368eb5a2622c2a1d8330fbddc8e7ac871e1b3776f82c338423823c27739159"}, + {file = "phonenumberslite-8.13.33-py2.py3-none-any.whl", hash = "sha256:4d92f4f9079bb83588dde45fd8a414bc13e4962886aa4d23576984196f4d83c2"}, + {file = "phonenumberslite-8.13.33.tar.gz", hash = "sha256:7426bc46af3de5a800a4c8f33ab13e33225d2c8ed4fc52aa3c0380dadd8d7381"}, ] [[package]] @@ -3240,13 +3242,13 @@ files = [ [[package]] name = "pre-commit" -version = "3.6.2" +version = "3.7.0" description = "A framework for managing and maintaining multi-language pre-commit hooks." optional = false python-versions = ">=3.9" files = [ - {file = "pre_commit-3.6.2-py2.py3-none-any.whl", hash = "sha256:ba637c2d7a670c10daedc059f5c49b5bd0aadbccfcd7ec15592cf9665117532c"}, - {file = "pre_commit-3.6.2.tar.gz", hash = "sha256:c3ef34f463045c88658c5b99f38c1e297abdcc0ff13f98d3370055fbbfabc67e"}, + {file = "pre_commit-3.7.0-py2.py3-none-any.whl", hash = "sha256:5eae9e10c2b5ac51577c3452ec0a490455c45a0533f7960f993a0d01e59decab"}, + {file = "pre_commit-3.7.0.tar.gz", hash = "sha256:e209d61b8acdcf742404408531f0c37d49d2c734fd7cff2d6076083d191cb060"}, ] [package.dependencies] @@ -3361,28 +3363,28 @@ files = [ [[package]] name = "pyasn1" -version = "0.5.1" +version = "0.6.0" description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)" optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +python-versions = ">=3.8" files = [ - {file = "pyasn1-0.5.1-py2.py3-none-any.whl", hash = "sha256:4439847c58d40b1d0a573d07e3856e95333f1976294494c325775aeca506eb58"}, - {file = "pyasn1-0.5.1.tar.gz", hash = "sha256:6d391a96e59b23130a5cfa74d6fd7f388dbbe26cc8f1edf39fdddf08d9d6676c"}, + {file = "pyasn1-0.6.0-py2.py3-none-any.whl", hash = "sha256:cca4bb0f2df5504f02f6f8a775b6e416ff9b0b3b16f7ee80b5a3153d9b804473"}, + {file = "pyasn1-0.6.0.tar.gz", hash = "sha256:3a35ab2c4b5ef98e17dfdec8ab074046fbda76e281c5a706ccd82328cfc8f64c"}, ] [[package]] name = "pyasn1-modules" -version = "0.3.0" +version = "0.4.0" description = "A collection of ASN.1-based protocols modules" optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +python-versions = ">=3.8" files = [ - {file = "pyasn1_modules-0.3.0-py2.py3-none-any.whl", hash = "sha256:d3ccd6ed470d9ffbc716be08bd90efbd44d0734bc9303818f7336070984a162d"}, - {file = "pyasn1_modules-0.3.0.tar.gz", hash = "sha256:5bd01446b736eb9d31512a30d46c1ac3395d676c6f3cafa4c03eb54b9925631c"}, + {file = "pyasn1_modules-0.4.0-py3-none-any.whl", hash = "sha256:be04f15b66c206eed667e0bb5ab27e2b1855ea54a842e5037738099e8ca4ae0b"}, + {file = "pyasn1_modules-0.4.0.tar.gz", hash = "sha256:831dbcea1b177b28c9baddf4c6d1013c24c3accd14a1873fffaa6a2e905f17b6"}, ] [package.dependencies] -pyasn1 = ">=0.4.6,<0.6.0" +pyasn1 = ">=0.4.6,<0.7.0" [[package]] name = "pybars3" @@ -3399,13 +3401,13 @@ PyMeta3 = ">=0.5.1" [[package]] name = "pycparser" -version = "2.21" +version = "2.22" description = "C parser in Python" optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +python-versions = ">=3.8" files = [ - {file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"}, - {file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"}, + {file = "pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc"}, + {file = "pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6"}, ] [[package]] @@ -3614,17 +3616,17 @@ test = ["covdefaults (>=2.2.2)", "coverage (>=7.0.5)", "flaky (>=3.7)", "pytest [[package]] name = "pytest-mock" -version = "3.12.0" +version = "3.14.0" description = "Thin-wrapper around the mock package for easier use with pytest" optional = false python-versions = ">=3.8" files = [ - {file = "pytest-mock-3.12.0.tar.gz", hash = "sha256:31a40f038c22cad32287bb43932054451ff5583ff094bca6f675df2f8bc1a6e9"}, - {file = "pytest_mock-3.12.0-py3-none-any.whl", hash = "sha256:0972719a7263072da3a21c7f4773069bcc7486027d7e8e1f81d98a47e701bc4f"}, + {file = "pytest-mock-3.14.0.tar.gz", hash = "sha256:2719255a1efeceadbc056d6bf3df3d1c5015530fb40cf347c0f9afac88410bd0"}, + {file = "pytest_mock-3.14.0-py3-none-any.whl", hash = "sha256:0b72c38033392a5f4621342fe11e9219ac11ec9d375f8e2a0c164539e0d70f6f"}, ] [package.dependencies] -pytest = ">=5.0" +pytest = ">=6.2.5" [package.extras] dev = ["pre-commit", "pytest-asyncio", "tox"] @@ -3882,13 +3884,13 @@ files = [ [[package]] name = "razorpay" -version = "1.4.1" +version = "1.4.2" description = "Razorpay Python Client" optional = false python-versions = "*" files = [ - {file = "razorpay-1.4.1-py3-none-any.whl", hash = "sha256:89b01c386cc93cc5af07d613c99c24fa3ded95d7def94b89e0ac294fb49cbe7c"}, - {file = "razorpay-1.4.1.tar.gz", hash = "sha256:236af5a6ae6512345907f4b7f42297980bb8314dbbbc684313a476b7f45179bd"}, + {file = "razorpay-1.4.2-py3-none-any.whl", hash = "sha256:aaf525baebc5001e5c54ae317a372c3c17c81314cedcbbf71aa0a2d3419a9757"}, + {file = "razorpay-1.4.2.tar.gz", hash = "sha256:0d812030ba29e776e66d1ceaacc31635810fde18b66ff2591e582eb651bb5131"}, ] [package.dependencies] @@ -3950,12 +3952,12 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] [[package]] name = "requests-hardened" -version = "1.0.0b2" +version = "1.0.0b3" description = "A library that overrides the default behaviors of the requests library, and adds new security features." optional = false python-versions = ">=3.8" files = [ - {file = "requests-hardened-1.0.0b2.tar.gz", hash = "sha256:28c5591f1f346f4b1a014d97cb0e1addb852bcdcfb5c129b6ca5443587c32c73"}, + {file = "requests-hardened-1.0.0b3.tar.gz", hash = "sha256:125057fb864e4283c926021f594c9e4695432036f13fd76fee3ef738510231e2"}, ] [package.dependencies] @@ -4183,28 +4185,28 @@ files = [ [[package]] name = "ruff" -version = "0.3.3" +version = "0.3.4" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.3.3-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:973a0e388b7bc2e9148c7f9be8b8c6ae7471b9be37e1cc732f8f44a6f6d7720d"}, - {file = "ruff-0.3.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:cfa60d23269d6e2031129b053fdb4e5a7b0637fc6c9c0586737b962b2f834493"}, - {file = "ruff-0.3.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1eca7ff7a47043cf6ce5c7f45f603b09121a7cc047447744b029d1b719278eb5"}, - {file = "ruff-0.3.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e7d3f6762217c1da954de24b4a1a70515630d29f71e268ec5000afe81377642d"}, - {file = "ruff-0.3.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b24c19e8598916d9c6f5a5437671f55ee93c212a2c4c569605dc3842b6820386"}, - {file = "ruff-0.3.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:5a6cbf216b69c7090f0fe4669501a27326c34e119068c1494f35aaf4cc683778"}, - {file = "ruff-0.3.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:352e95ead6964974b234e16ba8a66dad102ec7bf8ac064a23f95371d8b198aab"}, - {file = "ruff-0.3.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8d6ab88c81c4040a817aa432484e838aaddf8bfd7ca70e4e615482757acb64f8"}, - {file = "ruff-0.3.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:79bca3a03a759cc773fca69e0bdeac8abd1c13c31b798d5bb3c9da4a03144a9f"}, - {file = "ruff-0.3.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:2700a804d5336bcffe063fd789ca2c7b02b552d2e323a336700abb8ae9e6a3f8"}, - {file = "ruff-0.3.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:fd66469f1a18fdb9d32e22b79f486223052ddf057dc56dea0caaf1a47bdfaf4e"}, - {file = "ruff-0.3.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:45817af234605525cdf6317005923bf532514e1ea3d9270acf61ca2440691376"}, - {file = "ruff-0.3.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:0da458989ce0159555ef224d5b7c24d3d2e4bf4c300b85467b08c3261c6bc6a8"}, - {file = "ruff-0.3.3-py3-none-win32.whl", hash = "sha256:f2831ec6a580a97f1ea82ea1eda0401c3cdf512cf2045fa3c85e8ef109e87de0"}, - {file = "ruff-0.3.3-py3-none-win_amd64.whl", hash = "sha256:be90bcae57c24d9f9d023b12d627e958eb55f595428bafcb7fec0791ad25ddfc"}, - {file = "ruff-0.3.3-py3-none-win_arm64.whl", hash = "sha256:0171aab5fecdc54383993389710a3d1227f2da124d76a2784a7098e818f92d61"}, - {file = "ruff-0.3.3.tar.gz", hash = "sha256:38671be06f57a2f8aba957d9f701ea889aa5736be806f18c0cd03d6ff0cbca8d"}, + {file = "ruff-0.3.4-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:60c870a7d46efcbc8385d27ec07fe534ac32f3b251e4fc44b3cbfd9e09609ef4"}, + {file = "ruff-0.3.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:6fc14fa742e1d8f24910e1fff0bd5e26d395b0e0e04cc1b15c7c5e5fe5b4af91"}, + {file = "ruff-0.3.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d3ee7880f653cc03749a3bfea720cf2a192e4f884925b0cf7eecce82f0ce5854"}, + {file = "ruff-0.3.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cf133dd744f2470b347f602452a88e70dadfbe0fcfb5fd46e093d55da65f82f7"}, + {file = "ruff-0.3.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3f3860057590e810c7ffea75669bdc6927bfd91e29b4baa9258fd48b540a4365"}, + {file = "ruff-0.3.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:986f2377f7cf12efac1f515fc1a5b753c000ed1e0a6de96747cdf2da20a1b369"}, + {file = "ruff-0.3.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c4fd98e85869603e65f554fdc5cddf0712e352fe6e61d29d5a6fe087ec82b76c"}, + {file = "ruff-0.3.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:64abeed785dad51801b423fa51840b1764b35d6c461ea8caef9cf9e5e5ab34d9"}, + {file = "ruff-0.3.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df52972138318bc7546d92348a1ee58449bc3f9eaf0db278906eb511889c4b50"}, + {file = "ruff-0.3.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:98e98300056445ba2cc27d0b325fd044dc17fcc38e4e4d2c7711585bd0a958ed"}, + {file = "ruff-0.3.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:519cf6a0ebed244dce1dc8aecd3dc99add7a2ee15bb68cf19588bb5bf58e0488"}, + {file = "ruff-0.3.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:bb0acfb921030d00070539c038cd24bb1df73a2981e9f55942514af8b17be94e"}, + {file = "ruff-0.3.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:cf187a7e7098233d0d0c71175375c5162f880126c4c716fa28a8ac418dcf3378"}, + {file = "ruff-0.3.4-py3-none-win32.whl", hash = "sha256:af27ac187c0a331e8ef91d84bf1c3c6a5dea97e912a7560ac0cef25c526a4102"}, + {file = "ruff-0.3.4-py3-none-win_amd64.whl", hash = "sha256:de0d5069b165e5a32b3c6ffbb81c350b1e3d3483347196ffdf86dc0ef9e37dd6"}, + {file = "ruff-0.3.4-py3-none-win_arm64.whl", hash = "sha256:6810563cc08ad0096b57c717bd78aeac888a1bfd38654d9113cb3dc4d3f74232"}, + {file = "ruff-0.3.4.tar.gz", hash = "sha256:f0f4484c6541a99862b693e13a151435a279b271cff20e37101116a21e2a1ad1"}, ] [[package]] @@ -4487,13 +4489,13 @@ files = [ [[package]] name = "textual" -version = "0.52.1" +version = "0.54.0" description = "Modern Text User Interface framework" optional = false -python-versions = ">=3.8,<4.0" +python-versions = "<4.0,>=3.8" files = [ - {file = "textual-0.52.1-py3-none-any.whl", hash = "sha256:960a19df2319482918b4a58736d9552cdc1ab65d170ba0bc15273ce0e1922b7a"}, - {file = "textual-0.52.1.tar.gz", hash = "sha256:4232e5c2b423ed7c63baaeb6030355e14e1de1b9df096c9655b68a1e60e4de5f"}, + {file = "textual-0.54.0-py3-none-any.whl", hash = "sha256:94aacf28dece20a44f0b94b087e17ff4ac961acd92e12e648f060fe2555b3adc"}, + {file = "textual-0.54.0.tar.gz", hash = "sha256:0cfd134dde5ae49d64dd73bb32a2fb5a86d878d9caeacecaa1d640082f31124e"}, ] [package.dependencies] @@ -4520,12 +4522,12 @@ tornado = "*" [[package]] name = "thrift" -version = "0.16.0" +version = "0.20.0" description = "Python bindings for the Apache Thrift RPC system" optional = false python-versions = "*" files = [ - {file = "thrift-0.16.0.tar.gz", hash = "sha256:2b5b6488fcded21f9d312aa23c9ff6a0195d0f6ae26ddbd5ad9e3e25dfc14408"}, + {file = "thrift-0.20.0.tar.gz", hash = "sha256:4dd662eadf6b8aebe8a41729527bd69adf6ceaa2a8681cbef64d1273b3e8feba"}, ] [package.dependencies] @@ -5400,4 +5402,4 @@ test = ["pytest"] [metadata] lock-version = "2.0" python-versions = "~3.9" -content-hash = "16a2c854e9a7d0795c7e1222c51c247373015023e4f738575472d075bbc3b014" +content-hash = "1363addca1d09bfe3f1716c221ebfc88014c9584ad943b94a9f866ac9a8f3ff6" diff --git a/pyproject.toml b/pyproject.toml index 474d144c676..c5cc05d0985 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,7 @@ documentation = "https://docs.saleor.io/" graphene = "<3.0" graphql-core = "^2.3.2" graphql-relay = "^2.0.1" - gunicorn = "^21.2.0" + gunicorn = "^22.0.0" html2text = "^2020.1.16" html-to-draftjs = "^1.0.1" jaeger-client = "^4.5.0" @@ -77,7 +77,7 @@ documentation = "https://docs.saleor.io/" razorpay = "^1.2" redis = "^5.0.1" requests = "^2.22" - requests-hardened = "1.0.0b2" + requests-hardened = "1.0.0b3" Rx = "^1.6.3" semantic-version = "^2.10.0" sendgrid = "^6.7.1" diff --git a/saleor/attribute/migrations/0040_clear_assignedattributes.py b/saleor/attribute/migrations/0040_clear_assignedattributes.py index 5684e04fd14..7e336f5bf05 100644 --- a/saleor/attribute/migrations/0040_clear_assignedattributes.py +++ b/saleor/attribute/migrations/0040_clear_assignedattributes.py @@ -19,10 +19,6 @@ class Migration(migrations.Migration): ALTER TABLE attribute_assignedproductattributevalue ALTER COLUMN product_id SET NOT NULL; - ALTER TABLE attribute_assignedproductattributevalue - ADD CONSTRAINT attribute_assignedproduc_value_id_product_id_6f6deb31_uniq - UNIQUE (value_id, product_id); - DROP TABLE attribute_assignedproductattribute; """, reverse_sql=""" @@ -32,9 +28,6 @@ class Migration(migrations.Migration): ALTER TABLE attribute_assignedproductattributevalue ALTER COLUMN product_id DROP NOT NULL; - ALTER TABLE attribute_assignedproductattributevalue - DROP CONSTRAINT IF EXISTS attribute_assignedproduc_value_id_product_id_6f6deb31_uniq; - CREATE TABLE attribute_assignedproductattribute ( id serial NOT NULL PRIMARY KEY, assignment_id integer, @@ -50,10 +43,6 @@ class Migration(migrations.Migration): ALTER TABLE attribute_assignedpageattributevalue ALTER COLUMN page_id SET NOT NULL; - ALTER TABLE attribute_assignedpageattributevalue - ADD CONSTRAINT attribute_assignedpageat_value_id_page_id_851cd501_uniq - UNIQUE (value_id, page_id); - DROP TABLE attribute_assignedpageattribute; """, reverse_sql=""" @@ -63,9 +52,6 @@ class Migration(migrations.Migration): ALTER TABLE attribute_assignedpageattributevalue ALTER COLUMN page_id DROP NOT NULL; - ALTER TABLE attribute_assignedpageattributevalue - DROP CONSTRAINT IF EXISTS attribute_assignedpageat_value_id_page_id_851cd501_uniq; - CREATE TABLE attribute_assignedpageattribute ( id serial NOT NULL PRIMARY KEY, assignment_id integer, diff --git a/saleor/checkout/calculations.py b/saleor/checkout/calculations.py index 3493f023aff..6e0e9302a90 100644 --- a/saleor/checkout/calculations.py +++ b/saleor/checkout/calculations.py @@ -7,6 +7,7 @@ from prices import Money, TaxedMoney from ..checkout import base_calculations +from ..core.db.connection import allow_writer from ..core.prices import quantize_price from ..core.taxes import TaxData, TaxEmptyData, zero_money, zero_taxed_money from ..discount.utils import ( @@ -275,6 +276,7 @@ def _fetch_checkout_prices_if_expired( database_connection_name=database_connection_name, ) except TaxEmptyData as e: + _set_checkout_base_prices(checkout, checkout_info, lines) checkout.tax_error = str(e) if not should_charge_tax: @@ -300,10 +302,11 @@ def _fetch_checkout_prices_if_expired( database_connection_name=database_connection_name, ) except TaxEmptyData as e: + _set_checkout_base_prices(checkout, checkout_info, lines) checkout.tax_error = str(e) else: # Calculate net prices without taxes. - _get_checkout_base_prices(checkout, checkout_info, lines) + _set_checkout_base_prices(checkout, checkout_info, lines) checkout_update_fields = [ "voucher_code", @@ -325,18 +328,19 @@ def _fetch_checkout_prices_if_expired( checkout.price_expiration = timezone.now() + settings.CHECKOUT_PRICES_TTL - checkout.save( - update_fields=checkout_update_fields, - using=settings.DATABASE_CONNECTION_DEFAULT_NAME, - ) - checkout.lines.bulk_update( - [line_info.line for line_info in lines], - [ - "total_price_net_amount", - "total_price_gross_amount", - "tax_rate", - ], - ) + with allow_writer(): + checkout.save( + update_fields=checkout_update_fields, + using=settings.DATABASE_CONNECTION_DEFAULT_NAME, + ) + checkout.lines.bulk_update( + [line_info.line for line_info in lines], + [ + "total_price_net_amount", + "total_price_gross_amount", + "tax_rate", + ], + ) return checkout_info, lines @@ -530,7 +534,7 @@ def _apply_tax_data_from_plugins( ) -def _get_checkout_base_prices( +def _set_checkout_base_prices( checkout: "Checkout", checkout_info: "CheckoutInfo", lines: Iterable["CheckoutLineInfo"], diff --git a/saleor/checkout/complete_checkout.py b/saleor/checkout/complete_checkout.py index 0cdd8795293..4aaad9f2067 100644 --- a/saleor/checkout/complete_checkout.py +++ b/saleor/checkout/complete_checkout.py @@ -82,6 +82,7 @@ ) from .models import Checkout from .utils import ( + calculate_checkout_weight, get_checkout_metadata, get_or_create_checkout_metadata, get_voucher_for_checkout_info, @@ -145,6 +146,9 @@ def _process_shipping_data_for_order( if checkout_info.user.addresses.filter(pk=shipping_address.pk).exists(): shipping_address = shipping_address.get_copy() + if shipping_address and delivery_method_info.warehouse_pk: + shipping_address = shipping_address.get_copy() + shipping_method = delivery_method_info.delivery_method tax_class = getattr(shipping_method, "tax_class", None) @@ -152,7 +156,7 @@ def _process_shipping_data_for_order( "shipping_address": shipping_address, "base_shipping_price": base_shipping_price, "shipping_price": shipping_price, - "weight": checkout_info.checkout.get_total_weight(lines), + "weight": calculate_checkout_weight(lines), **get_shipping_tax_class_kwargs_for_order(tax_class), } result.update(delivery_method_info.delivery_method_order_field) diff --git a/saleor/checkout/fetch.py b/saleor/checkout/fetch.py index 4965f0c7ba1..c4861ea4333 100644 --- a/saleor/checkout/fetch.py +++ b/saleor/checkout/fetch.py @@ -671,7 +671,7 @@ def _resolve_all_shipping_methods(): ) # Filter shipping methods using sync webhooks excluded_methods = manager.excluded_shipping_methods_for_checkout( - checkout_info.checkout, all_methods + checkout_info.checkout, checkout_info.channel, all_methods ) initialize_shipping_method_active_status(all_methods, excluded_methods) return all_methods diff --git a/saleor/checkout/migrations/0063_auto_20240402_1114.py b/saleor/checkout/migrations/0063_auto_20240402_1114.py new file mode 100644 index 00000000000..2267d5dd53c --- /dev/null +++ b/saleor/checkout/migrations/0063_auto_20240402_1114.py @@ -0,0 +1,17 @@ +# Generated by Django 3.2.22 on 2024-04-02 11:14 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ( + "checkout", + "0062_update_checkout_last_transaction_modified_at_and_refundable", + ) + ] + + # Execution of a task moved to 0067_auto_20240405_0756.py. + # The migration task was not triggered due to missing proper celeryconf setup. + # Migration file left for history consistency. + operations = [] diff --git a/saleor/checkout/migrations/0066_merge_0063_auto_20240402_1114_0065_checkout_tax_error.py b/saleor/checkout/migrations/0066_merge_0063_auto_20240402_1114_0065_checkout_tax_error.py new file mode 100644 index 00000000000..3737c768b46 --- /dev/null +++ b/saleor/checkout/migrations/0066_merge_0063_auto_20240402_1114_0065_checkout_tax_error.py @@ -0,0 +1,11 @@ +# Generated by Django 3.2.25 on 2024-04-03 08:09 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("checkout", "0063_auto_20240402_1114"), + ("checkout", "0065_checkout_tax_error"), + ] + operations = [] diff --git a/saleor/checkout/migrations/0067_auto_20240405_0756.py b/saleor/checkout/migrations/0067_auto_20240405_0756.py new file mode 100644 index 00000000000..d8fd5ccb813 --- /dev/null +++ b/saleor/checkout/migrations/0067_auto_20240405_0756.py @@ -0,0 +1,55 @@ +# Generated by Django 3.2.25 on 2024-04-05 07:56 + +from django.db import migrations +from django.db.models.expressions import Exists, OuterRef +from django.forms import model_to_dict + +# The batch of size 250 takes ~0.5 second and consumes ~20MB memory at peak +ADDRESS_UPDATE_BATCH_SIZE = 250 + + +def queryset_in_batches(queryset): + """Slice a queryset into batches. + + Input queryset should be sorted be pk. + """ + start_pk = 0 + + while True: + qs = queryset.filter(pk__gt=start_pk)[:ADDRESS_UPDATE_BATCH_SIZE] + pks = list(qs.values_list("pk", flat=True)) + if not pks: + break + yield pks + start_pk = pks[-1] + + +def update_checkout_addresses(apps, schema_editor): + Checkout = apps.get_model("checkout", "Checkout") + Warehouse = apps.get_model("warehouse", "Warehouse") + Address = apps.get_model("account", "Address") + + queryset = Checkout.objects.filter( + Exists(Warehouse.objects.filter(address_id=OuterRef("shipping_address_id"))), + ).order_by("pk") + + for batch_pks in queryset_in_batches(queryset): + checkouts = Checkout.objects.filter(pk__in=batch_pks) + addresses = [] + for checkout in checkouts: + if cc_address := checkout.shipping_address: + checkout_address = Address(**model_to_dict(cc_address, exclude=["id"])) + checkout.shipping_address = checkout_address + addresses.append(checkout_address) + Address.objects.bulk_create(addresses, ignore_conflicts=True) + Checkout.objects.bulk_update(checkouts, ["shipping_address"]) + + +class Migration(migrations.Migration): + dependencies = [ + ("checkout", "0066_merge_0063_auto_20240402_1114_0065_checkout_tax_error"), + ] + + operations = [ + migrations.RunPython(update_checkout_addresses, migrations.RunPython.noop), + ] diff --git a/saleor/checkout/models.py b/saleor/checkout/models.py index 04011e966e9..8f4ae38288c 100644 --- a/saleor/checkout/models.py +++ b/saleor/checkout/models.py @@ -1,10 +1,9 @@ """Checkout-related ORM models.""" -from collections.abc import Iterable from datetime import date from decimal import Decimal from operator import attrgetter -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Optional from uuid import uuid4 from django.conf import settings @@ -19,19 +18,14 @@ from ..channel.models import Channel from ..core.models import ModelWithMetadata from ..core.taxes import zero_money -from ..core.weight import zero_weight from ..giftcard.models import GiftCard from ..permission.enums import CheckoutPermissions from ..shipping.models import ShippingMethod from . import CheckoutAuthorizeStatus, CheckoutChargeStatus if TYPE_CHECKING: - from django_measurement import Weight - - from ..order.fetch import OrderLineInfo from ..payment.models import Payment from ..product.models import ProductVariant - from .fetch import CheckoutLineInfo def get_default_country(): @@ -234,19 +228,6 @@ def get_total_gift_cards_balance( return zero_money(currency=self.currency) return Money(balance, self.currency) - def get_total_weight( - self, lines: Union[Iterable["CheckoutLineInfo"], Iterable["OrderLineInfo"]] - ) -> "Weight": - # FIXME: it does not make sense for this method to live in the Checkout model - # since it's used in the Order model as well. We should move it to a separate - # helper. - weights = zero_weight() - for checkout_line_info in lines: - line = checkout_line_info.line - if line.variant: - weights += line.variant.get_weight() * line.quantity - return weights - def get_line(self, variant: "ProductVariant") -> Optional["CheckoutLine"]: """Return a line matching the given variant and data if any.""" matching_lines = (line for line in self if line.variant.pk == variant.pk) diff --git a/saleor/checkout/payment_utils.py b/saleor/checkout/payment_utils.py index 0b2a6a96944..48d0043b9dd 100644 --- a/saleor/checkout/payment_utils.py +++ b/saleor/checkout/payment_utils.py @@ -7,6 +7,7 @@ from django.db.models import Exists, Q from prices import Money +from ..core.db.connection import allow_writer from ..core.taxes import zero_money from ..payment.models import TransactionItem from . import CheckoutAuthorizeStatus, CheckoutChargeStatus @@ -115,7 +116,8 @@ def update_checkout_payment_statuses( fields_to_update.append("charge_status") if fields_to_update: fields_to_update.append("last_change") - checkout.save(update_fields=fields_to_update) + with allow_writer(): + checkout.save(update_fields=fields_to_update) def update_refundable_for_checkout(checkout_pk): diff --git a/saleor/checkout/tests/test_calculations.py b/saleor/checkout/tests/test_calculations.py index a4c1d7dc24a..82ab2e8327b 100644 --- a/saleor/checkout/tests/test_calculations.py +++ b/saleor/checkout/tests/test_calculations.py @@ -24,7 +24,7 @@ from ..calculations import ( _apply_tax_data, _calculate_and_add_tax, - _get_checkout_base_prices, + _set_checkout_base_prices, fetch_checkout_data, ) from ..fetch import CheckoutLineInfo, fetch_checkout_info, fetch_checkout_lines @@ -269,7 +269,7 @@ def test_fetch_checkout_data_flat_rates_and_no_tax_calc_strategy( assert checkout.shipping_tax_rate == Decimal("0.2300") -def test_get_checkout_base_prices_no_charge_taxes_with_voucher( +def test_set_checkout_base_prices_no_charge_taxes_with_voucher( checkout_with_item, voucher_percentage ): # given @@ -312,7 +312,7 @@ def test_get_checkout_base_prices_no_charge_taxes_with_voucher( checkout_info = fetch_checkout_info(checkout, lines, manager) # when - _get_checkout_base_prices(checkout, checkout_info, lines) + _set_checkout_base_prices(checkout, checkout_info, lines) checkout.save() checkout.lines.bulk_update( [line_info.line for line_info in lines], @@ -339,7 +339,7 @@ def test_get_checkout_base_prices_no_charge_taxes_with_voucher( assert line.total_price == checkout.total -def test_get_checkout_base_prices_no_charge_taxes_with_order_promotion( +def test_set_checkout_base_prices_no_charge_taxes_with_order_promotion( checkout_with_item_and_order_discount, ): # given @@ -359,7 +359,7 @@ def test_get_checkout_base_prices_no_charge_taxes_with_order_promotion( checkout_info = fetch_checkout_info(checkout, lines, manager) # when - _get_checkout_base_prices(checkout, checkout_info, lines) + _set_checkout_base_prices(checkout, checkout_info, lines) checkout.save() checkout.lines.bulk_update( [line_info.line for line_info in lines], @@ -634,6 +634,7 @@ def test_fetch_checkout_data_calls_inactive_plugin( fetch_checkout_data(**fetch_kwargs) # then + assert checkout.total.gross.amount > 0 assert checkout_with_items.tax_error == "Empty tax data." diff --git a/saleor/checkout/tests/test_cart.py b/saleor/checkout/tests/test_cart.py index 48747849fff..55011b6f7a6 100644 --- a/saleor/checkout/tests/test_cart.py +++ b/saleor/checkout/tests/test_cart.py @@ -10,7 +10,11 @@ from ...product.models import Category from .. import calculations, utils from ..models import Checkout -from ..utils import add_variant_to_checkout, calculate_checkout_quantity +from ..utils import ( + add_variant_to_checkout, + calculate_checkout_quantity, + calculate_checkout_weight, +) @pytest.fixture @@ -283,4 +287,4 @@ def test_get_total_weight(checkout_with_item): line.quantity = 6 line.save() lines, _ = fetch_checkout_lines(checkout_with_item) - assert checkout_with_item.get_total_weight(lines) == Weight(kg=60) + assert calculate_checkout_weight(lines) == Weight(kg=60) diff --git a/saleor/checkout/utils.py b/saleor/checkout/utils.py index be07aab2b73..94084f66b21 100644 --- a/saleor/checkout/utils.py +++ b/saleor/checkout/utils.py @@ -13,6 +13,7 @@ from prices import Money from ..account.models import User +from ..core.db.connection import allow_writer from ..core.exceptions import ProductNotPublished from ..core.taxes import zero_taxed_money from ..core.utils.promo_code import ( @@ -21,6 +22,7 @@ promo_code_is_voucher, ) from ..core.utils.translations import get_translation +from ..core.weight import zero_weight from ..discount import DiscountType, VoucherType from ..discount.interface import VoucherInfo, fetch_voucher_info from ..discount.models import ( @@ -30,8 +32,8 @@ VoucherCode, ) from ..discount.utils import ( - create_discount_objects_for_catalogue_promotions, - create_discount_objects_for_order_promotions, + create_checkout_discount_objects_for_order_promotions, + create_checkout_line_discount_objects_for_catalogue_promotions, delete_gift_line, get_products_voucher_discount, get_voucher_code_instance, @@ -58,6 +60,8 @@ from .models import Checkout, CheckoutLine, CheckoutMetadata if TYPE_CHECKING: + from measurement.measures import Weight + from ..account.models import Address from ..order.models import Order from .fetch import CheckoutInfo, CheckoutLineInfo @@ -92,7 +96,7 @@ def recalculate_checkout_discounts( Update line and checkout discounts from vouchers and promotions. Create or remove gift line if needed. """ - create_discount_objects_for_catalogue_promotions(lines) + create_checkout_line_discount_objects_for_catalogue_promotions(lines) recalculate_checkout_discount(manager, checkout_info, lines) @@ -688,7 +692,9 @@ def recalculate_checkout_discount( else: remove_voucher_from_checkout(checkout) - create_discount_objects_for_order_promotions(checkout_info, lines, save=True) + create_checkout_discount_objects_for_order_promotions( + checkout_info, lines, save=True + ) def add_promo_code_to_checkout( @@ -864,6 +870,7 @@ def get_valid_internal_shipping_methods_for_checkout( checkout_info.checkout, channel_id=checkout_info.checkout.channel_id, price=subtotal, + shipping_address=checkout_info.shipping_address, country_code=country_code, lines=lines, ) @@ -960,9 +967,7 @@ def cancel_active_payments(checkout: Checkout): def is_shipping_required(lines: Iterable["CheckoutLineInfo"]): """Check if shipping is required for given checkout lines.""" - return any( - line_info.product.product_type.is_shipping_required for line_info in lines - ) + return any(line_info.product_type.is_shipping_required for line_info in lines) def validate_variants_in_checkout_lines(lines: Iterable["CheckoutLineInfo"]): @@ -1012,6 +1017,7 @@ def delete_external_shipping_id(checkout: Checkout, save: bool = False): metadata.save(update_fields=["private_metadata"]) +@allow_writer() def get_or_create_checkout_metadata(checkout: "Checkout") -> CheckoutMetadata: if hasattr(checkout, "metadata_storage"): return checkout.metadata_storage @@ -1024,3 +1030,21 @@ def get_checkout_metadata(checkout: "Checkout"): return checkout.metadata_storage else: return CheckoutMetadata(checkout=checkout) + + +def calculate_checkout_weight(lines: Iterable["CheckoutLineInfo"]) -> "Weight": + weights = zero_weight() + for checkout_line_info in lines: + variant = checkout_line_info.variant + if variant: + line_weight = get_checkout_line_weight(checkout_line_info) + weights += line_weight * checkout_line_info.line.quantity + return weights + + +def get_checkout_line_weight(line_info: "CheckoutLineInfo"): + return ( + line_info.variant.weight + or line_info.product.weight + or line_info.product_type.weight + ) diff --git a/saleor/core/db/connection.py b/saleor/core/db/connection.py new file mode 100644 index 00000000000..e359caa6878 --- /dev/null +++ b/saleor/core/db/connection.py @@ -0,0 +1,131 @@ +import logging +import traceback +from contextlib import contextmanager + +from django.conf import settings +from django.core.management.color import color_style +from django.db import connections +from django.db.backends.base.base import BaseDatabaseWrapper + +from ...graphql.core.context import SaleorContext, get_database_connection_name + +logger = logging.getLogger(__name__) + +writer = settings.DATABASE_CONNECTION_DEFAULT_NAME +replica = settings.DATABASE_CONNECTION_REPLICA_NAME + +# Limit the number of frames in the traceback in `log_writer_usage_middleware` to avoid +# excessive log size. +TRACEBACK_LIMIT = 20 + +UNSAFE_WRITER_ACCESS_MSG = ( + "Unsafe access to the writer DB detected. Call `using()` on the `QuerySet` " + "to utilize a replica DB, or employ the `allow_writer` context manager to " + "explicitly permit access to the writer." +) + + +class UnsafeDBUsageError(Exception): + pass + + +class UnsafeWriterAccessError(UnsafeDBUsageError): + pass + + +class UnsafeReplicaUsageError(UnsafeDBUsageError): + pass + + +@contextmanager +def allow_writer(): + """Context manager that allows write access to the default database connection. + + This context manager works in conjunction with the `restrict_writer_middleware` and + `log_writer_usage_middleware` middlewares. If any of these middlewares are enabled, + use the `allow_writer` context manager to allow write access to the default + database. Otherwise an error will be raised or a log message will be emitted. + """ + + default_connection = connections[settings.DATABASE_CONNECTION_DEFAULT_NAME] + + # Check if we are already in an allow_writer block. If so we don't need to do + # anything and we don't have to close access to the writer at the end. + in_allow_writer_block = getattr(default_connection, "_allow_writer", False) + if not in_allow_writer_block: + setattr(default_connection, "_allow_writer", True) + try: + yield + finally: + if not in_allow_writer_block: + # Close writer access when exiting the outermost allow_writer block. + setattr(default_connection, "_allow_writer", False) + + +@contextmanager +def allow_writer_in_context(context: SaleorContext): + """Context manager that allows write access to the default database connection in a context (SaleorContext). + + This is a helper context manager that conditionally allows write access based on the + database connection name in the given context. + """ + conn_name = get_database_connection_name(context) + if conn_name == settings.DATABASE_CONNECTION_DEFAULT_NAME: + with allow_writer(): + yield + else: + yield + + +def restrict_writer_middleware(get_response): + """Middleware that restricts write access to the default database connection. + + This middleware will raise an error if a write operation is attempted on the default + database connection. To allow writes, use the `allow_writer` context manager. Make + sure that writer is not used accidentally and always use the `using` queryset method + with proper database connection name. + """ + + def middleware(request): + with connections[writer].execute_wrapper(_restrict_writer): + with connections[replica].execute_wrapper(_restrict_writer): + return get_response(request) + + return middleware + + +def _restrict_writer(execute, sql, params, many, context): + conn: BaseDatabaseWrapper = context["connection"] + allow_writer = getattr(conn, "_allow_writer", False) + if conn.alias == writer and not allow_writer: + raise UnsafeWriterAccessError(f"{UNSAFE_WRITER_ACCESS_MSG} SQL: {sql}") + return execute(sql, params, many, context) + + +def log_writer_usage_middleware(get_response): + """Middleware that logs write access to the default database connection. + + This is similar to the `restrict_writer_middleware` middleware, but instead of + raising an error, it logs a message when a write operation is attempted on the + default database connection. + """ + + def middleware(request): + with connections[writer].execute_wrapper(_log_writer_usage): + return get_response(request) + + return middleware + + +def _log_writer_usage(execute, sql, params, many, context): + conn: BaseDatabaseWrapper = context["connection"] + allow_writer = getattr(conn, "_allow_writer", False) + if conn.alias == writer and not allow_writer: + stack_trace = traceback.extract_stack(limit=TRACEBACK_LIMIT) + error_msg = color_style().NOTICE(UNSAFE_WRITER_ACCESS_MSG) + log_msg = ( + f"{error_msg} SQL: {sql} \n" + f"Traceback: \n{''.join(traceback.format_list(stack_trace))}" + ) + logger.error(log_msg) + return execute(sql, params, many, context) diff --git a/saleor/order/migrations/tasks/__init__.py b/saleor/core/db/tests/__init__.py similarity index 100% rename from saleor/order/migrations/tasks/__init__.py rename to saleor/core/db/tests/__init__.py diff --git a/saleor/core/db/tests/test_connection.py b/saleor/core/db/tests/test_connection.py new file mode 100644 index 00000000000..6137d3af642 --- /dev/null +++ b/saleor/core/db/tests/test_connection.py @@ -0,0 +1,88 @@ +from unittest.mock import patch + +import pytest +from django.db import connections + +from ....graphql.context import SaleorContext +from ....tests.models import Book +from ..connection import ( + UnsafeWriterAccessError, + _log_writer_usage, + _restrict_writer, + allow_writer, + allow_writer_in_context, +) + + +def test_allow_writer(settings): + default_connection = connections[settings.DATABASE_CONNECTION_DEFAULT_NAME] + assert not getattr(default_connection, "_allow_writer", False) + + with allow_writer(): + assert hasattr(default_connection, "_allow_writer") + assert default_connection._allow_writer + + +def test_allow_writer_yield_exception(settings): + default_connection = connections[settings.DATABASE_CONNECTION_DEFAULT_NAME] + + def example_function(): + raise Exception() + + try: + with allow_writer(): + example_function() + except Exception: + pass + + assert hasattr(default_connection, "_allow_writer") + assert not default_connection._allow_writer + + +def test_allow_writer_in_context_writer(settings): + context = SaleorContext() + context.allow_replica = False + + with allow_writer_in_context(context): + connection = connections[settings.DATABASE_CONNECTION_DEFAULT_NAME] + assert hasattr(connection, "_allow_writer") + assert connection._allow_writer + + +@patch("saleor.core.db.connection.get_database_connection_name") +def test_allow_writer_in_context_replica(mocked_get_database_connection_name, settings): + mocked_get_database_connection_name.return_value = "replica" + + context = SaleorContext() + context.allow_replica = True + + with allow_writer_in_context(context): + connection = connections[settings.DATABASE_CONNECTION_DEFAULT_NAME] + assert not getattr(connection, "_allow_writer") + + +def test_restrict_writer_raises_error(settings): + connection = connections[settings.DATABASE_CONNECTION_DEFAULT_NAME] + + with pytest.raises(UnsafeWriterAccessError): + with connection.execute_wrapper(_restrict_writer): + Book.objects.first() + + +def test_restrict_writer_in_allow_writer(settings): + connection = connections[settings.DATABASE_CONNECTION_DEFAULT_NAME] + + with connection.execute_wrapper(_restrict_writer): + with allow_writer(): + Book.objects.first() + + +def test_log_writer_usage(settings, caplog): + connection = connections[settings.DATABASE_CONNECTION_DEFAULT_NAME] + + with connection.execute_wrapper(_log_writer_usage): + Book.objects.first() + + assert caplog.records + msg = caplog.records[0].getMessage() + assert "Unsafe access to the writer DB detected" in msg diff --git a/saleor/core/tests/test_dataloaders.py b/saleor/core/tests/test_dataloaders.py index 943cdda80bc..5dd142b08fd 100644 --- a/saleor/core/tests/test_dataloaders.py +++ b/saleor/core/tests/test_dataloaders.py @@ -13,6 +13,7 @@ def test_plugins_manager_loader_loads_requestor_in_plugin(rf, customer_user, set handler.load_middleware() handler.get_response(request) manager = get_plugin_manager_promise(request).get() + manager.get_all_plugins() plugin = manager.all_plugins.pop() assert isinstance(plugin.requestor, type(customer_user)) @@ -31,6 +32,7 @@ def test_plugins_manager_loader_requestor_in_plugin_when_no_app_and_user_in_req_ handler.load_middleware() handler.get_response(request) manager = get_plugin_manager_promise(request).get() + manager.get_all_plugins() plugin = manager.all_plugins.pop() assert not plugin.requestor diff --git a/saleor/discount/interface.py b/saleor/discount/interface.py index 1919f332698..d835c58b2fa 100644 --- a/saleor/discount/interface.py +++ b/saleor/discount/interface.py @@ -53,14 +53,22 @@ class VariantPromotionRuleInfo(NamedTuple): def fetch_variant_rules_info( variant_channel_listing: "ProductVariantChannelListing", translation_language_code: str, -): +) -> list[VariantPromotionRuleInfo]: listings_rules = ( variant_channel_listing.variantlistingpromotionrule.all() if variant_channel_listing else [] ) + rules_info = [] - for listing_promotion_rule in listings_rules: + if listings_rules: + # Before introducing unique_type on discount models, there was possibility + # to have multiple catalogue discount associated with single line. In such a + # case, we should pick the best discount (with the highest discount amount) + listing_promotion_rule = max( + list(listings_rules), + key=lambda x: x.discount_amount, + ) promotion = listing_promotion_rule.promotion_rule.promotion promotion_translation, rule_translation = get_rule_translations( diff --git a/saleor/discount/migrations/0078_add_unique_type.py b/saleor/discount/migrations/0078_add_unique_type.py new file mode 100644 index 00000000000..85dd280310e --- /dev/null +++ b/saleor/discount/migrations/0078_add_unique_type.py @@ -0,0 +1,44 @@ +# Generated by Django 3.2.24 on 2024-04-08 11:16 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("discount", "0077_merge_20240307_1217"), + ] + + operations = [ + migrations.AddField( + model_name="checkoutlinediscount", + name="unique_type", + field=models.CharField( + blank=True, + choices=[ + ("sale", "Sale"), + ("voucher", "Voucher"), + ("manual", "Manual"), + ("promotion", "Promotion"), + ("order_promotion", "Order promotion"), + ], + max_length=64, + null=True, + ), + ), + migrations.AddField( + model_name="orderlinediscount", + name="unique_type", + field=models.CharField( + blank=True, + choices=[ + ("sale", "Sale"), + ("voucher", "Voucher"), + ("manual", "Manual"), + ("promotion", "Promotion"), + ("order_promotion", "Order promotion"), + ], + max_length=64, + null=True, + ), + ), + ] diff --git a/saleor/discount/migrations/0079_add_index_for_unique_type.py b/saleor/discount/migrations/0079_add_index_for_unique_type.py new file mode 100644 index 00000000000..b97205ae32c --- /dev/null +++ b/saleor/discount/migrations/0079_add_index_for_unique_type.py @@ -0,0 +1,33 @@ +# Generated by Django 3.2.24 on 2024-04-08 11:18 + +from django.db import migrations + + +class Migration(migrations.Migration): + atomic = False + dependencies = [ + ("discount", "0078_add_unique_type"), + ] + + operations = [ + migrations.RunSQL( + sql=""" + CREATE UNIQUE INDEX CONCURRENTLY checkoutlinediscount_line_id_unique_type_idx + ON discount_checkoutlinediscount USING btree (line_id, unique_type); + """, + reverse_sql=""" + DROP INDEX CONCURRENTLY IF EXISTS + checkoutlinediscount_line_id_unique_type_idx; + """, + ), + migrations.RunSQL( + sql=""" + CREATE UNIQUE INDEX CONCURRENTLY orderlinediscount_line_id_unique_type_idx + ON discount_orderlinediscount USING btree (line_id, unique_type); + """, + reverse_sql=""" + DROP INDEX CONCURRENTLY IF EXISTS + orderlinediscount_line_id_unique_type_idx; + """, + ), + ] diff --git a/saleor/discount/migrations/0080_add_unique_type_constraint.py b/saleor/discount/migrations/0080_add_unique_type_constraint.py new file mode 100644 index 00000000000..ce3c8e8e000 --- /dev/null +++ b/saleor/discount/migrations/0080_add_unique_type_constraint.py @@ -0,0 +1,54 @@ +# Generated by Django 3.2.24 on 2024-04-08 11:20 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("discount", "0079_add_index_for_unique_type"), + ] + + operations = [ + migrations.SeparateDatabaseAndState( + database_operations=[ + migrations.RunSQL( + sql=""" + ALTER TABLE discount_checkoutlinediscount + ADD CONSTRAINT unique_checkoutline_discount_type + UNIQUE USING INDEX checkoutlinediscount_line_id_unique_type_idx; + """, + reverse_sql=""" + ALTER TABLE discount_checkoutlinediscount DROP CONSTRAINT IF EXISTS + unique_checkoutline_discount_type; + """, + ), + migrations.RunSQL( + sql=""" + ALTER TABLE discount_orderlinediscount + ADD CONSTRAINT unique_orderline_discount_type + UNIQUE USING INDEX orderlinediscount_line_id_unique_type_idx; + """, + reverse_sql=""" + ALTER TABLE discount_orderlinediscount DROP CONSTRAINT IF EXISTS + unique_orderline_discount_type; + """, + ), + ], + state_operations=[ + migrations.AddConstraint( + model_name="checkoutlinediscount", + constraint=models.UniqueConstraint( + fields=("line_id", "unique_type"), + name="unique_checkoutline_discount_type", + ), + ), + migrations.AddConstraint( + model_name="orderlinediscount", + constraint=models.UniqueConstraint( + fields=("line_id", "unique_type"), + name="unique_orderline_discount_type", + ), + ), + ], + ) + ] diff --git a/saleor/discount/models.py b/saleor/discount/models.py index 61461e6d9e6..2027e4b7058 100644 --- a/saleor/discount/models.py +++ b/saleor/discount/models.py @@ -539,6 +539,13 @@ class OrderLineDiscount(BaseDiscount): null=True, on_delete=models.CASCADE, ) + # Saleor in version 3.19 and below, doesn't have any unique constraint applied on + # discounts for checkout/order. To not have an impact on existing DB objects, + # the new field `unique_type` will be used for new discount records. + # This will ensure that we always apply a single specific discount type. + unique_type = models.CharField( + max_length=64, null=True, blank=True, choices=DiscountType.CHOICES + ) class Meta: indexes = [ @@ -548,6 +555,12 @@ class Meta: GinIndex(fields=["voucher_code"], name="orderlinedisc_voucher_code_idx"), ] ordering = ("created_at", "id") + constraints = [ + models.UniqueConstraint( + fields=["line_id", "unique_type"], + name="unique_orderline_discount_type", + ), + ] class CheckoutDiscount(BaseDiscount): @@ -578,6 +591,13 @@ class CheckoutLineDiscount(BaseDiscount): null=True, on_delete=models.CASCADE, ) + # Saleor in version 3.19 and below, doesn't have any unique constraint applied on + # discounts for checkout/order. To not have an impact on existing DB objects, + # the new field `unique_type` will be used for new discount records. + # This will ensure that we always apply a single specific discount type. + unique_type = models.CharField( + max_length=64, null=True, blank=True, choices=DiscountType.CHOICES + ) class Meta: indexes = [ @@ -587,6 +607,12 @@ class Meta: GinIndex(fields=["voucher_code"], name="checklinedisc_voucher_code_idx"), ] ordering = ("created_at", "id") + constraints = [ + models.UniqueConstraint( + fields=["line_id", "unique_type"], + name="unique_checkoutline_discount_type", + ), + ] class PromotionEvent(models.Model): diff --git a/saleor/discount/tests/test_utils/fixtures.py b/saleor/discount/tests/test_utils/fixtures.py index 7cc5a0d4d7d..b3244fb7b1a 100644 --- a/saleor/discount/tests/test_utils/fixtures.py +++ b/saleor/discount/tests/test_utils/fixtures.py @@ -1,7 +1,16 @@ +from decimal import Decimal + +import graphene import pytest +from prices import TaxedMoney from ....checkout.fetch import fetch_checkout_info, fetch_checkout_lines +from ....core.taxes import zero_money +from ....discount import RewardType, RewardValueType +from ....order import OrderStatus from ....plugins.manager import get_plugins_manager +from ....product.models import VariantChannelListingPromotionRule +from ....warehouse.models import Stock @pytest.fixture @@ -51,3 +60,80 @@ def checkout_lines_with_multiple_quantity_info( lines_info, _ = fetch_checkout_lines(checkout_with_items) return lines_info + + +@pytest.fixture +def draft_order_and_promotions( + order_with_lines, + order_promotion_without_rules, + catalogue_promotion_without_rules, + channel_USD, +): + # given + order = order_with_lines + line_1 = order.lines.get(quantity=3) + line_2 = order.lines.get(quantity=2) + + # prepare catalogue promotions + catalogue_promotion = catalogue_promotion_without_rules + variant_1 = line_1.variant + variant_2 = line_2.variant + rule_catalogue = catalogue_promotion.rules.create( + name="Catalogue rule fixed", + catalogue_predicate={ + "variantPredicate": { + "ids": [graphene.Node.to_global_id("ProductVariant", variant_2.id)] + } + }, + reward_value_type=RewardValueType.FIXED, + reward_value=Decimal(3), + ) + rule_catalogue.channels.add(channel_USD) + + listing = variant_2.channel_listings.first() + listing.discounted_price_amount = Decimal(17) + listing.save(update_fields=["discounted_price_amount"]) + + currency = order.currency + VariantChannelListingPromotionRule.objects.create( + variant_channel_listing=listing, + promotion_rule=rule_catalogue, + discount_amount=Decimal(3), + currency=currency, + ) + + # prepare order promotion - subtotal + order_promotion = order_promotion_without_rules + rule_total = order_promotion.rules.create( + name="Fixed subtotal rule", + order_predicate={ + "discountedObjectPredicate": {"baseSubtotalPrice": {"range": {"gte": 10}}} + }, + reward_value_type=RewardValueType.FIXED, + reward_value=Decimal(25), + reward_type=RewardType.SUBTOTAL_DISCOUNT, + ) + rule_total.channels.add(channel_USD) + + # prepare order promotion - gift + rule_gift = order_promotion.rules.create( + name="Gift subtotal rule", + order_predicate={ + "discountedObjectPredicate": {"baseSubtotalPrice": {"range": {"gte": 10}}} + }, + reward_type=RewardType.GIFT, + ) + rule_gift.channels.add(channel_USD) + rule_gift.gifts.set([variant_1, variant_2]) + Stock.objects.update(quantity=100) + + # reset prices + order.total = TaxedMoney(net=zero_money(currency), gross=zero_money(currency)) + order.subtotal = TaxedMoney(net=zero_money(currency), gross=zero_money(currency)) + order.undiscounted_total = TaxedMoney( + net=zero_money(currency), gross=zero_money(currency) + ) + order.status = OrderStatus.DRAFT + order.save() + + return order, rule_catalogue, rule_total, rule_gift diff --git a/saleor/discount/tests/test_utils/test_copy_unit_discount_data_to_order_line.py b/saleor/discount/tests/test_utils/test_copy_unit_discount_data_to_order_line.py new file mode 100644 index 00000000000..cd832bdb518 --- /dev/null +++ b/saleor/discount/tests/test_utils/test_copy_unit_discount_data_to_order_line.py @@ -0,0 +1,96 @@ +from decimal import Decimal + +import graphene + +from ....order.fetch import fetch_draft_order_lines_info +from ... import DiscountType, DiscountValueType +from ...models import PromotionRule +from ...utils import _copy_unit_discount_data_to_order_line + + +def test_copy_unit_discount_data_to_order_line_multiple_discounts( + order_with_lines_and_catalogue_promotion, +): + # given + order = order_with_lines_and_catalogue_promotion + rule = PromotionRule.objects.get() + rule_reward_value = rule.reward_value + promotion_id = graphene.Node.to_global_id("Promotion", rule.promotion_id) + rule_discount_reason = f"Promotion: {promotion_id}" + + line = order.lines.first() + rule_discount = line.discounts.get() + rule_discount.reason = rule_discount_reason + rule_discount.save(update_fields=["reason"]) + + manual_reward_value = Decimal("2") + manual_discount_reason = "Manual discount" + line.discounts.create( + type=DiscountType.MANUAL, + value_type=DiscountValueType.FIXED, + value=manual_reward_value, + amount_value=manual_reward_value * line.quantity, + currency=order.currency, + reason=manual_discount_reason, + ) + + assert line.discounts.count() == 2 + lines_info = fetch_draft_order_lines_info(order) + + # when + _copy_unit_discount_data_to_order_line(lines_info) + + # then + line = lines_info[0].line + assert line.unit_discount_amount == rule_reward_value + manual_reward_value + assert rule_discount_reason in line.unit_discount_reason + assert manual_discount_reason in line.unit_discount_reason + assert line.unit_discount_type == DiscountValueType.FIXED + assert line.unit_discount_value == line.unit_discount_amount + + +def test_copy_unit_discount_data_to_order_line_single_discount( + order_with_lines_and_catalogue_promotion, +): + # given + order = order_with_lines_and_catalogue_promotion + rule = PromotionRule.objects.get() + rule_reward_value = rule.reward_value + promotion_id = graphene.Node.to_global_id("Promotion", rule.promotion_id) + rule_discount_reason = f"Promotion: {promotion_id}" + + line = order.lines.first() + rule_discount = line.discounts.get() + rule_discount.reason = rule_discount_reason + rule_discount.save(update_fields=["reason"]) + + assert line.discounts.count() == 1 + lines_info = fetch_draft_order_lines_info(order) + + # when + _copy_unit_discount_data_to_order_line(lines_info) + + # then + line = lines_info[0].line + assert line.unit_discount_amount == rule_reward_value + assert line.unit_discount_reason == rule_discount_reason + assert line.unit_discount_type == rule.reward_value_type + assert line.unit_discount_value == rule_reward_value + + +def test_copy_unit_discount_data_to_order_line_no_discount(order_with_lines): + # given + order = order_with_lines + line = order.lines.first() + assert not line.discounts.exists() + lines_info = fetch_draft_order_lines_info(order) + + # when + _copy_unit_discount_data_to_order_line(lines_info) + + # then + line = lines_info[0].line + assert line.unit_discount_amount == Decimal(0) + assert not line.unit_discount_reason + assert line.unit_discount_type == DiscountValueType.FIXED + assert line.unit_discount_value == Decimal(0) diff --git a/saleor/discount/tests/test_utils/test_create_or_update_discount_objects_from_promotion_for_checkout.py b/saleor/discount/tests/test_utils/test_create_or_update_discount_objects_from_promotion_for_checkout.py index 92d59f356db..2e5b940ab30 100644 --- a/saleor/discount/tests/test_utils/test_create_or_update_discount_objects_from_promotion_for_checkout.py +++ b/saleor/discount/tests/test_utils/test_create_or_update_discount_objects_from_promotion_for_checkout.py @@ -21,9 +21,9 @@ from ... import DiscountType, RewardType, RewardValueType from ...models import CheckoutDiscount, CheckoutLineDiscount, PromotionRule from ...utils import ( - _create_or_update_checkout_discount, _get_best_gift_reward, - create_discount_objects_for_order_promotions, + create_checkout_discount_objects_for_order_promotions, + create_checkout_line_discount_objects_for_catalogue_promotions, create_or_update_discount_objects_from_promotion_for_checkout, ) @@ -115,7 +115,14 @@ def test_create_fixed_discount( == discount_from_db.name == f"{catalogue_promotion_without_rules.name}: {rule.name}" ) - assert discount_from_info.reason == discount_from_db.reason is None + promotion_id = graphene.Node.to_global_id( + "Promotion", catalogue_promotion_without_rules.pk + ) + assert ( + discount_from_info.reason + == discount_from_db.reason + == f"Promotion: {promotion_id}" + ) assert discount_from_info.promotion_rule == discount_from_db.promotion_rule == rule assert discount_from_info.voucher == discount_from_db.voucher is None assert ( @@ -123,11 +130,90 @@ def test_create_fixed_discount( == discount_from_db.translated_name == promotion_translation_fr.name ) + assert ( + discount_from_info.unique_type + == discount_from_db.unique_type + == DiscountType.PROMOTION + ) for checkout_line_info in checkout_lines_info[1:]: assert not checkout_line_info.discounts +@freeze_time("2020-12-12 12:00:00") +def test_update_catalogue_discount( + checkout_info, + checkout_lines_info, + catalogue_promotion_without_rules, + promotion_translation_fr, +): + # given + line_info1 = checkout_lines_info[0] + product_line1 = line_info1.product + + actual_reward_value = Decimal("5") + discount_to_update = line_info1.line.discounts.create( + type=DiscountType.PROMOTION, + value_type=RewardValueType.FIXED, + value=actual_reward_value, + name="Fixed 5 catalogue discount", + currency=line_info1.channel.currency_code, + amount_value=actual_reward_value * line_info1.line.quantity, + ) + checkout_lines_info[0].discounts.append(discount_to_update) + + reward_value = Decimal("7") + assert reward_value > actual_reward_value + rule = catalogue_promotion_without_rules.rules.create( + name="Percentage promotion rule", + catalogue_predicate={ + "productPredicate": { + "ids": [graphene.Node.to_global_id("Product", product_line1.id)] + } + }, + reward_value_type=RewardValueType.FIXED, + reward_value=reward_value, + ) + rule.channels.add(line_info1.channel) + + listing = line_info1.channel_listing + discounted_price = listing.price.amount - reward_value + listing.discounted_price_amount = discounted_price + listing.save(update_fields=["discounted_price_amount"]) + + listing_promotion_rule = VariantChannelListingPromotionRule.objects.create( + variant_channel_listing=listing, + promotion_rule=rule, + discount_amount=reward_value, + currency=line_info1.channel.currency_code, + ) + line_info1.rules_info = [ + VariantPromotionRuleInfo( + rule=rule, + variant_listing_promotion_rule=listing_promotion_rule, + promotion=catalogue_promotion_without_rules, + promotion_translation=promotion_translation_fr, + rule_translation=None, + ) + ] + + # when + create_or_update_discount_objects_from_promotion_for_checkout( + checkout_info, checkout_lines_info + ) + + # then + assert len(line_info1.discounts) == 1 + assert CheckoutLineDiscount.objects.count() == 1 + + discount = line_info1.discounts[0] + assert discount.id == discount_to_update.id + assert discount.value == reward_value + assert discount.promotion_rule_id == rule.id + assert discount.amount_value == reward_value * line_info1.line.quantity + assert discount.unique_type == DiscountType.PROMOTION + + @freeze_time("2020-12-12 12:00:00") def test_create_fixed_discount_multiple_quantity_in_lines( checkout_info, @@ -203,7 +289,14 @@ def test_create_fixed_discount_multiple_quantity_in_lines( == discount_from_db.name == catalogue_promotion_without_rules.name ) - assert discount_from_info.reason == discount_from_db.reason is None + promotion_id = graphene.Node.to_global_id( + "Promotion", catalogue_promotion_without_rules.pk + ) + assert ( + discount_from_info.reason + == discount_from_db.reason + == f"Promotion: {promotion_id}" + ) assert discount_from_info.promotion_rule == discount_from_db.promotion_rule == rule assert discount_from_info.voucher == discount_from_db.voucher is None @@ -358,7 +451,14 @@ def test_create_percentage_discount( == discount_from_db.name == f"{catalogue_promotion_without_rules.name}: {rule.name}" ) - assert discount_from_info.reason == discount_from_db.reason is None + promotion_id = graphene.Node.to_global_id( + "Promotion", catalogue_promotion_without_rules.pk + ) + assert ( + discount_from_info.reason + == discount_from_db.reason + == f"Promotion: {promotion_id}" + ) assert discount_from_info.promotion_rule == discount_from_db.promotion_rule == rule assert discount_from_info.voucher == discount_from_db.voucher is None @@ -441,7 +541,14 @@ def test_create_percentage_discount_multiple_quantity_in_lines( assert discount_from_info.currency == discount_from_db.currency == "USD" discount_name = f"{catalogue_promotion_without_rules.name}: {rule.name}" assert discount_from_info.name == discount_from_db.name == discount_name - assert discount_from_info.reason == discount_from_db.reason is None + promotion_id = graphene.Node.to_global_id( + "Promotion", catalogue_promotion_without_rules.pk + ) + assert ( + discount_from_info.reason + == discount_from_db.reason + == f"Promotion: {promotion_id}" + ) assert discount_from_info.promotion_rule == discount_from_db.promotion_rule == rule assert discount_from_info.voucher == discount_from_db.voucher is None @@ -449,127 +556,6 @@ def test_create_percentage_discount_multiple_quantity_in_lines( assert not checkout_line_info.discounts -def test_create_discount_multiple_rules_applied( - checkout_info, checkout_lines_info, catalogue_promotion_without_rules -): - # given - line_info1 = checkout_lines_info[0] - product_line1 = line_info1.product - - reward_value_1 = Decimal("2") - reward_value_2 = Decimal("10") - rule_1, rule_2 = PromotionRule.objects.bulk_create( - [ - PromotionRule( - name="Percentage promotion rule 1", - promotion=catalogue_promotion_without_rules, - reward_value_type=RewardValueType.FIXED, - reward_value=reward_value_1, - catalogue_predicate={ - "productPredicate": { - "ids": [graphene.Node.to_global_id("Product", product_line1.id)] - } - }, - ), - PromotionRule( - name="Percentage promotion rule 2", - promotion=catalogue_promotion_without_rules, - reward_value_type=RewardValueType.PERCENTAGE, - reward_value=reward_value_2, - catalogue_predicate={ - "variantPredicate": { - "ids": [ - graphene.Node.to_global_id( - "ProductVariant", line_info1.variant.id - ) - ] - } - }, - ), - ] - ) - - rule_1.channels.add(line_info1.channel) - rule_2.channels.add(line_info1.channel) - - listing = line_info1.channel_listing - discount_amount_2 = reward_value_2 / 100 * listing.price.amount - discounted_price = listing.price.amount - reward_value_1 - discount_amount_2 - listing.discounted_price_amount = discounted_price - listing.save(update_fields=["discounted_price_amount"]) - - ( - listing_promotion_rule_1, - listing_promotion_rule_2, - ) = VariantChannelListingPromotionRule.objects.bulk_create( - [ - VariantChannelListingPromotionRule( - variant_channel_listing=listing, - promotion_rule=rule_1, - discount_amount=reward_value_1, - currency=line_info1.channel.currency_code, - ), - VariantChannelListingPromotionRule( - variant_channel_listing=listing, - promotion_rule=rule_2, - discount_amount=discount_amount_2, - currency=line_info1.channel.currency_code, - ), - ] - ) - - line_info1.rules_info = [ - VariantPromotionRuleInfo( - rule=rule_1, - variant_listing_promotion_rule=listing_promotion_rule_1, - promotion=catalogue_promotion_without_rules, - promotion_translation=None, - rule_translation=None, - ), - VariantPromotionRuleInfo( - rule=rule_2, - variant_listing_promotion_rule=listing_promotion_rule_2, - promotion=catalogue_promotion_without_rules, - promotion_translation=None, - rule_translation=None, - ), - ] - - # when - create_or_update_discount_objects_from_promotion_for_checkout( - checkout_info, checkout_lines_info - ) - - # then - assert len(line_info1.discounts) == 2 - discount_for_rule_1 = line_info1.line.discounts.get(promotion_rule=rule_1) - discount_for_rule_2 = line_info1.line.discounts.get(promotion_rule=rule_2) - - assert discount_for_rule_1.line == line_info1.line - assert discount_for_rule_2.line == line_info1.line - - assert discount_for_rule_1.type == DiscountType.PROMOTION - assert discount_for_rule_2.type == DiscountType.PROMOTION - - assert discount_for_rule_1.value_type == RewardValueType.FIXED - assert discount_for_rule_2.value_type == RewardValueType.PERCENTAGE - - assert discount_for_rule_1.value == reward_value_1 - assert discount_for_rule_2.value == reward_value_2 - - assert discount_for_rule_1.amount_value == reward_value_1 - assert discount_for_rule_2.amount_value == discount_amount_2 - - assert discount_for_rule_1.currency == "USD" - assert discount_for_rule_2.currency == "USD" - - assert discount_for_rule_1.promotion_rule == rule_1 - assert discount_for_rule_2.promotion_rule == rule_2 - - for checkout_line_info in checkout_lines_info[1:]: - assert not checkout_line_info.discounts - - def test_two_promotions_applied_to_two_different_lines( checkout_info, checkout_lines_info, catalogue_promotion_without_rules ): @@ -691,7 +677,14 @@ def test_two_promotions_applied_to_two_different_lines( == discount_from_db_1.name == f"{catalogue_promotion_without_rules.name}: {rule_1.name}" ) - assert discount_from_info_1.reason == discount_from_db_1.reason is None + promotion_id = graphene.Node.to_global_id( + "Promotion", catalogue_promotion_without_rules.pk + ) + assert ( + discount_from_info_1.reason + == discount_from_db_1.reason + == f"Promotion: {promotion_id}" + ) assert ( discount_from_info_1.promotion_rule == discount_from_db_1.promotion_rule @@ -722,7 +715,14 @@ def test_two_promotions_applied_to_two_different_lines( == discount_from_db_2.name == f"{catalogue_promotion_without_rules.name}: {rule_2.name}" ) - assert discount_from_info_2.reason == discount_from_db_2.reason is None + promotion_id = graphene.Node.to_global_id( + "Promotion", catalogue_promotion_without_rules.pk + ) + assert ( + discount_from_info_2.reason + == discount_from_db_2.reason + == f"Promotion: {promotion_id}" + ) assert ( discount_from_info_2.promotion_rule == discount_from_db_2.promotion_rule @@ -815,7 +815,14 @@ def test_create_percentage_discount_1_cent_variant_on_10_percentage_discount( == discount_from_db.name == f"{catalogue_promotion_without_rules.name}: {rule.name}" ) - assert discount_from_info.reason == discount_from_db.reason is None + promotion_id = graphene.Node.to_global_id( + "Promotion", catalogue_promotion_without_rules.pk + ) + assert ( + discount_from_info.reason + == discount_from_db.reason + == f"Promotion: {promotion_id}" + ) assert discount_from_info.promotion_rule == discount_from_db.promotion_rule == rule assert discount_from_info.voucher == discount_from_db.voucher is None @@ -2103,10 +2110,12 @@ def call_before_creating_discount_object(*args, **kwargs): ) with before_after.before( - "saleor.discount.utils._create_or_update_checkout_discount", + "saleor.discount.utils.create_checkout_discount_objects_for_order_promotions", call_before_creating_discount_object, ): - create_discount_objects_for_order_promotions(checkout_info, checkout_lines_info) + create_checkout_discount_objects_for_order_promotions( + checkout_info, checkout_lines_info + ) # then discounts = list(checkout_info.checkout.discounts.all()) @@ -2114,16 +2123,14 @@ def call_before_creating_discount_object(*args, **kwargs): assert discounts[0].amount_value == reward_value -def test_create_or_update_checkout_discount_race_condition( +def test_create_or_update_order_discount_race_condition( checkout_info, checkout_lines_info, catalogue_promotion_without_rules, ): # given promotion = catalogue_promotion_without_rules - checkout = checkout_info.checkout channel = checkout_info.channel - currency = channel.currency_code reward_value = Decimal("2") rule = promotion.rules.create( @@ -2141,20 +2148,14 @@ def test_create_or_update_checkout_discount_race_condition( rule.channels.add(channel) def call_update(*args, **kwargs): - _create_or_update_checkout_discount( - checkout, + create_checkout_discount_objects_for_order_promotions( checkout_info, checkout_lines_info, - rule, - reward_value, - None, - currency, - promotion, - True, + save=True, ) with before_after.before( - "saleor.discount.utils.get_rule_translations", call_update + "saleor.discount.utils._set_checkout_base_prices", call_update ): call_update() @@ -2163,37 +2164,23 @@ def call_update(*args, **kwargs): assert len(discounts) == 1 -def test_create_or_update_checkout_discount_gift_reward_race_condition( +def test_create_or_update_order_discount_gift_reward_race_condition( checkout_info, checkout_lines_info, gift_promotion_rule, ): # given - rule = gift_promotion_rule - promotion = gift_promotion_rule.promotion checkout = checkout_info.checkout - channel = checkout_info.channel - currency = channel.currency_code - - variants = gift_promotion_rule.gifts.all() - variant_listings = ProductVariantChannelListing.objects.filter(variant__in=variants) - listing = max(list(variant_listings), key=lambda x: x.discounted_price_amount) def call_update(*args, **kwargs): - _create_or_update_checkout_discount( - checkout, + create_checkout_discount_objects_for_order_promotions( checkout_info, checkout_lines_info, - rule, - listing.discounted_price_amount, - listing, - currency, - promotion, - True, + save=True, ) with before_after.before( - "saleor.discount.utils.get_rule_translations", call_update + "saleor.discount.utils._set_checkout_base_prices", call_update ): call_update() @@ -2298,3 +2285,27 @@ def test_get_best_gift_reward_no_variants_in_channel(gift_promotion_rule, channe # then assert rule is None assert listing is None + + +def test_create_checkout_line_discount_objects_for_catalogue_promotions_race_condition( + checkout_with_item_on_promotion, + plugins_manager, +): + # given + checkout = checkout_with_item_on_promotion + CheckoutLineDiscount.objects.all().delete() + + # when + def call_before_creating_catalogue_line_discount(*args, **kwargs): + lines_info, _ = fetch_checkout_lines(checkout) + create_checkout_line_discount_objects_for_catalogue_promotions(lines_info) + + with before_after.before( + "saleor.discount.utils.prepare_line_discount_objects_for_catalogue_promotions", + call_before_creating_catalogue_line_discount, + ): + lines_info, _ = fetch_checkout_lines(checkout) + create_checkout_line_discount_objects_for_catalogue_promotions(lines_info) + + # then + assert CheckoutLineDiscount.objects.count() == 1 diff --git a/saleor/discount/tests/test_utils/test_create_or_update_discount_objects_from_promotion_for_order.py b/saleor/discount/tests/test_utils/test_create_or_update_discount_objects_from_promotion_for_order.py new file mode 100644 index 00000000000..86a3a48bdb0 --- /dev/null +++ b/saleor/discount/tests/test_utils/test_create_or_update_discount_objects_from_promotion_for_order.py @@ -0,0 +1,523 @@ +from decimal import Decimal + +import graphene + +from ....order.fetch import fetch_draft_order_lines_info +from ....product.models import ( + ProductVariantChannelListing, + VariantChannelListingPromotionRule, +) +from ....warehouse.models import Stock +from ... import DiscountType, RewardType, RewardValueType +from ...models import OrderDiscount, OrderLineDiscount +from ...utils import create_or_update_discount_objects_from_promotion_for_order + + +def test_create_catalogue_discount_fixed( + order_with_lines, + catalogue_promotion_without_rules, +): + # given + order = order_with_lines + promotion = catalogue_promotion_without_rules + channel = order.channel + line_1 = order.lines.get(quantity=3) + + # prepare catalogue promotions + variant_1 = line_1.variant + reward_value = Decimal(3) + rule = promotion.rules.create( + name="Catalogue rule fixed", + catalogue_predicate={ + "variantPredicate": { + "ids": [graphene.Node.to_global_id("ProductVariant", variant_1.id)] + } + }, + reward_value_type=RewardValueType.FIXED, + reward_value=reward_value, + ) + rule.channels.add(channel) + + listing = variant_1.channel_listings.get(channel=channel) + undiscounted_price = listing.price_amount + listing.discounted_price_amount = undiscounted_price - reward_value + listing.save(update_fields=["discounted_price_amount"]) + + currency = order.currency + VariantChannelListingPromotionRule.objects.create( + variant_channel_listing=listing, + promotion_rule=rule, + discount_amount=reward_value, + currency=currency, + ) + lines_info = fetch_draft_order_lines_info(order) + + # when + create_or_update_discount_objects_from_promotion_for_order(order, lines_info) + + # then + assert OrderLineDiscount.objects.count() == 1 + assert not OrderDiscount.objects.exists() + discount = OrderLineDiscount.objects.get() + assert discount.line == line_1 + assert discount.promotion_rule == rule + assert discount.type == DiscountType.PROMOTION + assert discount.value_type == RewardValueType.FIXED + assert discount.value == reward_value == Decimal(3) + assert discount.amount_value == reward_value * line_1.quantity == Decimal(9) + assert discount.currency == channel.currency_code + assert discount.name == f"{promotion.name}: {rule.name}" + + line = [line_info.line for line_info in lines_info if line_info.line == line_1][0] + assert line.base_unit_price_amount == Decimal(7) + + +def test_create_catalogue_discount_percentage( + order_with_lines, + catalogue_promotion_without_rules, +): + # given + order = order_with_lines + promotion = catalogue_promotion_without_rules + promotion_id = graphene.Node.to_global_id("Promotion", promotion.id) + channel = order.channel + line_1 = order.lines.get(quantity=3) + + variant_1 = line_1.variant + reward_value = Decimal(50) + rule = promotion.rules.create( + name="Catalogue rule percentage", + catalogue_predicate={ + "variantPredicate": { + "ids": [graphene.Node.to_global_id("ProductVariant", variant_1.id)] + } + }, + reward_value_type=RewardValueType.PERCENTAGE, + reward_value=reward_value, + ) + rule.channels.add(channel) + + listing = variant_1.channel_listings.get(channel=channel) + undiscounted_price = listing.price_amount + discount_amount = undiscounted_price * reward_value / 100 + listing.discounted_price_amount = discount_amount + listing.save(update_fields=["discounted_price_amount"]) + + currency = order.currency + VariantChannelListingPromotionRule.objects.create( + variant_channel_listing=listing, + promotion_rule=rule, + discount_amount=discount_amount, + currency=currency, + ) + lines_info = fetch_draft_order_lines_info(order) + + # when + create_or_update_discount_objects_from_promotion_for_order(order, lines_info) + + # then + assert OrderLineDiscount.objects.count() == 1 + assert not OrderDiscount.objects.exists() + discount = OrderLineDiscount.objects.get() + assert discount.line == line_1 + assert discount.promotion_rule == rule + assert discount.type == DiscountType.PROMOTION + assert discount.value_type == RewardValueType.PERCENTAGE + assert discount.value == reward_value == Decimal(50) + assert discount.amount_value == discount_amount * line_1.quantity == Decimal(15) + assert discount.currency == channel.currency_code + assert discount.name == f"{promotion.name}: {rule.name}" + assert discount.reason == f"Promotion: {promotion_id}" + + line = [line_info.line for line_info in lines_info if line_info.line == line_1][0] + assert line.base_unit_price_amount == Decimal(5) + + +def test_create_order_discount_subtotal_fixed( + order_with_lines, order_promotion_without_rules +): + # given + order = order_with_lines + channel = order.channel + promotion = order_promotion_without_rules + promotion_id = graphene.Node.to_global_id("Promotion", promotion.id) + reward_value = Decimal(25) + rule = promotion.rules.create( + name="Fixed subtotal rule", + order_predicate={ + "discountedObjectPredicate": {"baseTotalPrice": {"range": {"gte": 10}}} + }, + reward_value_type=RewardValueType.FIXED, + reward_value=reward_value, + reward_type=RewardType.SUBTOTAL_DISCOUNT, + ) + rule.channels.add(order.channel) + + lines_info = fetch_draft_order_lines_info(order) + + # when + create_or_update_discount_objects_from_promotion_for_order(order, lines_info) + + # then + assert OrderDiscount.objects.count() == 1 + assert not OrderLineDiscount.objects.exists() + discount = OrderDiscount.objects.get() + assert discount.order == order + assert discount.promotion_rule == rule + assert discount.type == DiscountType.ORDER_PROMOTION + assert discount.value_type == RewardValueType.FIXED + assert discount.value == reward_value == Decimal(25) + assert discount.amount_value == reward_value == Decimal(25) + assert discount.currency == channel.currency_code + assert discount.name == f"{promotion.name}: {rule.name}" + assert discount.reason == f"Promotion: {promotion_id}" + + +def test_create_order_discount_subtotal_percentage( + order_with_lines, order_promotion_without_rules +): + # given + order = order_with_lines + channel = order.channel + promotion = order_promotion_without_rules + promotion_id = graphene.Node.to_global_id("Promotion", promotion.id) + reward_value = Decimal(50) + rule = promotion.rules.create( + name="Percentage subtotal rule", + order_predicate={ + "discountedObjectPredicate": {"baseSubtotalPrice": {"eq": 70}} + }, + reward_value_type=RewardValueType.PERCENTAGE, + reward_value=reward_value, + reward_type=RewardType.SUBTOTAL_DISCOUNT, + ) + rule.channels.add(order.channel) + + lines_info = fetch_draft_order_lines_info(order) + + # when + create_or_update_discount_objects_from_promotion_for_order(order, lines_info) + + # then + assert OrderDiscount.objects.count() == 1 + assert not OrderLineDiscount.objects.exists() + discount = OrderDiscount.objects.get() + assert discount.order == order + assert discount.promotion_rule == rule + assert discount.type == DiscountType.ORDER_PROMOTION + assert discount.value_type == RewardValueType.PERCENTAGE + assert discount.value == reward_value == Decimal(50) + assert discount.amount_value == Decimal(35) + assert discount.currency == channel.currency_code + assert discount.name == f"{promotion.name}: {rule.name}" + assert discount.reason == f"Promotion: {promotion_id}" + + +def test_create_order_discount_gift( + order_with_lines, order_promotion_without_rules, variant_with_many_stocks +): + # given + order = order_with_lines + variant = variant_with_many_stocks + channel = order.channel + promotion = order_promotion_without_rules + promotion_id = graphene.Node.to_global_id("Promotion", promotion.id) + rule = promotion.rules.create( + name="Gift subtotal rule", + order_predicate={ + "discountedObjectPredicate": {"baseSubtotalPrice": {"range": {"gte": 10}}} + }, + reward_type=RewardType.GIFT, + ) + rule.channels.add(channel) + rule.gifts.set([variant]) + + lines_info = fetch_draft_order_lines_info(order) + + # when + create_or_update_discount_objects_from_promotion_for_order(order, lines_info) + + # then + assert OrderLineDiscount.objects.count() == 1 + assert not OrderDiscount.objects.exists() + lines = order.lines.all() + assert len(lines) == 3 + + gift_line = [line for line in lines if line.is_gift][0] + discount = OrderLineDiscount.objects.get() + assert discount.line == gift_line + assert discount.promotion_rule == rule + assert discount.type == DiscountType.ORDER_PROMOTION + listing = ProductVariantChannelListing.objects.filter( + channel=channel, variant=variant + ).first() + assert discount.value == listing.price_amount == Decimal(10) + assert discount.amount_value == Decimal(10) + assert discount.currency == channel.currency_code + assert discount.name == f"{promotion.name}: {rule.name}" + assert discount.reason == f"Promotion: {promotion_id}" + + assert gift_line.quantity == 1 + assert gift_line.variant == variant + assert gift_line.total_price_gross_amount == Decimal(0) + assert gift_line.total_price_net_amount == Decimal(0) + assert gift_line.undiscounted_total_price_gross_amount == Decimal(0) + assert gift_line.undiscounted_total_price_net_amount == Decimal(0) + assert gift_line.unit_price_gross_amount == Decimal(0) + assert gift_line.unit_price_net_amount == Decimal(0) + assert gift_line.base_unit_price_amount == Decimal(0) + assert gift_line.unit_discount_amount == Decimal(0) + assert gift_line.unit_discount_type == RewardValueType.FIXED + assert gift_line.unit_discount_value == Decimal(0) + + +def test_multiple_rules_subtotal_and_catalogue_discount_applied( + draft_order_and_promotions, +): + # given + order, rule_catalogue, rule_total, rule_gift = draft_order_and_promotions + lines_info = fetch_draft_order_lines_info(order) + discounted_variant_global_id = rule_catalogue.catalogue_predicate[ + "variantPredicate" + ]["ids"][0] + _, discounted_variant_id = graphene.Node.from_global_id( + discounted_variant_global_id + ) + + # when + create_or_update_discount_objects_from_promotion_for_order(order, lines_info) + + # then + order.refresh_from_db() + assert OrderLineDiscount.objects.count() == 1 + line = order.lines.get(variant_id=discounted_variant_id) + catalogue_discount = line.discounts.first() + assert catalogue_discount.type == DiscountType.PROMOTION + assert catalogue_discount.value == Decimal(3) + assert catalogue_discount.value == rule_catalogue.reward_value + assert catalogue_discount.amount_value == Decimal(6) + assert ( + catalogue_discount.amount_value == line.quantity * rule_catalogue.reward_value + ) + assert catalogue_discount.value_type == RewardValueType.FIXED + + assert OrderDiscount.objects.count() == 1 + order_discount = order.discounts.first() + assert order_discount.type == DiscountType.ORDER_PROMOTION + assert order_discount.amount_value == Decimal(25) + assert order_discount.amount_value == rule_total.reward_value + assert order_discount.value_type == RewardValueType.FIXED + + +def test_multiple_rules_gift_and_catalogue_discount_applied(draft_order_and_promotions): + # given + order, rule_catalogue, rule_total, rule_gift = draft_order_and_promotions + lines_info = fetch_draft_order_lines_info(order) + rule_total.reward_value = Decimal(0) + rule_total.save(update_fields=["reward_value"]) + + # when + create_or_update_discount_objects_from_promotion_for_order(order, lines_info) + + # then + order.refresh_from_db() + # If gift reward applies and gift is discounted by catalogue promotion, + # do not create discount object for catalogue promotion. Instead, create discount + # object for gift promotion and set reward amount to undiscounted price + assert OrderLineDiscount.objects.count() == 2 + lines = order.lines.all() + assert len(lines) == 3 + gift_line = [line for line in lines if line.is_gift][0] + gift_discount = gift_line.discounts.get() + assert gift_discount.type == DiscountType.ORDER_PROMOTION + listing = ProductVariantChannelListing.objects.filter( + channel=order.channel, variant=gift_line.variant + ).first() + assert gift_discount.value == listing.price_amount + assert not gift_discount.value == listing.discounted_price_amount + assert gift_discount.value == Decimal(20) + + line_1 = [line for line in lines if line.quantity == 3][0] + line_2 = [line for line in lines if line.quantity == 2][0] + assert not line_1.discounts.exists() + catalogue_discount = line_2.discounts.first() + assert catalogue_discount.type == DiscountType.PROMOTION + + assert not OrderDiscount.objects.exists() + + +def test_multiple_rules_no_discount_applied( + draft_order_and_promotions, product_variant_list +): + # given + order, rule_catalogue, rule_total, rule_gift = draft_order_and_promotions + rule_total.order_predicate = { + "discountedObjectPredicate": {"baseSubtotalPrice": {"range": {"gte": 100000}}} + } + rule_total.save(update_fields=["order_predicate"]) + rule_gift.order_predicate = { + "discountedObjectPredicate": {"baseSubtotalPrice": {"range": {"gte": 100000}}} + } + rule_gift.save(update_fields=["order_predicate"]) + + line_2 = [line for line in order.lines.all() if line.quantity == 2][0] + discounted_variant = line_2.variant + listing = discounted_variant.channel_listings.get(channel=order.channel) + listing.discounted_price_amount = listing.price_amount + listing.variantlistingpromotionrule.all().delete() + listing.save(update_fields=["discounted_price_amount"]) + rule_catalogue.catalogue_predicate = { + "variantPredicate": { + "ids": [ + graphene.Node.to_global_id("ProductVariant", product_variant_list[0].id) + ] + } + } + rule_catalogue.save(update_fields=["catalogue_predicate"]) + + lines_info = fetch_draft_order_lines_info(order) + + # when + create_or_update_discount_objects_from_promotion_for_order(order, lines_info) + + # then + assert not OrderLineDiscount.objects.exists() + assert not OrderDiscount.objects.exists() + + +def test_update_catalogue_discount( + order_with_lines_and_catalogue_promotion, catalogue_promotion_without_rules +): + # given + order = order_with_lines_and_catalogue_promotion + promotion = catalogue_promotion_without_rules + + channel = order.channel + line = order.lines.get(quantity=3) + variant = line.variant + assert OrderLineDiscount.objects.count() == 1 + discount_to_update = line.discounts.get() + + reward_value = Decimal(6) + assert reward_value > promotion.rules.first().reward_value + rule = promotion.rules.create( + name="New catalogue rule fixed", + catalogue_predicate={ + "variantPredicate": { + "ids": [graphene.Node.to_global_id("ProductVariant", variant)] + } + }, + reward_value_type=RewardValueType.FIXED, + reward_value=reward_value, + ) + rule.channels.add(channel) + + variant_channel_listing = variant.channel_listings.get(channel=channel) + undiscounted_price = variant_channel_listing.price_amount + variant_channel_listing.discounted_price_amount = undiscounted_price - reward_value + variant_channel_listing.save(update_fields=["discounted_price_amount"]) + + variant_rule_listing = variant_channel_listing.variantlistingpromotionrule.get() + variant_rule_listing.discount_amount = reward_value + variant_rule_listing.promotion_rule = rule + variant_rule_listing.save(update_fields=["discount_amount", "promotion_rule"]) + + lines_info = fetch_draft_order_lines_info(order) + + # when + create_or_update_discount_objects_from_promotion_for_order(order, lines_info) + + # then + assert OrderLineDiscount.objects.count() == 1 + discount = OrderLineDiscount.objects.get() + assert discount_to_update.id == discount.id + assert discount.line == line + assert discount.promotion_rule == rule + assert discount.type == DiscountType.PROMOTION + assert discount.value_type == RewardValueType.FIXED + assert discount.value == reward_value == Decimal(6) + assert discount.amount_value == reward_value * line.quantity == Decimal(18) + assert discount.currency == channel.currency_code + assert discount.name == f"{promotion.name}: {rule.name}" + + line = [line_info.line for line_info in lines_info if line_info.line == line][0] + assert line.base_unit_price_amount == Decimal(4) + + +def test_update_order_discount_subtotal( + order_with_lines_and_order_promotion, order_promotion_without_rules +): + # given + order = order_with_lines_and_order_promotion + channel = order.channel + promotion = order_promotion_without_rules + + reward_value = Decimal(30) + assert reward_value > promotion.rules.first().reward_value + rule = promotion.rules.create( + name="Fixed subtotal rule", + order_predicate={ + "discountedObjectPredicate": {"baseTotalPrice": {"range": {"gte": 10}}} + }, + reward_value_type=RewardValueType.FIXED, + reward_value=reward_value, + reward_type=RewardType.SUBTOTAL_DISCOUNT, + ) + rule.channels.add(order.channel) + + lines_info = fetch_draft_order_lines_info(order) + + # when + create_or_update_discount_objects_from_promotion_for_order(order, lines_info) + + # then + assert OrderDiscount.objects.count() == 1 + discount = order.discounts.get() + assert discount.promotion_rule == rule + assert discount.type == DiscountType.ORDER_PROMOTION + assert discount.value_type == RewardValueType.FIXED + assert discount.value == reward_value == Decimal(30) + assert discount.amount_value == reward_value == Decimal(30) + assert discount.currency == channel.currency_code + assert discount.name == f"{promotion.name}: {rule.name}" + + +def test_update_gift_discount_new_gift_available( + order_with_lines_and_gift_promotion, product_variant_list, warehouse +): + # given + order = order_with_lines_and_gift_promotion + variant = product_variant_list[0] + Stock.objects.create(product_variant=variant, warehouse=warehouse, quantity=100) + channel = order.channel + current_discount = OrderLineDiscount.objects.get() + rule = current_discount.promotion_rule + + gift_price = Decimal(50) + listing = variant.channel_listings.get(channel=channel) + listing.discounted_price_amount = gift_price + listing.price_amount = gift_price + listing.save(update_fields=["discounted_price_amount", "price_amount"]) + rule.gifts.add(variant) + + lines_info = fetch_draft_order_lines_info(order) + + # when + create_or_update_discount_objects_from_promotion_for_order(order, lines_info) + + # then + assert OrderLineDiscount.objects.count() == 1 + lines = order.lines.all() + assert len(lines) == 3 + + gift_line = [line for line in lines if line.is_gift][0] + discount = OrderLineDiscount.objects.get() + assert discount.line == gift_line + assert discount.promotion_rule == rule + assert discount.type == DiscountType.ORDER_PROMOTION + assert discount.value == gift_price == Decimal(50) + assert discount.amount_value == gift_price + assert discount.currency == channel.currency_code + + assert gift_line.quantity == 1 + assert gift_line.variant == variant diff --git a/saleor/discount/tests/test_utils/test_fetch_promotion_rules_for_checkout.py b/saleor/discount/tests/test_utils/test_fetch_promotion_rules_for_checkout.py index a090088b4c2..8e41130a536 100644 --- a/saleor/discount/tests/test_utils/test_fetch_promotion_rules_for_checkout.py +++ b/saleor/discount/tests/test_utils/test_fetch_promotion_rules_for_checkout.py @@ -2,7 +2,7 @@ from ... import RewardType, RewardValueType from ...models import PromotionRule -from ...utils import fetch_promotion_rules_for_checkout +from ...utils import fetch_promotion_rules_for_checkout_or_order def test_fetch_promotion_rules_for_checkout( @@ -23,7 +23,7 @@ def test_fetch_promotion_rules_for_checkout( ) # when - rules_per_promotion_id = fetch_promotion_rules_for_checkout(checkout) + rules_per_promotion_id = fetch_promotion_rules_for_checkout_or_order(checkout) # then assert len(rules_per_promotion_id) == 1 @@ -48,7 +48,7 @@ def test_fetch_promotion_rules_for_checkout_no_matching_rule( ) # when - rules_per_promotion_id = fetch_promotion_rules_for_checkout(checkout) + rules_per_promotion_id = fetch_promotion_rules_for_checkout_or_order(checkout) # then assert not rules_per_promotion_id @@ -77,7 +77,7 @@ def test_fetch_promotion_rules_for_checkout_relevant_channel_only( rule_2.channels.add(checkout_JPY.channel) # when - rules_per_promotion_id = fetch_promotion_rules_for_checkout(checkout_JPY) + rules_per_promotion_id = fetch_promotion_rules_for_checkout_or_order(checkout_JPY) # then assert len(rules_per_promotion_id) == 1 diff --git a/saleor/discount/utils.py b/saleor/discount/utils.py index cf27fb685ff..c3db0fb11a1 100644 --- a/saleor/discount/utils.py +++ b/saleor/discount/utils.py @@ -19,11 +19,14 @@ base_checkout_delivery_price, base_checkout_subtotal, ) -from ..checkout.fetch import CheckoutLineInfo, find_checkout_line_info -from ..checkout.models import Checkout, CheckoutLine +from ..checkout.fetch import CheckoutInfo, CheckoutLineInfo +from ..checkout.models import Checkout +from ..core.db.connection import allow_writer from ..core.exceptions import InsufficientStock from ..core.taxes import zero_money from ..core.utils.promo_code import InvalidPromoCode +from ..order.fetch import DraftOrderLineInfo +from ..order.models import Order, OrderLine from ..product.models import ( Product, ProductChannelListing, @@ -43,6 +46,8 @@ CheckoutLineDiscount, DiscountValueType, NotApplicable, + OrderDiscount, + OrderLineDiscount, Promotion, PromotionRule, Voucher, @@ -52,8 +57,6 @@ if TYPE_CHECKING: from ..account.models import User - from ..checkout.fetch import CheckoutInfo - from ..order.models import Order from ..plugins.manager import PluginsManager from ..product.managers import ProductVariantQueryset from ..product.models import VariantChannelListingPromotionRule @@ -282,25 +285,29 @@ def validate_voucher_for_checkout( ) -def validate_voucher_in_order(order: "Order"): +def validate_voucher_in_order( + order: "Order", lines: Iterable["OrderLine"], channel: "Channel" +): if not order.voucher: return + from ..order.utils import get_total_quantity + subtotal = order.subtotal - quantity = order.get_total_quantity() + quantity = get_total_quantity(lines) customer_email = order.get_customer_email() - tax_configuration = order.channel.tax_configuration + tax_configuration = channel.tax_configuration prices_entered_with_tax = tax_configuration.prices_entered_with_tax value = subtotal.gross if prices_entered_with_tax else subtotal.net validate_voucher( - order.voucher, value, quantity, customer_email, order.channel, order.user + order.voucher, value, quantity, customer_email, channel, order.user ) def validate_voucher( voucher: "Voucher", - total_price: TaxedMoney, + total_price: Money, quantity: int, customer_email: str, channel: Channel, @@ -350,76 +357,125 @@ def create_or_update_discount_objects_from_promotion_for_checkout( lines_info: Iterable["CheckoutLineInfo"], database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME, ): - create_discount_objects_for_catalogue_promotions(lines_info) - create_discount_objects_for_order_promotions( + create_checkout_line_discount_objects_for_catalogue_promotions(lines_info) + create_checkout_discount_objects_for_order_promotions( checkout_info, lines_info, database_connection_name=database_connection_name ) -def create_discount_objects_for_catalogue_promotions( - lines_info: Iterable["CheckoutLineInfo"], +def create_checkout_line_discount_objects_for_catalogue_promotions( + lines_info: Iterable[CheckoutLineInfo], +): + discount_data = prepare_line_discount_objects_for_catalogue_promotions(lines_info) + if not discount_data or not lines_info: + return + + ( + discounts_to_create_inputs, + discounts_to_update, + discount_to_remove, + updated_fields, + ) = discount_data + + new_line_discounts = [] + with allow_writer(): + with transaction.atomic(): + # Protect against potential thread race. CheckoutLine object can have only + # single catalogue discount applied. + checkout_id = lines_info[0].line.checkout_id # type: ignore[index] + _checkout_lock = list( + Checkout.objects.filter(pk=checkout_id).select_for_update(of=(["self"])) + ) + + if discount_ids_to_remove := [ + discount.id for discount in discount_to_remove + ]: + CheckoutLineDiscount.objects.filter( + id__in=discount_ids_to_remove + ).delete() + + if discounts_to_create_inputs: + new_line_discounts = [ + CheckoutLineDiscount(**input) + for input in discounts_to_create_inputs + ] + CheckoutLineDiscount.objects.bulk_create( + new_line_discounts, ignore_conflicts=True + ) + + if discounts_to_update and updated_fields: + CheckoutLineDiscount.objects.bulk_update( + discounts_to_update, updated_fields + ) + + _update_line_info_cached_discounts( + lines_info, new_line_discounts, discounts_to_update, discount_ids_to_remove + ) + + +def prepare_line_discount_objects_for_catalogue_promotions( + lines_info: Union[Iterable["CheckoutLineInfo"], Iterable["DraftOrderLineInfo"]], ): - line_discounts_to_create = [] - line_discounts_to_update = [] - line_discount_ids_to_remove = [] + line_discounts_to_create_inputs: list[dict] = [] + line_discounts_to_update: list[Union[CheckoutLineDiscount, OrderLineDiscount]] = [] + line_discounts_to_remove: list[Union[CheckoutLineDiscount, OrderLineDiscount]] = [] updated_fields: list[str] = [] + if not lines_info: + return + for line_info in lines_info: line = line_info.line + # get the existing catalogue discount for the line + discount_to_update = None + if discounts_to_update := line_info.get_catalogue_discounts(): + discount_to_update = discounts_to_update[0] + # Line should never have multiple catalogue discounts associated. Before + # introducing unique_type on discount models, there was such a possibility. + line_discounts_to_remove.extend(discounts_to_update[1:]) + + # manual line discount do not stack with other discounts + if [ + discount + for discount in line_info.discounts + if discount.type == DiscountType.MANUAL + ]: + line_discounts_to_remove.extend(discounts_to_update) + continue + # discount_amount based on the difference between discounted_price and price discount_amount = _get_discount_amount(line_info.channel_listing, line.quantity) - # get the existing discounts for the line - discounts_to_update = line_info.get_catalogue_discounts() - rule_id_to_discount = { - discount.promotion_rule_id: discount for discount in discounts_to_update - } - # delete all existing discounts if the line is not discounted or it is a gift if not discount_amount or line.is_gift: - ids_to_remove = [discount.id for discount in discounts_to_update] - if ids_to_remove: - line_discount_ids_to_remove.extend(ids_to_remove) - line_info.discounts = [ - discount - for discount in line_info.discounts - if discount.id not in ids_to_remove - ] + line_discounts_to_remove.extend(discounts_to_update) continue - # delete the discount objects that are not valid anymore - line_discount_ids_to_remove.extend( - _get_discounts_that_are_not_valid_anymore( - line_info.rules_info, - rule_id_to_discount, # type: ignore[arg-type] - line_info, - ) - ) - - for rule_info in line_info.rules_info: + if line_info.rules_info: + rule_info = line_info.rules_info[0] rule = rule_info.rule - discount_to_update = rule_id_to_discount.get(rule.id) rule_discount_amount = _get_rule_discount_amount( rule_info.variant_listing_promotion_rule, line.quantity ) discount_name = get_discount_name(rule, rule_info.promotion) translated_name = get_discount_translated_name(rule_info) + reason = _get_discount_reason(rule) if not discount_to_update: - line_discount = CheckoutLineDiscount( - line=line, - type=DiscountType.PROMOTION, - value_type=rule.reward_value_type, - value=rule.reward_value, - amount_value=rule_discount_amount, - currency=line.currency, - name=discount_name, - translated_name=translated_name, - reason=None, - promotion_rule=rule, - ) - line_discounts_to_create.append(line_discount) - line_info.discounts.append(line_discount) + line_discount_input = { + "line": line, + "type": DiscountType.PROMOTION, + "value_type": rule.reward_value_type, + "value": rule.reward_value, + "amount_value": rule_discount_amount, + "currency": line.currency, + "name": discount_name, + "translated_name": translated_name, + "reason": reason, + "promotion_rule": rule, + "unique_type": DiscountType.PROMOTION, + } + line_discounts_to_create_inputs.append(line_discount_input) else: _update_discount( rule, @@ -428,17 +484,17 @@ def create_discount_objects_for_catalogue_promotions( discount_to_update, updated_fields, ) - line_discounts_to_update.append(discount_to_update) + else: + # Fallback for unlike mismatch between discount_amount and rules_info + line_discounts_to_remove.extend(discounts_to_update) - if line_discounts_to_create: - CheckoutLineDiscount.objects.bulk_create(line_discounts_to_create) - if line_discounts_to_update and updated_fields: - CheckoutLineDiscount.objects.bulk_update( - line_discounts_to_update, updated_fields - ) - if line_discount_ids_to_remove: - CheckoutLineDiscount.objects.filter(id__in=line_discount_ids_to_remove).delete() + return ( + line_discounts_to_create_inputs, + line_discounts_to_update, + line_discounts_to_remove, + updated_fields, + ) def _get_discount_amount( @@ -458,20 +514,6 @@ def _get_discount_amount( return unit_discount * line_quantity -def _get_discounts_that_are_not_valid_anymore( - rules_info: list["VariantPromotionRuleInfo"], - rule_id_to_discount: dict[int, "CheckoutLineDiscount"], - line_info: "CheckoutLineInfo", -): - discount_ids = [] - rule_ids = {rule_info.rule.id for rule_info in rules_info} - for rule_id, discount in rule_id_to_discount.items(): - if rule_id not in rule_ids: - discount_ids.append(discount.id) - line_info.discounts.remove(discount) - return discount_ids - - def _get_rule_discount_amount( variant_listing_promotion_rule: Optional["VariantChannelListingPromotionRule"], line_quantity: int, @@ -488,6 +530,13 @@ def get_discount_name(rule: "PromotionRule", promotion: "Promotion"): return rule.name or promotion.name +def _get_discount_reason(rule: PromotionRule): + promotion = rule.promotion + if promotion.old_sale_id: + return f"Sale: {graphene.Node.to_global_id('Sale', promotion.old_sale_id)}" + return f"Promotion: {graphene.Node.to_global_id('Promotion', promotion.id)}" + + def get_discount_translated_name(rule_info: "VariantPromotionRuleInfo"): promotion_translation = rule_info.promotion_translation rule_translation = rule_info.rule_translation @@ -504,7 +553,9 @@ def _update_discount( rule: "PromotionRule", rule_info: "VariantPromotionRuleInfo", rule_discount_amount: Decimal, - discount_to_update: Union["CheckoutLineDiscount", "CheckoutDiscount"], + discount_to_update: Union[ + "CheckoutLineDiscount", "CheckoutDiscount", "OrderLineDiscount", "OrderDiscount" + ], updated_fields: list[str], ): if discount_to_update.promotion_rule_id != rule.id: @@ -537,9 +588,33 @@ def _update_discount( if discount_to_update.reason != reason: discount_to_update.reason = reason updated_fields.append("reason") + if hasattr(discount_to_update, "unique_type"): + if discount_to_update.unique_type is None: + discount_to_update.unique_type = DiscountType.PROMOTION + updated_fields.append("unique_type") -def create_discount_objects_for_order_promotions( +def _update_line_info_cached_discounts( + lines_info, new_line_discounts, updated_discounts, line_discount_ids_to_remove +): + if not any([new_line_discounts, updated_discounts, line_discount_ids_to_remove]): + return + + line_id_line_discounts_map = defaultdict(list) + for line_discount in new_line_discounts: + line_id_line_discounts_map[line_discount.line_id].append(line_discount) + + for line_info in lines_info: + line_info.discounts = [ + discount + for discount in line_info.discounts + if discount.id not in line_discount_ids_to_remove + ] + if discount := line_id_line_discounts_map.get(line_info.line.id): + line_info.discounts.extend(discount) + + +def create_checkout_discount_objects_for_order_promotions( checkout_info: "CheckoutInfo", lines_info: Iterable["CheckoutLineInfo"], *, @@ -557,42 +632,77 @@ def create_discount_objects_for_order_promotions( return channel = checkout_info.channel - rule_data = get_best_rule_for_checkout( - checkout, channel, checkout_info.get_country(), database_connection_name + rules = fetch_promotion_rules_for_checkout_or_order( + checkout, database_connection_name + ) + rule_data = get_best_rule( + rules=rules, + channel=channel, + country=checkout_info.get_country(), + subtotal=checkout.base_subtotal, + database_connection_name=database_connection_name, ) if not rule_data: _clear_checkout_discount(checkout_info, lines_info, save) return best_rule, best_discount_amount, gift_listing = rule_data - - _create_or_update_checkout_discount( - checkout, - checkout_info, - lines_info, - best_rule, - best_discount_amount, - gift_listing, - channel.currency_code, - best_rule.promotion, - save, + promotion = best_rule.promotion + currency = channel.currency_code + translation_language_code = checkout.language_code + promotion_translation, rule_translation = get_rule_translations( + promotion, best_rule, translation_language_code ) + rule_info = VariantPromotionRuleInfo( + rule=best_rule, + variant_listing_promotion_rule=None, + promotion=promotion, + promotion_translation=promotion_translation, + rule_translation=rule_translation, + ) + # gift rule has empty reward_value and reward_value_type + value_type = best_rule.reward_value_type or RewardValueType.FIXED + amount_value = gift_listing.price_amount if gift_listing else best_discount_amount + value = best_rule.reward_value or amount_value + discount_object_defaults = { + "promotion_rule": best_rule, + "value_type": value_type, + "value": value, + "amount_value": amount_value, + "currency": currency, + "name": get_discount_name(best_rule, promotion), + "translated_name": get_discount_translated_name(rule_info), + "reason": prepare_promotion_discount_reason(promotion, get_sale_id(promotion)), + } + if gift_listing: + _handle_gift_reward_for_checkout( + checkout_info, + lines_info, + gift_listing, + discount_object_defaults, + rule_info, + save, + ) + else: + _handle_order_promotion_for_checkout( + checkout_info, + lines_info, + discount_object_defaults, + rule_info, + save, + ) -def get_best_rule_for_checkout( - checkout: "Checkout", +def get_best_rule( + rules: Iterable["PromotionRule"], channel: "Channel", country: str, + subtotal: Money, database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME, ): RuleDiscount = namedtuple( "RuleDiscount", ["rule", "discount_amount", "gift_listing"] ) - subtotal = checkout.base_subtotal - rules = fetch_promotion_rules_for_checkout(checkout, database_connection_name) - if not rules: - return - currency_code = channel.currency_code rule_discounts: list[RuleDiscount] = [] gift_rules = [rule for rule in rules if rule.reward_type == RewardType.GIFT] @@ -607,12 +717,14 @@ def get_best_rule_for_checkout( rule_discounts.append(RuleDiscount(rule, discount_amount, None)) if gift_rules: - rule, gift_listing = _get_best_gift_reward( + best_gift_rule, gift_listing = _get_best_gift_reward( gift_rules, channel, country, database_connection_name ) - if rule and gift_listing: + if best_gift_rule and gift_listing: rule_discounts.append( - RuleDiscount(rule, gift_listing.discounted_price_amount, gift_listing) + RuleDiscount( + best_gift_rule, gift_listing.discounted_price_amount, gift_listing + ) ) if not rule_discounts: @@ -640,7 +752,8 @@ def _set_checkout_base_prices(checkout_info, lines_info): if is_update_needed: checkout.base_subtotal = subtotal checkout.base_total = total - checkout.save(update_fields=["base_total_amount", "base_subtotal_amount"]) + with allow_writer(): + checkout.save(update_fields=["base_total_amount", "base_subtotal_amount"]) def _clear_checkout_discount( @@ -776,71 +889,14 @@ def _get_available_for_purchase_variant_ids( return set(available_variant_ids) -def _create_or_update_checkout_discount( - checkout: "Checkout", - checkout_info: "CheckoutInfo", - lines_info: Iterable["CheckoutLineInfo"], - best_rule: "PromotionRule", - best_discount_amount: Decimal, - gift_listing: Optional[ProductVariantChannelListing], - currency_code: str, - promotion: "Promotion", - save: bool, -): - translation_language_code = checkout.language_code - promotion_translation, rule_translation = get_rule_translations( - promotion, best_rule, translation_language_code - ) - rule_info = VariantPromotionRuleInfo( - rule=best_rule, - variant_listing_promotion_rule=None, - promotion=promotion, - promotion_translation=promotion_translation, - rule_translation=rule_translation, - ) - # gift rule has empty reward_value and reward_value_type - value_type = best_rule.reward_value_type or RewardValueType.FIXED - amount_value = gift_listing.price_amount if gift_listing else best_discount_amount - value = best_rule.reward_value or amount_value - discount_object_defaults = { - "promotion_rule": best_rule, - "value_type": value_type, - "value": value, - "amount_value": amount_value, - "currency": currency_code, - "name": get_discount_name(best_rule, promotion), - "translated_name": get_discount_translated_name(rule_info), - "reason": prepare_promotion_discount_reason(promotion, get_sale_id(promotion)), - } - if gift_listing: - _handle_gift_reward( - checkout, - checkout_info, - lines_info, - gift_listing, - discount_object_defaults, - rule_info, - save, - ) - else: - _handle_order_promotion( - checkout, - checkout_info, - lines_info, - discount_object_defaults, - rule_info, - save, - ) - - -def _handle_order_promotion( - checkout: "Checkout", - checkout_info: "CheckoutInfo", - lines_info: Iterable["CheckoutLineInfo"], +def _handle_order_promotion_for_checkout( + checkout_info: CheckoutInfo, + lines_info: Iterable[CheckoutLineInfo], discount_object_defaults: dict, rule_info: VariantPromotionRuleInfo, - save: bool, + save: bool = False, ): + checkout = checkout_info.checkout discount_object, created = checkout.discounts.get_or_create( type=DiscountType.ORDER_PROMOTION, defaults=discount_object_defaults, @@ -860,6 +916,7 @@ def _handle_order_promotion( discount_object.save(update_fields=fields_to_update) checkout_info.discounts = [discount_object] + checkout = checkout_info.checkout checkout.discount_amount = discount_amount checkout.discount_name = discount_object.name checkout.translated_discount_name = discount_object.translated_name @@ -875,26 +932,29 @@ def _handle_order_promotion( delete_gift_line(checkout, lines_info) -def delete_gift_line(checkout: "Checkout", lines_info: Iterable["CheckoutLineInfo"]): +def delete_gift_line( + order_or_checkout: Union[Checkout, Order], + lines_info: Iterable[Union["CheckoutLineInfo", "DraftOrderLineInfo"]], +): if gift_line_infos := [line for line in lines_info if line.line.is_gift]: - CheckoutLine.objects.filter(checkout_id=checkout.pk, is_gift=True).delete() + order_or_checkout.lines.filter(is_gift=True).delete() # type: ignore[misc] for gift_line_info in gift_line_infos: lines_info.remove(gift_line_info) # type: ignore[attr-defined] -def _handle_gift_reward( - checkout: "Checkout", - checkout_info: "CheckoutInfo", - lines_info: Iterable["CheckoutLineInfo"], +@allow_writer() +def _handle_gift_reward_for_checkout( + checkout_info: CheckoutInfo, + lines_info: Iterable[CheckoutLineInfo], gift_listing: ProductVariantChannelListing, discount_object_defaults: dict, rule_info: VariantPromotionRuleInfo, - save: bool, + save: bool = False, ): with transaction.atomic(): - line, line_created = create_gift_line(checkout, gift_listing.variant_id) - line_discount = None - discount_created = False + line, line_created = create_gift_line( + checkout_info.checkout, gift_listing.variant_id + ) ( line_discount, discount_created, @@ -920,38 +980,38 @@ def _handle_gift_reward( line_discount.save(update_fields=fields_to_update) checkout_info.discounts = [] - checkout.discount_amount = Decimal("0") + checkout_info.checkout.discount_amount = Decimal("0") if save: - checkout.save(update_fields=["discount_amount"]) + checkout_info.checkout.save(update_fields=["discount_amount"]) if line_created: variant = gift_listing.variant - gift_line_info = CheckoutLineInfo( - line=line, - variant=variant, - channel_listing=gift_listing, - product=variant.product, - product_type=variant.product.product_type, - collections=[], - discounts=[line_discount], - rules_info=[rule_info], - channel=checkout_info.channel, - ) + init_values = { + "line": line, + "variant": variant, + "channel_listing": gift_listing, + "discounts": [line_discount], + "rules_info": [rule_info], + "channel": checkout_info.channel, + "product": variant.product, + "product_type": variant.product.product_type, + "collections": [], + } + + gift_line_info = CheckoutLineInfo(**init_values) lines_info.append(gift_line_info) # type: ignore[attr-defined] else: - line_info = find_checkout_line_info(lines_info, line.id) + line_info = next( + line_info for line_info in lines_info if line_info.line.pk == line.id + ) line_info.line = line line_info.discounts = [line_discount] -def create_gift_line(checkout: "Checkout", variant_id: int): - defaults = { - "variant_id": variant_id, - "quantity": 1, - "currency": checkout.currency, - } - line, created = CheckoutLine.objects.get_or_create( - checkout=checkout, is_gift=True, defaults=defaults +def create_gift_line(order_or_checkout: Union[Checkout, Order], variant_id: int): + defaults = _get_defaults_for_gift_line(order_or_checkout, variant_id) + line, created = order_or_checkout.lines.get_or_create( + is_gift=True, defaults=defaults ) if not created: fields_to_update = [] @@ -965,6 +1025,29 @@ def create_gift_line(checkout: "Checkout", variant_id: int): return line, created +def _get_defaults_for_gift_line( + order_or_checkout: Union[Checkout, Order], variant_id: int +): + if isinstance(order_or_checkout, Checkout): + return { + "variant_id": variant_id, + "quantity": 1, + "currency": order_or_checkout.currency, + } + else: + return { + "variant_id": variant_id, + "quantity": 1, + "currency": order_or_checkout.currency, + "unit_price_net_amount": Decimal(0), + "unit_price_gross_amount": Decimal(0), + "total_price_net_amount": Decimal(0), + "total_price_gross_amount": Decimal(0), + "is_shipping_required": True, + "is_gift_card": False, + } + + def get_variants_to_promotion_rules_map( variant_qs: "ProductVariantQueryset", ) -> dict[int, list[PromotionRuleInfo]]: @@ -1013,39 +1096,43 @@ def get_variants_to_promotion_rules_map( return rules_info_per_variant -def fetch_promotion_rules_for_checkout( - checkout: Checkout, +def fetch_promotion_rules_for_checkout_or_order( + instance: Union["Checkout", "Order"], database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME, ): from ..graphql.discount.utils import PredicateObjectType, filter_qs_by_predicate applicable_rules = [] promotions = Promotion.objects.active() - checkout_channel_id = checkout.channel_id - PromotionRuleChannels = PromotionRule.channels.through.objects.filter( - channel_id=checkout_channel_id - ) rules = ( PromotionRule.objects.using(database_connection_name) - .filter( - Exists(promotions.filter(id=OuterRef("promotion_id"))), - Exists(PromotionRuleChannels.filter(promotionrule_id=OuterRef("id"))), - ) + .filter(Exists(promotions.filter(id=OuterRef("promotion_id")))) .exclude(order_predicate={}) + .prefetch_related("channels") ) + rule_to_channel_ids_map = _get_rule_to_channel_ids_map(rules) - currency = checkout.currency - checkout_qs = Checkout.objects.using(database_connection_name).filter( - pk=checkout.pk + channel_id = instance.channel_id + currency = instance.channel.currency_code + qs = instance._meta.model.objects.using(database_connection_name).filter( # type: ignore[attr-defined] # noqa: E501 + pk=instance.pk ) for rule in rules.iterator(): - checkouts = filter_qs_by_predicate( + rule_channel_ids = rule_to_channel_ids_map.get(rule.id, []) + if channel_id not in rule_channel_ids: + continue + predicate_type = ( + PredicateObjectType.CHECKOUT + if isinstance(instance, Checkout) + else PredicateObjectType.ORDER + ) + objects = filter_qs_by_predicate( rule.order_predicate, - checkout_qs, - PredicateObjectType.CHECKOUT, + qs, + predicate_type, currency, ) - if checkouts.exists(): + if objects.exists(): applicable_rules.append(rule) return applicable_rules @@ -1121,6 +1208,299 @@ def update_rule_variant_relation( ) +def create_or_update_discount_objects_from_promotion_for_order( + order: "Order", + lines_info: Iterable["DraftOrderLineInfo"], + database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME, +): + create_order_line_discount_objects_for_catalogue_promotions(lines_info) + create_order_discount_objects_for_order_promotions( + order, lines_info, database_connection_name=database_connection_name + ) + _copy_unit_discount_data_to_order_line(lines_info) + + +def create_order_line_discount_objects_for_catalogue_promotions( + lines_info: Iterable[DraftOrderLineInfo], +): + discount_data = prepare_line_discount_objects_for_catalogue_promotions(lines_info) + if not discount_data or not lines_info: + return + + ( + discounts_to_create_inputs, + discounts_to_update, + discount_to_remove, + updated_fields, + ) = discount_data + + new_line_discounts = [] + with allow_writer(): + with transaction.atomic(): + # Protect against potential thread race. OrderLine object can have only + # single catalogue discount applied. + order_id = lines_info[0].line.order_id # type: ignore[index] + _order_lock = list( + Order.objects.filter(id=order_id).select_for_update(of=(["self"])) + ) + + if discount_ids_to_remove := [ + discount.id for discount in discount_to_remove + ]: + OrderLineDiscount.objects.filter(id__in=discount_ids_to_remove).delete() + + if discounts_to_create_inputs: + new_line_discounts = [ + OrderLineDiscount(**input) for input in discounts_to_create_inputs + ] + OrderLineDiscount.objects.bulk_create( + new_line_discounts, ignore_conflicts=True + ) + + if discounts_to_update and updated_fields: + OrderLineDiscount.objects.bulk_update( + discounts_to_update, updated_fields + ) + + _update_line_info_cached_discounts( + lines_info, new_line_discounts, discounts_to_update, discount_ids_to_remove + ) + + affected_line_ids = [ + discount_line.line.id + for discount_line in new_line_discounts + + discounts_to_update + + discount_to_remove + ] + modified_lines_info = [ + line_info for line_info in lines_info if line_info.line.id in affected_line_ids + ] + # base unit price must reflect all actual catalogue discounts + _update_base_unit_price_amount(modified_lines_info) + + +def _copy_unit_discount_data_to_order_line(lines_info: Iterable[DraftOrderLineInfo]): + for line_info in lines_info: + if discounts := line_info.discounts: + line = line_info.line + discount_amount = sum([discount.amount_value for discount in discounts]) + unit_discount_amount = discount_amount / line.quantity + discount_reason = "; ".join( + [discount.reason for discount in discounts if discount.reason] + ) + discount_type = ( + discounts[0].value_type + if len(discounts) == 1 + else DiscountValueType.FIXED + ) + discount_value = ( + discounts[0].value if len(discounts) == 1 else unit_discount_amount + ) + + line.unit_discount_amount = unit_discount_amount + line.unit_discount_reason = discount_reason + line.unit_discount_type = discount_type + line.unit_discount_value = discount_value + + +def _update_base_unit_price_amount(lines_info: Iterable[DraftOrderLineInfo]): + for line_info in lines_info: + line = line_info.line + base_unit_price = line.undiscounted_base_unit_price_amount + for discount in line_info.discounts: + unit_discount = discount.amount_value / line.quantity + base_unit_price -= unit_discount + line.base_unit_price_amount = max(base_unit_price, Decimal(0)) + + +def create_order_discount_objects_for_order_promotions( + order: "Order", + lines_info: Iterable["DraftOrderLineInfo"], + database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME, +): + from ..order.base_calculations import base_order_subtotal + from ..order.utils import get_order_country + + # If voucher is set or manual discount applied, then skip order promotions + if order.voucher_code or order.discounts.filter(type=DiscountType.MANUAL): + _clear_order_discount(order, lines_info) + return + + # The base prices are required for order promotion discount qualification. + _set_order_base_prices(order, lines_info) + + lines = [line_info.line for line_info in lines_info] + subtotal = base_order_subtotal(order, lines) + channel = order.channel + rules = fetch_promotion_rules_for_checkout_or_order(order, database_connection_name) + rule_data = get_best_rule( + rules=rules, + channel=channel, + country=get_order_country(order), + subtotal=subtotal, + database_connection_name=database_connection_name, + ) + if not rule_data: + _clear_order_discount(order, lines_info) + return + + best_rule, best_discount_amount, gift_listing = rule_data + promotion = best_rule.promotion + currency = channel.currency_code + translation_language_code = order.language_code + promotion_translation, rule_translation = get_rule_translations( + promotion, best_rule, translation_language_code + ) + rule_info = VariantPromotionRuleInfo( + rule=best_rule, + variant_listing_promotion_rule=None, + promotion=best_rule.promotion, + promotion_translation=promotion_translation, + rule_translation=rule_translation, + ) + # gift rule has empty reward_value and reward_value_type + value_type = best_rule.reward_value_type or RewardValueType.FIXED + amount_value = gift_listing.price_amount if gift_listing else best_discount_amount + value = best_rule.reward_value or amount_value + discount_object_defaults = { + "promotion_rule": best_rule, + "value_type": value_type, + "value": value, + "amount_value": amount_value, + "currency": currency, + "name": get_discount_name(best_rule, promotion), + "translated_name": get_discount_translated_name(rule_info), + "reason": prepare_promotion_discount_reason(promotion, get_sale_id(promotion)), + } + if gift_listing: + _handle_gift_reward_for_order( + order, + lines_info, + gift_listing, + discount_object_defaults, + rule_info, + ) + else: + _handle_order_promotion_for_order( + order, + lines_info, + discount_object_defaults, + rule_info, + ) + + +def _clear_order_discount( + order_or_checkout: Union[Checkout, Order], + lines_info: Iterable[DraftOrderLineInfo], +): + with transaction.atomic(): + delete_gift_line(order_or_checkout, lines_info) + order_or_checkout.discounts.filter(type=DiscountType.ORDER_PROMOTION).delete() + + +def _set_order_base_prices(order: Order, lines_info: Iterable[DraftOrderLineInfo]): + """Set base order prices that includes only catalogue discounts.""" + from ..order.base_calculations import base_order_subtotal + + lines = [line_info.line for line_info in lines_info] + subtotal = base_order_subtotal(order, lines) + shipping_price = order.base_shipping_price + total = subtotal + shipping_price + + update_fields = [] + if order.subtotal != TaxedMoney(net=subtotal, gross=subtotal): + order.subtotal = TaxedMoney(net=subtotal, gross=subtotal) + update_fields.extend(["subtotal_net_amount", "subtotal_gross_amount"]) + if order.total != TaxedMoney(net=total, gross=total): + order.total = TaxedMoney(net=total, gross=total) + update_fields.extend(["total_net_amount", "total_gross_amount"]) + + if update_fields: + with allow_writer(): + order.save(update_fields=update_fields) + + +def _handle_order_promotion_for_order( + order: Order, + lines_info: Iterable[DraftOrderLineInfo], + discount_object_defaults: dict, + rule_info: VariantPromotionRuleInfo, +): + discount_object, created = order.discounts.get_or_create( + type=DiscountType.ORDER_PROMOTION, + defaults=discount_object_defaults, + ) + discount_amount = discount_object_defaults["amount_value"] + + if not created: + fields_to_update: list[str] = [] + _update_discount( + discount_object_defaults["promotion_rule"], + rule_info, + discount_amount, + discount_object, + fields_to_update, + ) + if fields_to_update: + discount_object.save(update_fields=fields_to_update) + + delete_gift_line(order, lines_info) + + +@allow_writer() +def _handle_gift_reward_for_order( + order: Order, + lines_info: Iterable[DraftOrderLineInfo], + gift_listing: ProductVariantChannelListing, + discount_object_defaults: dict, + rule_info: VariantPromotionRuleInfo, +): + with transaction.atomic(): + line, line_created = create_gift_line(order, gift_listing.variant_id) + ( + line_discount, + discount_created, + ) = OrderLineDiscount.objects.get_or_create( + type=DiscountType.ORDER_PROMOTION, + line=line, + defaults=discount_object_defaults, + ) + + if not discount_created: + fields_to_update = [] + if line_discount.line_id != line.id: + line_discount.line = line + fields_to_update.append("line_id") + _update_discount( + discount_object_defaults["promotion_rule"], + rule_info, + discount_object_defaults["amount_value"], + line_discount, + fields_to_update, + ) + if fields_to_update: + line_discount.save(update_fields=fields_to_update) + + if line_created: + variant = gift_listing.variant + init_values = { + "line": line, + "variant": variant, + "channel_listing": gift_listing, + "discounts": [line_discount], + "rules_info": [rule_info], + "channel": order.channel_id, + } + gift_line_info = DraftOrderLineInfo(**init_values) + lines_info.append(gift_line_info) # type: ignore[attr-defined] + else: + line_info = next( + line_info for line_info in lines_info if line_info.line.pk == line.id + ) + line_info.line = line + line_info.discounts = [line_discount] + + def get_active_catalogue_promotion_rules( allow_replica: bool = False, ) -> "QuerySet[PromotionRule]": diff --git a/saleor/graphql/account/mutations/authentication/refresh_token.py b/saleor/graphql/account/mutations/authentication/refresh_token.py index cf4fb63484b..b461fffc36d 100644 --- a/saleor/graphql/account/mutations/authentication/refresh_token.py +++ b/saleor/graphql/account/mutations/authentication/refresh_token.py @@ -58,14 +58,13 @@ def get_refresh_token( cls, info: ResolveInfo, refresh_token: Optional[str] = None ) -> Optional[str]: request = info.context - refresh_token = refresh_token or request.COOKIES.get( - JWT_REFRESH_TOKEN_COOKIE_NAME, None - ) + if refresh_token is None: + refresh_token = request.COOKIES.get(JWT_REFRESH_TOKEN_COOKIE_NAME, None) return refresh_token @classmethod def clean_refresh_token(cls, refresh_token): - if not refresh_token: + if refresh_token is None: raise ValidationError( { "refresh_token": ValidationError( diff --git a/saleor/graphql/account/mutations/authentication/set_password.py b/saleor/graphql/account/mutations/authentication/set_password.py index c319421f0e2..2ef7414ce8f 100644 --- a/saleor/graphql/account/mutations/authentication/set_password.py +++ b/saleor/graphql/account/mutations/authentication/set_password.py @@ -6,6 +6,7 @@ from .....account import events as account_events from .....account import models from .....account.error_codes import AccountErrorCode +from .....core.db.connection import allow_writer from .....order.utils import match_orders_with_new_user from ....core import ResolveInfo from ....core.context import disallow_replica_in_context @@ -35,6 +36,7 @@ class Meta: error_type_field = "account_errors" @classmethod + @allow_writer() def mutate( # type: ignore[override] cls, root, info: ResolveInfo, /, *, email, password, token ): diff --git a/saleor/graphql/account/resolvers.py b/saleor/graphql/account/resolvers.py index 046c0282dc3..ab1ec5fc5a7 100644 --- a/saleor/graphql/account/resolvers.py +++ b/saleor/graphql/account/resolvers.py @@ -178,11 +178,13 @@ def resolve_address_validation_rules( @traced_resolver -def resolve_payment_sources(_info, user: models.User, manager, channel_slug: str): - stored_customer_accounts = ( +def resolve_payment_sources( + _info, user: models.User, manager, channel_slug: Optional[str] +): + stored_customer_accounts = [ (gtw.id, fetch_customer_id(user, gtw.id)) for gtw in gateway.list_gateways(manager, channel_slug) - ) + ] return list( chain( *[ diff --git a/saleor/graphql/account/tests/benchmark/test_permission_group.py b/saleor/graphql/account/tests/benchmark/test_permission_group.py index db890373a96..1e684cb1cb2 100644 --- a/saleor/graphql/account/tests/benchmark/test_permission_group.py +++ b/saleor/graphql/account/tests/benchmark/test_permission_group.py @@ -348,7 +348,7 @@ def test_groups_for_federation_query_count( ], } - with django_assert_num_queries(2): + with django_assert_num_queries(1): response = api_client.post_graphql(query, variables) content = get_graphql_content(response) assert len(content["data"]["_entities"]) == 1 @@ -363,7 +363,7 @@ def test_groups_for_federation_query_count( ], } - with django_assert_num_queries(2): + with django_assert_num_queries(1): response = api_client.post_graphql(query, variables) content = get_graphql_content(response) assert len(content["data"]["_entities"]) == 3 diff --git a/saleor/graphql/account/tests/mutations/authentication/test_token_refresh.py b/saleor/graphql/account/tests/mutations/authentication/test_token_refresh.py index e969dbcbf49..f0baffa79d4 100644 --- a/saleor/graphql/account/tests/mutations/authentication/test_token_refresh.py +++ b/saleor/graphql/account/tests/mutations/authentication/test_token_refresh.py @@ -1,5 +1,6 @@ from datetime import datetime +import pytest from django.urls import reverse from freezegun import freeze_time @@ -269,3 +270,25 @@ def test_refresh_token_when_user_deactivated_token(api_client, customer_user): assert not data["token"] assert len(errors) == 1 assert errors[0]["code"] == AccountErrorCode.JWT_INVALID_TOKEN.name + + +@pytest.mark.parametrize("token", ["incorrect-token", ""]) +def test_refresh_token_incorrect_token_provided(api_client, customer_user, token): + # given + csrf_token = _get_new_csrf_token() + refresh_token = create_refresh_token(customer_user, {"csrfToken": csrf_token}) + api_client.cookies[JWT_REFRESH_TOKEN_COOKIE_NAME] = refresh_token + api_client.cookies[JWT_REFRESH_TOKEN_COOKIE_NAME]["httponly"] = True + + variables = {"token": token, "csrf_token": csrf_token} + + # when + response = api_client.post_graphql(MUTATION_TOKEN_REFRESH, variables) + content = get_graphql_content(response) + + # then + data = content["data"]["tokenRefresh"] + errors = data["errors"] + assert not data.get("token") + assert len(errors) == 1 + assert errors[0]["code"] == AccountErrorCode.JWT_DECODE_ERROR.name diff --git a/saleor/graphql/account/tests/queries/test_me.py b/saleor/graphql/account/tests/queries/test_me.py index d012191d319..2ed22c3e10c 100644 --- a/saleor/graphql/account/tests/queries/test_me.py +++ b/saleor/graphql/account/tests/queries/test_me.py @@ -375,7 +375,9 @@ def test_me_query_stored_payment_methods( ) # then - mocked_list_stored_payment_methods.assert_called_once_with(request_data) + mocked_list_stored_payment_methods.assert_called_once_with( + request_data, channel_slug=channel_USD.slug + ) content = get_graphql_content(response) data = content["data"]["me"] diff --git a/saleor/graphql/account/types.py b/saleor/graphql/account/types.py index ebf62b0600a..11502adaa7d 100644 --- a/saleor/graphql/account/types.py +++ b/saleor/graphql/account/types.py @@ -64,6 +64,7 @@ from .dataloaders import ( AccessibleChannelsByGroupIdLoader, AccessibleChannelsByUserIdLoader, + AddressByIdLoader, CustomerEventsByUserLoader, RestrictedChannelAccessByUserIdLoader, ThumbnailByUserIdSizeAndFormatLoader, @@ -706,7 +707,9 @@ def get_stored_payment_methods(data): user=root, channel=channel_obj, ) - return manager.list_stored_payment_methods(request_data) + return manager.list_stored_payment_methods( + request_data, channel_slug=channel + ) return Promise.all( [ @@ -715,6 +718,20 @@ def get_stored_payment_methods(data): ] ).then(get_stored_payment_methods) + @staticmethod + def resolve_default_billing_address(root: models.User, info: ResolveInfo): + if root.default_billing_address_id: + return AddressByIdLoader(info.context).load(root.default_billing_address_id) + return None + + @staticmethod + def resolve_default_shipping_address(root: models.User, info: ResolveInfo): + if root.default_shipping_address_id: + return AddressByIdLoader(info.context).load( + root.default_shipping_address_id + ) + return None + class UserCountableConnection(CountableConnection): class Meta: @@ -731,7 +748,7 @@ class ChoiceValue(graphene.ObjectType): "\n\nMany fields in the JSON refer to address fields by one-letter " "abbreviations. These are defined as follows:\n\n" "- `N`: Name\n" - "- `O`: Organisation\n" + "- `O`: Organization\n" "- `A`: Street Address Line(s)\n" "- `D`: Dependent locality (may be an inner-city district or a suburb)\n" "- `C`: City or Locality\n" diff --git a/saleor/graphql/app/types.py b/saleor/graphql/app/types.py index 36de29bb9c7..58861dccc69 100644 --- a/saleor/graphql/app/types.py +++ b/saleor/graphql/app/types.py @@ -545,7 +545,7 @@ class App(ModelObjectType[models.App]): ) version = graphene.String(description="Version number of the app.") access_token = graphene.String( - description="JWT token used to authenticate by thridparty app." + description="JWT token used to authenticate by third-party app." ) author = graphene.String( description=("The App's author name." + ADDED_IN_313 + PREVIEW_FEATURE) diff --git a/saleor/graphql/attribute/descriptions.py b/saleor/graphql/attribute/descriptions.py index e313204b35b..7a938fd06b1 100644 --- a/saleor/graphql/attribute/descriptions.py +++ b/saleor/graphql/attribute/descriptions.py @@ -47,7 +47,7 @@ class AttributeValueDescriptions: + RICH_CONTENT ) PLAIN_TEXT = ( - "Represents the text of the attribute value, plain text without formating." + "Represents the text of the attribute value, plain text without formatting." ) BOOLEAN = "Represents the boolean value of the attribute value." DATE = "Represents the date value of the attribute value." diff --git a/saleor/graphql/channel/mutations/channel_create.py b/saleor/graphql/channel/mutations/channel_create.py index 3f40ad844d7..c4488f40f43 100644 --- a/saleor/graphql/channel/mutations/channel_create.py +++ b/saleor/graphql/channel/mutations/channel_create.py @@ -91,7 +91,7 @@ class OrderSettingsInput(BaseInputObjectType): automatically_fulfill_non_shippable_gift_card = graphene.Boolean( required=False, description="When enabled, all non-shippable gift card orders " - "will be fulfilled automatically. By defualt set to True.", + "will be fulfilled automatically. By default set to True.", ) expire_orders_after = Minute( required=False, diff --git a/saleor/graphql/channel/types.py b/saleor/graphql/channel/types.py index 71af5463a48..7238fa33b58 100644 --- a/saleor/graphql/channel/types.py +++ b/saleor/graphql/channel/types.py @@ -259,7 +259,7 @@ class OrderSettings(ObjectType): allow_unpaid_orders = graphene.Boolean( required=True, description=( - "Determine if it is possible to place unpdaid order by calling " + "Determine if it is possible to place unpaid order by calling " "`checkoutComplete` mutation." + ADDED_IN_315 + PREVIEW_FEATURE ), ) diff --git a/saleor/graphql/checkout/dataloaders.py b/saleor/graphql/checkout/dataloaders.py index 8220bcee1ca..80c44df819d 100644 --- a/saleor/graphql/checkout/dataloaders.py +++ b/saleor/graphql/checkout/dataloaders.py @@ -22,6 +22,7 @@ get_checkout_lines_problems, get_checkout_problems, ) +from ...core.db.connection import allow_writer_in_context from ...discount import VoucherType from ...discount.interface import VariantPromotionRuleInfo from ...payment.models import TransactionItem @@ -91,6 +92,7 @@ def with_checkout_lines(results): channel_pks = [checkout.channel_id for checkout in checkouts] + @allow_writer_in_context(self.context) def with_variants_products_collections(results): ( variants, @@ -554,16 +556,17 @@ def with_checkout_info(results): for listing in channel_listings if listing.channel_id == channel.id ] - update_delivery_method_lists_for_checkout_info( - checkout_info, - shipping_method, - collection_point, - shipping_address, - checkout_lines, - manager, - shipping_method_listings, - database_connection_name=self.database_connection_name, - ) + with allow_writer_in_context(self.context): + update_delivery_method_lists_for_checkout_info( + checkout_info, + shipping_method, + collection_point, + shipping_address, + checkout_lines, + manager, + shipping_method_listings, + database_connection_name=self.database_connection_name, + ) checkout_info_map[key] = checkout_info return [checkout_info_map[key] for key in keys] diff --git a/saleor/graphql/checkout/mutations/checkout_delivery_method_update.py b/saleor/graphql/checkout/mutations/checkout_delivery_method_update.py index b07aae60497..2b785d976aa 100644 --- a/saleor/graphql/checkout/mutations/checkout_delivery_method_update.py +++ b/saleor/graphql/checkout/mutations/checkout_delivery_method_update.py @@ -256,8 +256,8 @@ def _update_delivery_method( and collection_point.click_and_collect_option == WarehouseClickAndCollectOption.LOCAL_STOCK ): - checkout.shipping_address = collection_point.address - checkout_info.shipping_address = collection_point.address + checkout.shipping_address = collection_point.address.get_copy() + checkout_info.shipping_address = checkout.shipping_address checkout_fields_to_update += ["shipping_address"] invalidate_prices_updated_fields = invalidate_checkout( checkout_info, lines, manager, save=False diff --git a/saleor/graphql/checkout/mutations/utils.py b/saleor/graphql/checkout/mutations/utils.py index c80c36cfb69..c39a3a2d83c 100644 --- a/saleor/graphql/checkout/mutations/utils.py +++ b/saleor/graphql/checkout/mutations/utils.py @@ -33,7 +33,11 @@ from ....core.exceptions import InsufficientStock, PermissionDenied from ....discount import DiscountType, DiscountValueType from ....discount.models import CheckoutLineDiscount, PromotionRule -from ....discount.utils import create_gift_line, get_best_rule_for_checkout +from ....discount.utils import ( + create_gift_line, + fetch_promotion_rules_for_checkout_or_order, + get_best_rule, +) from ....permission.enums import CheckoutPermissions from ....product import models as product_models from ....product.models import ProductChannelListing, ProductVariant @@ -442,6 +446,7 @@ def group_lines_input_data_on_update( variant_id = cast(str, line.get("variant_id")) line_id = cast(str, line.get("line_id")) + line_db_id, variant_db_id = None, None if line_id: _, line_db_id = graphene.Node.from_global_id(line_id) @@ -452,7 +457,7 @@ def group_lines_input_data_on_update( ) if not line_db_id: - line_data = checkout_lines_data_map[variant_db_id] + line_data = checkout_lines_data_map[variant_db_id] # type: ignore[index] line_data.variant_id = variant_db_id else: line_data = checkout_lines_data_map[line_db_id] @@ -532,7 +537,7 @@ def find_variant_id_when_line_parameter_used( def apply_gift_reward_if_applicable_on_checkout_creation( - checkout: "Checkout", + checkout: "models.Checkout", database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME, ) -> None: """Apply gift reward if applicable on newly created checkout. @@ -549,9 +554,13 @@ def apply_gift_reward_if_applicable_on_checkout_creation( return _set_checkout_base_subtotal_and_total_on_checkout_creation(checkout) - - best_rule_data = get_best_rule_for_checkout( - checkout, checkout.channel, checkout.get_country(), database_connection_name + rules = fetch_promotion_rules_for_checkout_or_order(checkout) + best_rule_data = get_best_rule( + rules, + checkout.channel, + checkout.get_country(), + checkout.base_subtotal, + database_connection_name, ) if not best_rule_data: return @@ -560,22 +569,21 @@ def apply_gift_reward_if_applicable_on_checkout_creation( if not gift_listing: return - amount_value = gift_listing.price_amount with transaction.atomic(): line, _line_created = create_gift_line(checkout, gift_listing.variant_id) CheckoutLineDiscount.objects.create( type=DiscountType.ORDER_PROMOTION, line=line, - amount_value=amount_value, + amount_value=best_discount_amount, value_type=DiscountValueType.FIXED, - value=amount_value, + value=best_discount_amount, promotion_rule=best_rule, currency=checkout.currency, ) def _set_checkout_base_subtotal_and_total_on_checkout_creation( - checkout: "Checkout", + checkout: "models.Checkout", ): """Calculate and set base subtotal and total for newly created checkout.""" variants_id = [line.variant_id for line in checkout.lines.all()] diff --git a/saleor/graphql/checkout/tests/benchmark/test_checkout_mutations.py b/saleor/graphql/checkout/tests/benchmark/test_checkout_mutations.py index 7dc79c2d0e4..5d0c084f38b 100644 --- a/saleor/graphql/checkout/tests/benchmark/test_checkout_mutations.py +++ b/saleor/graphql/checkout/tests/benchmark/test_checkout_mutations.py @@ -9,8 +9,12 @@ from .....checkout.fetch import fetch_checkout_info, fetch_checkout_lines from .....checkout.models import Checkout from .....checkout.utils import add_variants_to_checkout, set_external_shipping_id +from .....discount import RewardValueType +from .....discount.models import CheckoutLineDiscount, PromotionRule from .....plugins.manager import get_plugins_manager -from .....product.models import ProductVariant, ProductVariantChannelListing +from .....product.models import Product, ProductVariant, ProductVariantChannelListing +from .....product.utils.variant_prices import update_discounted_prices_for_promotion +from .....product.utils.variants import fetch_variants_for_promotion_rules from .....warehouse.models import Stock from ....core.utils import to_global_id_or_none from ....tests.utils import get_graphql_content @@ -413,7 +417,7 @@ def test_create_checkout_with_reservations( } } - with django_assert_num_queries(66): + with django_assert_num_queries(72): response = api_client.post_graphql(query, variables) assert get_graphql_content(response)["data"]["checkoutCreate"] assert Checkout.objects.first().lines.count() == 1 @@ -431,7 +435,7 @@ def test_create_checkout_with_reservations( } } - with django_assert_num_queries(66): + with django_assert_num_queries(72): response = api_client.post_graphql(query, variables) assert get_graphql_content(response)["data"]["checkoutCreate"] assert Checkout.objects.first().lines.count() == 10 @@ -562,7 +566,7 @@ def test_create_checkout_with_order_promotion( } # when - with django_assert_num_queries(89): + with django_assert_num_queries(77): response = user_api_client.post_graphql(MUTATION_CHECKOUT_CREATE, variables) # then @@ -817,7 +821,7 @@ def test_update_checkout_lines_with_reservations( reservation_length=5, ) - with django_assert_num_queries(81): + with django_assert_num_queries(91): variant_id = graphene.Node.to_global_id("ProductVariant", variants[0].pk) variables = { "id": to_global_id_or_none(checkout), @@ -831,7 +835,7 @@ def test_update_checkout_lines_with_reservations( assert not data["errors"] # Updating multiple lines in checkout has same query count as updating one - with django_assert_num_queries(81): + with django_assert_num_queries(91): variables = { "id": to_global_id_or_none(checkout), "lines": [], @@ -1076,7 +1080,7 @@ def test_add_checkout_lines_with_reservations( new_lines.append({"quantity": 2, "variantId": variant_id}) # Adding multiple lines to checkout has same query count as adding one - with django_assert_num_queries(80): + with django_assert_num_queries(90): variables = { "id": Node.to_global_id("Checkout", checkout.pk), "lines": [new_lines[0]], @@ -1089,7 +1093,7 @@ def test_add_checkout_lines_with_reservations( checkout.lines.exclude(id=line.id).delete() - with django_assert_num_queries(80): + with django_assert_num_queries(90): variables = { "id": Node.to_global_id("Checkout", checkout.pk), "lines": new_lines, @@ -1101,6 +1105,140 @@ def test_add_checkout_lines_with_reservations( assert not data["errors"] +@pytest.mark.django_db +@pytest.mark.count_queries(autouse=False) +def test_add_checkout_lines_catalogue_discount_applies( + user_api_client, + catalogue_promotion_without_rules, + checkout, + channel_USD, + django_assert_num_queries, + count_queries, + variant_with_many_stocks, +): + # given + Stock.objects.update(quantity=100) + variant = variant_with_many_stocks + variant_id = graphene.Node.to_global_id("ProductVariant", variant.pk) + + # prepare promotion with 50% discount + promotion = catalogue_promotion_without_rules + catalogue_predicate = {"variantPredicate": {"ids": [variant_id]}} + rule = promotion.rules.create( + name="Catalogue rule percentage 50", + catalogue_predicate=catalogue_predicate, + reward_value_type=RewardValueType.PERCENTAGE, + reward_value=Decimal(50), + ) + rule.channels.add(channel_USD) + fetch_variants_for_promotion_rules(PromotionRule.objects.all()) + + # update prices + update_discounted_prices_for_promotion(Product.objects.all()) + + variables = { + "id": to_global_id_or_none(checkout), + "lines": [{"variantId": variant_id, "quantity": 3}], + "channelSlug": checkout.channel.slug, + } + + # when + with django_assert_num_queries(82): + response = user_api_client.post_graphql(MUTATION_CHECKOUT_LINES_ADD, variables) + + # then + content = get_graphql_content(response) + data = content["data"]["checkoutLinesAdd"] + assert not data["errors"] + assert checkout.lines.count() == 1 + assert CheckoutLineDiscount.objects.count() == 1 + + +@pytest.mark.django_db +@pytest.mark.count_queries(autouse=False) +def test_add_checkout_lines_multiple_catalogue_discount_applies( + user_api_client, + catalogue_promotion_without_rules, + checkout, + channel_USD, + django_assert_num_queries, + count_queries, + product_variant_list, + warehouse, +): + # given + variants = product_variant_list + variant_global_ids = [variant.get_global_id() for variant in variants] + + channel_listing = variants[2].channel_listings.first() + channel_listing.channel = channel_USD + channel_listing.currency = channel_USD.currency_code + channel_listing.save(update_fields=["channel_id", "currency"]) + + Stock.objects.bulk_create( + [ + Stock(product_variant=variant, warehouse=warehouse, quantity=1000) + for variant in variants + ] + ) + + # create many rules + promotion = catalogue_promotion_without_rules + rules = [] + catalogue_predicate = {"variantPredicate": {"ids": variant_global_ids}} + for idx in range(5): + reward_value = 2 + idx + rules.append( + PromotionRule( + name=f"Catalogue rule fixed {reward_value}", + promotion=promotion, + catalogue_predicate=catalogue_predicate, + reward_value_type=RewardValueType.FIXED, + reward_value=Decimal(reward_value), + ) + ) + for idx in range(5): + reward_value = idx * 10 + 25 + rules.append( + PromotionRule( + name=f"Catalogue rule percentage {reward_value}", + promotion=promotion, + catalogue_predicate=catalogue_predicate, + reward_value_type=RewardValueType.PERCENTAGE, + reward_value=Decimal(reward_value), + ) + ) + rules = PromotionRule.objects.bulk_create(rules) + for rule in rules: + rule.channels.add(channel_USD) + fetch_variants_for_promotion_rules(PromotionRule.objects.all()) + + # update prices + update_discounted_prices_for_promotion(Product.objects.all()) + + variables = { + "id": to_global_id_or_none(checkout), + "lines": [ + {"variantId": variant_global_ids[0], "quantity": 4}, + {"variantId": variant_global_ids[1], "quantity": 5}, + {"variantId": variant_global_ids[2], "quantity": 6}, + {"variantId": variant_global_ids[3], "quantity": 7}, + ], + "channelSlug": checkout.channel.slug, + } + + # when + with django_assert_num_queries(82): + response = user_api_client.post_graphql(MUTATION_CHECKOUT_LINES_ADD, variables) + + # then + content = get_graphql_content(response) + data = content["data"]["checkoutLinesAdd"] + assert not data["errors"] + assert checkout.lines.count() == 4 + assert CheckoutLineDiscount.objects.count() == 4 + + @pytest.mark.django_db @pytest.mark.count_queries(autouse=False) def test_add_checkout_lines_order_discount_applies( @@ -1125,7 +1263,7 @@ def test_add_checkout_lines_order_discount_applies( } # when - with django_assert_num_queries(75): + with django_assert_num_queries(85): response = user_api_client.post_graphql(MUTATION_CHECKOUT_LINES_ADD, variables) # then @@ -1159,7 +1297,7 @@ def test_add_checkout_lines_gift_discount_applies( } # when - with django_assert_num_queries(101): + with django_assert_num_queries(112): response = user_api_client.post_graphql(MUTATION_CHECKOUT_LINES_ADD, variables) # then diff --git a/saleor/graphql/checkout/tests/mutations/test_checkout_complete_with_payment.py b/saleor/graphql/checkout/tests/mutations/test_checkout_complete_with_payment.py index 21b68047422..c7727866827 100644 --- a/saleor/graphql/checkout/tests/mutations/test_checkout_complete_with_payment.py +++ b/saleor/graphql/checkout/tests/mutations/test_checkout_complete_with_payment.py @@ -19,7 +19,7 @@ from .....core.exceptions import InsufficientStock, InsufficientStockData from .....core.taxes import TaxError, zero_money, zero_taxed_money from .....discount import DiscountType, DiscountValueType, RewardValueType -from .....discount.models import CheckoutLineDiscount, PromotionRule +from .....discount.models import CheckoutLineDiscount from .....giftcard import GiftCardEvents from .....giftcard.models import GiftCard, GiftCardEvent from .....order import OrderOrigin, OrderStatus @@ -29,7 +29,6 @@ from .....payment.interface import GatewayResponse from .....payment.model_helpers import get_subtotal from .....plugins.manager import PluginsManager, get_plugins_manager -from .....product.models import VariantChannelListingPromotionRule from .....tests.utils import flush_post_commit_hooks from .....warehouse.models import Reservation, Stock, WarehouseClickAndCollectOption from .....warehouse.tests.utils import get_available_quantity_for_stock @@ -1679,186 +1678,6 @@ def test_checkout_complete_product_on_old_sale( ).exists(), "Checkout should have been deleted" -def test_checkout_complete_multiple_rules_applied( - user_api_client, - checkout_with_item, - catalogue_promotion_without_rules, - payment_dummy, - address, - shipping_method, -): - # given - checkout = checkout_with_item - checkout.shipping_address = address - checkout.shipping_method = shipping_method - checkout.billing_address = address - checkout.metadata_storage.store_value_in_metadata(items={"accepted": "true"}) - checkout.metadata_storage.store_value_in_private_metadata( - items={"accepted": "false"} - ) - checkout.save() - checkout.metadata_storage.save() - - checkout_line = checkout.lines.first() - checkout_line_quantity = checkout_line.quantity - checkout_line_variant = checkout_line.variant - - channel = checkout.channel - - reward_value_1 = Decimal("2") - reward_value_2 = Decimal("10") - rule_1, rule_2 = PromotionRule.objects.bulk_create( - [ - PromotionRule( - name="Percentage promotion rule 1", - promotion=catalogue_promotion_without_rules, - reward_value_type=RewardValueType.FIXED, - reward_value=reward_value_1, - catalogue_predicate={ - "productPredicate": { - "ids": [ - graphene.Node.to_global_id( - "Product", checkout_line_variant.product_id - ) - ] - } - }, - ), - PromotionRule( - name="Percentage promotion rule 2", - promotion=catalogue_promotion_without_rules, - reward_value_type=RewardValueType.PERCENTAGE, - reward_value=reward_value_2, - catalogue_predicate={ - "variantPredicate": { - "ids": [ - graphene.Node.to_global_id( - "ProductVariant", checkout_line_variant.id - ) - ] - } - }, - ), - ] - ) - - rule_1.channels.add(channel) - rule_2.channels.add(channel) - - variant_channel_listing = checkout_line_variant.channel_listings.get( - channel=channel - ) - discount_amount_2 = reward_value_2 / 100 * variant_channel_listing.price.amount - discounted_price = ( - variant_channel_listing.price.amount - reward_value_1 - discount_amount_2 - ) - variant_channel_listing.discounted_price_amount = discounted_price - variant_channel_listing.save(update_fields=["discounted_price_amount"]) - - VariantChannelListingPromotionRule.objects.bulk_create( - [ - VariantChannelListingPromotionRule( - variant_channel_listing=variant_channel_listing, - promotion_rule=rule_1, - discount_amount=reward_value_1, - currency=channel.currency_code, - ), - VariantChannelListingPromotionRule( - variant_channel_listing=variant_channel_listing, - promotion_rule=rule_2, - discount_amount=discount_amount_2, - currency=channel.currency_code, - ), - ] - ) - - CheckoutLineDiscount.objects.bulk_create( - [ - CheckoutLineDiscount( - line=checkout_line, - type=DiscountType.PROMOTION, - value_type=DiscountValueType.FIXED, - amount_value=reward_value_1, - currency=channel.currency_code, - promotion_rule=rule_1, - ), - CheckoutLineDiscount( - line=checkout_line, - type=DiscountType.PROMOTION, - value_type=DiscountValueType.FIXED, - amount_value=discount_amount_2, - currency=channel.currency_code, - promotion_rule=rule_2, - ), - ] - ) - - manager = get_plugins_manager(allow_replica=False) - lines, _ = fetch_checkout_lines(checkout) - checkout_info = fetch_checkout_info(checkout, lines, manager) - - total = calculations.checkout_total( - manager=manager, - checkout_info=checkout_info, - lines=lines, - address=address, - ) - payment = payment_dummy - payment.is_active = True - payment.order = None - payment.total = total.gross.amount - payment.currency = total.gross.currency - payment.checkout = checkout - payment.save() - assert not payment.transactions.exists() - - orders_count = Order.objects.count() - variables = { - "id": to_global_id_or_none(checkout), - "redirectUrl": "https://www.example.com", - } - - # when - response = user_api_client.post_graphql(MUTATION_CHECKOUT_COMPLETE, variables) - - # then - content = get_graphql_content(response) - data = content["data"]["checkoutComplete"] - assert not data["errors"] - - order_token = data["order"]["token"] - order_id = data["order"]["id"] - assert Order.objects.count() == orders_count + 1 - order = Order.objects.first() - assert str(order.id) == order_token - assert order_id == graphene.Node.to_global_id("Order", order.id) - assert order.metadata == checkout.metadata_storage.metadata - assert order.private_metadata == checkout.metadata_storage.private_metadata - - order_line = order.lines.first() - subtotal = get_subtotal(order.lines.all(), order.currency) - assert order.subtotal == subtotal - assert data["order"]["subtotal"]["gross"]["amount"] == subtotal.gross.amount - assert order.total == total - assert order.undiscounted_total == total + ( - order_line.undiscounted_total_price - order_line.total_price - ) - assert order_line.discounts.count() == 2 - - assert checkout_line_quantity == order_line.quantity - assert checkout_line_variant == order_line.variant - assert order.shipping_address == address - assert order.shipping_method == checkout.shipping_method - assert order.payments.exists() - order_payment = order.payments.first() - assert order_payment == payment - assert payment.transactions.count() == 1 - - assert not Checkout.objects.filter( - pk=checkout.pk - ).exists(), "Checkout should have been deleted" - - def test_checkout_with_voucher_on_specific_product_complete_with_product_on_promotion( user_api_client, checkout_with_item_and_voucher_specific_products, @@ -3105,7 +2924,7 @@ def test_complete_checkout_for_local_click_and_collect( order_count = Order.objects.count() checkout = checkout_with_item_for_cc checkout.collection_point = warehouse_for_cc - checkout.shipping_address = None + checkout.shipping_address = warehouse_for_cc.address checkout.save(update_fields=["collection_point", "shipping_address"]) variables = { @@ -3147,7 +2966,8 @@ def test_complete_checkout_for_local_click_and_collect( assert order.collection_point == warehouse_for_cc assert order.shipping_method is None - assert order.shipping_address == warehouse_for_cc.address + assert order.shipping_address + assert order.shipping_address.id != warehouse_for_cc.address.id assert order.shipping_price == zero_taxed_money(payment.currency) assert order.lines.count() == 1 @@ -3215,7 +3035,8 @@ def test_complete_checkout_for_global_click_and_collect( assert order.collection_point == warehouse_for_cc assert order.shipping_method is None - assert order.shipping_address == warehouse_for_cc.address + assert order.shipping_address + assert order.shipping_address.id != warehouse_for_cc.address.id assert order.shipping_price == zero_taxed_money(payment.currency) assert order.lines.count() == 1 diff --git a/saleor/graphql/checkout/tests/mutations/test_checkout_complete_with_transactions.py b/saleor/graphql/checkout/tests/mutations/test_checkout_complete_with_transactions.py index a7c39822c05..c553436e272 100644 --- a/saleor/graphql/checkout/tests/mutations/test_checkout_complete_with_transactions.py +++ b/saleor/graphql/checkout/tests/mutations/test_checkout_complete_with_transactions.py @@ -2773,7 +2773,8 @@ def test_complete_checkout_for_local_click_and_collect( assert order.collection_point == warehouse_for_cc assert order.shipping_method is None - assert order.shipping_address == warehouse_for_cc.address + assert order.shipping_address + assert order.shipping_address.id != warehouse_for_cc.address.id assert order.shipping_price == zero_taxed_money(order.channel.currency_code) assert order.lines.count() == 1 @@ -2833,7 +2834,8 @@ def test_complete_checkout_for_global_click_and_collect( assert order.collection_point == warehouse_for_cc assert order.shipping_method is None - assert order.shipping_address == warehouse_for_cc.address + assert order.shipping_address + assert order.shipping_address.id != warehouse_for_cc.address.id assert order.shipping_price == zero_taxed_money(order.channel.currency_code) assert order.lines.count() == 1 diff --git a/saleor/graphql/checkout/tests/mutations/test_checkout_delivery_method_update.py b/saleor/graphql/checkout/tests/mutations/test_checkout_delivery_method_update.py index 6a9f7f01f94..c806867976d 100644 --- a/saleor/graphql/checkout/tests/mutations/test_checkout_delivery_method_update.py +++ b/saleor/graphql/checkout/tests/mutations/test_checkout_delivery_method_update.py @@ -553,6 +553,7 @@ def test_checkout_delivery_method_update_valid_method_not_all_shipping_data_for_ ) errors = data["errors"] assert checkout.shipping_address == delivery_method.address + assert checkout.shipping_address_id != delivery_method.address.id assert not errors assert getattr(checkout, attribute_name) == delivery_method diff --git a/saleor/graphql/checkout/tests/mutations/test_order_create_from_checkout.py b/saleor/graphql/checkout/tests/mutations/test_order_create_from_checkout.py index a74f429ad2c..c2b3dccc05b 100644 --- a/saleor/graphql/checkout/tests/mutations/test_order_create_from_checkout.py +++ b/saleor/graphql/checkout/tests/mutations/test_order_create_from_checkout.py @@ -15,7 +15,7 @@ from .....checkout.models import Checkout, CheckoutLine from .....core.taxes import TaxError, zero_money, zero_taxed_money from .....discount import DiscountType, DiscountValueType, RewardValueType -from .....discount.models import CheckoutLineDiscount, Promotion +from .....discount.models import CheckoutLineDiscount from .....giftcard import GiftCardEvents from .....giftcard.models import GiftCard, GiftCardEvent from .....order import OrderOrigin, OrderStatus @@ -966,11 +966,54 @@ def test_order_from_checkout_voucher_not_increase_uses_on_preprocess_creation_fa assert code.used == 0 +MUTATION_ORDER_CREATE_FROM_CHECKOUT_PROMOTIONS = """ +mutation orderCreateFromCheckout($id: ID!){ + orderCreateFromCheckout(id: $id){ + order{ + id + total { + currency + net { + amount + } + gross { + amount + } + } + lines { + unitDiscount { + amount + } + unitDiscountType + unitDiscountValue + isGift + quantity + } + discounts { + amount { + amount + } + valueType + type + } + } + errors{ + field + message + code + variants + } + } +} +""" + + def test_order_from_checkout_on_catalogue_promotion( app_api_client, checkout_with_item_on_promotion, permission_handle_checkouts, permission_manage_checkouts, + permission_manage_orders, address, shipping_method, ): @@ -987,9 +1030,13 @@ def test_order_from_checkout_on_catalogue_promotion( # when response = app_api_client.post_graphql( - MUTATION_ORDER_CREATE_FROM_CHECKOUT, + MUTATION_ORDER_CREATE_FROM_CHECKOUT_PROMOTIONS, variables, - permissions=[permission_handle_checkouts, permission_manage_checkouts], + permissions=[ + permission_handle_checkouts, + permission_manage_checkouts, + permission_manage_orders, + ], ) # then @@ -997,22 +1044,30 @@ def test_order_from_checkout_on_catalogue_promotion( data = content["data"]["orderCreateFromCheckout"] assert not data["errors"] - order = Order.objects.first() - assert order.status == OrderStatus.UNCONFIRMED - assert order.origin == OrderOrigin.CHECKOUT - assert not order.original - - assert order.lines.count() == 1 - line = order.lines.first() - assert line.sale_id - assert line.unit_discount_reason - assert line.discounts.count() == 1 - discount = line.discounts.first() + order_db = Order.objects.first() + assert order_db.status == OrderStatus.UNCONFIRMED + assert order_db.origin == OrderOrigin.CHECKOUT + assert not order_db.original + + assert order_db.lines.count() == 1 + line_db = order_db.lines.first() + assert line_db.sale_id + assert line_db.unit_discount_reason + assert line_db.discounts.count() == 1 + discount = line_db.discounts.first() assert discount.promotion_rule assert ( - discount.amount_value == (order.undiscounted_total - order.total).gross.amount + discount.amount_value + == (order_db.undiscounted_total - order_db.total).gross.amount ) - assert not order.discounts.first() + assert not order_db.discounts.first() + + assert not data["order"]["discounts"] + assert len(data["order"]["lines"]) == 1 + line = data["order"]["lines"][0] + assert line["unitDiscount"]["amount"] == discount.amount_value / line["quantity"] + assert line["unitDiscountType"] == RewardValueType.FIXED.upper() + assert line["unitDiscountValue"] == discount.amount_value / line["quantity"] def test_order_from_checkout_on_order_promotion( @@ -1020,6 +1075,7 @@ def test_order_from_checkout_on_order_promotion( checkout_with_item_and_order_discount, permission_handle_checkouts, permission_manage_checkouts, + permission_manage_orders, address, shipping_method, ): @@ -1036,9 +1092,13 @@ def test_order_from_checkout_on_order_promotion( # when response = app_api_client.post_graphql( - MUTATION_ORDER_CREATE_FROM_CHECKOUT, + MUTATION_ORDER_CREATE_FROM_CHECKOUT_PROMOTIONS, variables, - permissions=[permission_handle_checkouts, permission_manage_checkouts], + permissions=[ + permission_handle_checkouts, + permission_manage_checkouts, + permission_manage_orders, + ], ) # then @@ -1059,6 +1119,12 @@ def test_order_from_checkout_on_order_promotion( ) assert order_discount.type == DiscountType.ORDER_PROMOTION + discounts = data["order"]["discounts"] + assert len(discounts) == 1 + assert discounts[0]["amount"]["amount"] == order_discount.amount_value + assert discounts[0]["type"] == DiscountType.ORDER_PROMOTION.upper() + assert discounts[0]["valueType"] == DiscountValueType.FIXED.upper() + def test_order_from_checkout_on_gift_promotion( app_api_client, @@ -1066,6 +1132,7 @@ def test_order_from_checkout_on_gift_promotion( gift_promotion_rule, permission_handle_checkouts, permission_manage_checkouts, + permission_manage_orders, address, shipping_method, ): @@ -1083,9 +1150,13 @@ def test_order_from_checkout_on_gift_promotion( # when response = app_api_client.post_graphql( - MUTATION_ORDER_CREATE_FROM_CHECKOUT, + MUTATION_ORDER_CREATE_FROM_CHECKOUT_PROMOTIONS, variables, - permissions=[permission_handle_checkouts, permission_manage_checkouts], + permissions=[ + permission_handle_checkouts, + permission_manage_checkouts, + permission_manage_orders, + ], ) # then @@ -1101,10 +1172,23 @@ def test_order_from_checkout_on_gift_promotion( assert not order.discounts.all() assert order.lines.count() == line_count gift_line = order.lines.get(is_gift=True) + gift_price = gift_line.variant.channel_listings.get( + channel=checkout.channel + ).discounted_price_amount assert gift_line.discounts.count() == 1 line_discount = gift_line.discounts.first() assert line_discount.promotion_rule == gift_promotion_rule assert line_discount.type == DiscountType.ORDER_PROMOTION + assert line_discount.amount_value == gift_price + assert line_discount.value == gift_price + + assert not data["order"]["discounts"] + lines = data["order"]["lines"] + assert len(lines) == 2 + gift_line_api = [line for line in lines if line["isGift"]][0] + assert gift_line_api["unitDiscount"]["amount"] == gift_price + assert gift_line_api["unitDiscountValue"] == gift_price + assert gift_line_api["unitDiscountType"] == RewardValueType.FIXED.upper() def test_order_from_checkout_on_catalogue_and_gift_promotion( @@ -1197,91 +1281,6 @@ def test_order_from_checkout_on_catalogue_and_gift_promotion( ).gross.amount == top_price + line_discount.amount_value -def test_order_from_checkout_multiple_rules_applied( - app_api_client, - checkout_with_item_on_promotion, - permission_handle_checkouts, - permission_manage_checkouts, - address, - shipping_method, -): - # given - checkout = checkout_with_item_on_promotion - checkout.shipping_address = address - checkout.shipping_method = shipping_method - checkout.billing_address = address - checkout.save() - - channel = checkout.channel - - line = checkout.lines.first() - variant = line.variant - variant_channel_listing = variant.channel_listings.get(channel=channel) - - reward_value_2 = Decimal("10.00") - promotion = Promotion.objects.first() - rule_2 = promotion.rules.create( - name="Percentage promotion rule 2", - reward_value_type=RewardValueType.PERCENTAGE, - reward_value=reward_value_2, - catalogue_predicate={ - "variantPredicate": { - "ids": [graphene.Node.to_global_id("ProductVariant", line.variant_id)] - } - }, - ) - rule_2.channels.add(channel) - - discount_amount_2 = reward_value_2 / 100 * variant_channel_listing.price.amount - - variant_channel_listing.variantlistingpromotionrule.create( - promotion_rule=rule_2, - discount_amount=discount_amount_2, - currency=channel.currency_code, - ) - CheckoutLineDiscount.objects.create( - line=line, - type=DiscountType.PROMOTION, - value_type=DiscountValueType.PERCENTAGE, - amount_value=discount_amount_2, - currency=channel.currency_code, - promotion_rule=rule_2, - ) - - variant_channel_listing.discounted_price_amount = ( - variant_channel_listing.discounted_price_amount - discount_amount_2 - ) - variant_channel_listing.save(update_fields=["discounted_price_amount"]) - - variables = { - "id": graphene.Node.to_global_id("Checkout", checkout.pk), - } - - # when - response = app_api_client.post_graphql( - MUTATION_ORDER_CREATE_FROM_CHECKOUT, - variables, - permissions=[permission_handle_checkouts, permission_manage_checkouts], - ) - - # then - content = get_graphql_content(response) - data = content["data"]["orderCreateFromCheckout"] - assert not data["errors"] - - order = Order.objects.first() - - assert order.status == OrderStatus.UNCONFIRMED - assert order.origin == OrderOrigin.CHECKOUT - assert not order.original - - assert order.lines.count() == 1 - line = order.lines.first() - assert line.sale_id == graphene.Node.to_global_id("Promotion", promotion.pk) - assert line.unit_discount_reason - assert line.discounts.count() == 2 - - @pytest.mark.integration def test_order_from_checkout_without_inventory_tracking( app_api_client, diff --git a/saleor/graphql/checkout/tests/test_checkout.py b/saleor/graphql/checkout/tests/test_checkout.py index 2d449948662..3d342aa9cf5 100644 --- a/saleor/graphql/checkout/tests/test_checkout.py +++ b/saleor/graphql/checkout/tests/test_checkout.py @@ -2708,7 +2708,9 @@ def test_checkout_with_stored_payment_methods_empty_response( # then content = get_graphql_content(response) - mocked_list_stored_payment_methods.assert_called_once_with(request_data) + mocked_list_stored_payment_methods.assert_called_once_with( + request_data, channel_slug=checkout_with_prices.channel.slug + ) assert content["data"]["checkout"]["storedPaymentMethods"] == [] @@ -2779,7 +2781,9 @@ def test_checkout_with_stored_payment_methods( # then content = get_graphql_content(response) - mocked_list_stored_payment_methods.assert_called_once_with(request_data) + mocked_list_stored_payment_methods.assert_called_once_with( + request_data, channel_slug=checkout_with_prices.channel.slug + ) assert content["data"]["checkout"]["storedPaymentMethods"] == [ { "id": payment_method_id, diff --git a/saleor/graphql/checkout/types.py b/saleor/graphql/checkout/types.py index 9980516c7d4..bfccba5d092 100644 --- a/saleor/graphql/checkout/types.py +++ b/saleor/graphql/checkout/types.py @@ -11,6 +11,7 @@ ) from ...checkout.calculations import fetch_checkout_data from ...checkout.utils import get_valid_collection_points_for_checkout +from ...core.db.connection import allow_writer_in_context from ...core.taxes import zero_money, zero_taxed_money from ...core.utils.lazyobjects import unwrap_lazy from ...graphql.core.context import get_database_connection_name @@ -288,6 +289,7 @@ def with_checkout(data): checkout.token ) + @allow_writer_in_context(info.context) def calculate_line_unit_price(data): checkout_info, lines = data database_connection_name = get_database_connection_name(info.context) @@ -365,6 +367,7 @@ def with_checkout(data): checkout.token ) + @allow_writer_in_context(info.context) def calculate_line_total_price(data): (checkout_info, lines) = data database_connection_name = get_database_connection_name(info.context) @@ -854,20 +857,26 @@ def with_checkout_info(checkout_info): @traced_resolver @prevent_sync_event_circular_query def resolve_shipping_methods(root: models.Checkout, info: ResolveInfo): + @allow_writer_in_context(info.context) + def with_checkout_info(checkout_info): + return unwrap_lazy(checkout_info.all_shipping_methods) + return ( CheckoutInfoByCheckoutTokenLoader(info.context) .load(root.token) - .then(lambda checkout_info: unwrap_lazy(checkout_info.all_shipping_methods)) + .then(with_checkout_info) ) @staticmethod def resolve_delivery_method(root: models.Checkout, info: ResolveInfo): + @allow_writer_in_context(info.context) + def with_checkout_info(checkout_info): + return checkout_info.delivery_method_info.delivery_method + return ( CheckoutInfoByCheckoutTokenLoader(info.context) .load(root.token) - .then( - lambda checkout_info: checkout_info.delivery_method_info.delivery_method - ) + .then(with_checkout_info) ) @staticmethod @@ -885,6 +894,7 @@ def calculate_quantity(lines): @traced_resolver @prevent_sync_event_circular_query def resolve_total_price(root: models.Checkout, info: ResolveInfo): + @allow_writer_in_context(info.context) def calculate_total_price(data): address, lines, checkout_info, manager = data database_connection_name = get_database_connection_name(info.context) @@ -904,6 +914,7 @@ def calculate_total_price(data): @traced_resolver @prevent_sync_event_circular_query def resolve_subtotal_price(root: models.Checkout, info: ResolveInfo): + @allow_writer_in_context(info.context) def calculate_subtotal_price(data): address, lines, checkout_info, manager = data database_connection_name = get_database_connection_name(info.context) @@ -922,6 +933,7 @@ def calculate_subtotal_price(data): @traced_resolver @prevent_sync_event_circular_query def resolve_shipping_price(root: models.Checkout, info: ResolveInfo): + @allow_writer_in_context(info.context) def calculate_shipping_price(data): address, lines, checkout_info, manager = data database_connection_name = get_database_connection_name(info.context) @@ -944,22 +956,28 @@ def resolve_lines(root: models.Checkout, info: ResolveInfo): @traced_resolver @prevent_sync_event_circular_query def resolve_available_shipping_methods(root: models.Checkout, info: ResolveInfo): + @allow_writer_in_context(info.context) + def with_checkout_info(checkout_info): + return checkout_info.valid_shipping_methods + return ( CheckoutInfoByCheckoutTokenLoader(info.context) .load(root.token) - .then(lambda checkout_info: checkout_info.valid_shipping_methods) + .then(with_checkout_info) ) @staticmethod @traced_resolver def resolve_available_collection_points(root: models.Checkout, info: ResolveInfo): + @allow_writer_in_context(info.context) def get_available_collection_points(lines): database_connection_name = get_database_connection_name(info.context) - return get_valid_collection_points_for_checkout( + result = get_valid_collection_points_for_checkout( lines, root.channel_id, database_connection_name=database_connection_name, ) + return list(result) return ( CheckoutLinesInfoByCheckoutTokenLoader(info.context) @@ -979,12 +997,12 @@ def resolve_available_payment_gateways( ) def get_available_payment_gateways(results): - (checkout, lines_info) = results + (checkout_info, lines_info) = results return manager.list_payment_gateways( currency=root.currency, - checkout_info=checkout, + checkout_info=checkout_info, checkout_lines=lines_info, - channel_slug=root.channel.slug, + channel_slug=checkout_info.channel.slug, ) return Promise.all([checkout_info, checkout_lines_info]).then( @@ -1025,20 +1043,21 @@ def resolve_language_code(root, _info): def resolve_stock_reservation_expires( root: models.Checkout, info: ResolveInfo, site ): - if not is_reservation_enabled(site.settings): - return None - - def get_oldest_stock_reservation_expiration_date(reservations): - if not reservations: + with allow_writer_in_context(info.context): + if not is_reservation_enabled(site.settings): return None - return min(reservation.reserved_until for reservation in reservations) + def get_oldest_stock_reservation_expiration_date(reservations): + if not reservations: + return None - return ( - StocksReservationsByCheckoutTokenLoader(info.context) - .load(root.token) - .then(get_oldest_stock_reservation_expiration_date) - ) + return min(reservation.reserved_until for reservation in reservations) + + return ( + StocksReservationsByCheckoutTokenLoader(info.context) + .load(root.token) + .then(get_oldest_stock_reservation_expiration_date) + ) @staticmethod @one_of_permissions_required( @@ -1269,7 +1288,9 @@ def _resolve_stored_payment_methods(data): user=user, channel=channel, ) - return manager.list_stored_payment_methods(request_data) + return manager.list_stored_payment_methods( + request_data, channel_slug=channel.slug + ) manager = get_plugin_manager_promise(info.context) channel_loader = ChannelByIdLoader(info.context).load(root.channel_id) diff --git a/saleor/graphql/core/connection.py b/saleor/graphql/core/connection.py index d109f749486..d50b1f020cd 100644 --- a/saleor/graphql/core/connection.py +++ b/saleor/graphql/core/connection.py @@ -345,13 +345,16 @@ def create_connection_slice( ) args["sort_by"] = sort_by - slice = connection_from_queryset_slice( - queryset, - args, - connection_type, - edge_type or connection_type.Edge, - pageinfo_type or graphene.relay.PageInfo, - ) + from ...core.db.connection import allow_writer_in_context + + with allow_writer_in_context(info.context): + slice = connection_from_queryset_slice( + queryset, + args, + connection_type, + edge_type or connection_type.Edge, + pageinfo_type or graphene.relay.PageInfo, + ) if isinstance(iterable, ChannelQsContext): edges_with_context = [] diff --git a/saleor/graphql/core/dataloaders.py b/saleor/graphql/core/dataloaders.py index 48b464f87e1..329470986c7 100644 --- a/saleor/graphql/core/dataloaders.py +++ b/saleor/graphql/core/dataloaders.py @@ -7,6 +7,7 @@ from promise import Promise from promise.dataloader import DataLoader as BaseLoader +from ...core.db.connection import allow_writer_in_context from ...thumbnail.models import Thumbnail from ...thumbnail.utils import get_thumbnail_format from . import SaleorContext @@ -47,7 +48,10 @@ def batch_load_fn( # pylint: disable=method-hidden ) as scope: span = scope.span span.set_tag(opentracing.tags.COMPONENT, "dataloaders") - results = self.batch_load(keys) + + with allow_writer_in_context(self.context): + results = self.batch_load(keys) + if not isinstance(results, Promise): return Promise.resolve(results) return results diff --git a/saleor/graphql/core/mutations.py b/saleor/graphql/core/mutations.py index 023e65772e6..5e5e3f46747 100644 --- a/saleor/graphql/core/mutations.py +++ b/saleor/graphql/core/mutations.py @@ -27,6 +27,7 @@ from graphene.types.mutation import MutationOptions from graphql.error import GraphQLError +from ...core.db.connection import allow_writer from ...core.error_codes import MetadataErrorCode from ...core.exceptions import PermissionDenied from ...core.utils.events import call_event @@ -511,6 +512,7 @@ def check_permissions( return one_of_permissions_or_auth_filter_required(context, all_permissions) @classmethod + @allow_writer() def mutate(cls, root, info: ResolveInfo, **data): disallow_replica_in_context(info.context) setup_context_user(info.context) @@ -1043,6 +1045,7 @@ def perform_mutation( # type: ignore[override] return count, errors @classmethod + @allow_writer() def mutate(cls, root, info: ResolveInfo, **data): disallow_replica_in_context(info.context) setup_context_user(info.context) diff --git a/saleor/graphql/core/schema.py b/saleor/graphql/core/schema.py index 3e6f856b4a0..fac28232cb8 100644 --- a/saleor/graphql/core/schema.py +++ b/saleor/graphql/core/schema.py @@ -1,5 +1,6 @@ import graphene +from ..core.descriptions import DEPRECATED_IN_3X_FIELD from ..core.doc_category import DOC_CATEGORY_TAXES from ..core.fields import BaseField from ..plugins.dataloaders import get_plugin_manager_promise @@ -13,6 +14,7 @@ class CoreQueries(graphene.ObjectType): NonNullList(TaxType), description="List of all tax rates available from tax gateway.", doc_category=DOC_CATEGORY_TAXES, + deprecation_reason=f"{DEPRECATED_IN_3X_FIELD} Use `taxClasses` field instead.", ) def resolve_tax_types(self, info: ResolveInfo): diff --git a/saleor/graphql/core/types/common.py b/saleor/graphql/core/types/common.py index f8715a54e21..253adb83ec4 100644 --- a/saleor/graphql/core/types/common.py +++ b/saleor/graphql/core/types/common.py @@ -259,7 +259,7 @@ class CheckoutError(Error): code = CheckoutErrorCode(description="The error code.", required=True) variants = NonNullList( graphene.ID, - description="List of varint IDs which causes the error.", + description="List of variant IDs which causes the error.", required=False, ) lines = NonNullList( @@ -408,7 +408,7 @@ class PermissionGroupError(Error): ) channels = NonNullList( graphene.ID, - description="List of chnnels IDs which causes the error.", + description="List of channels IDs which causes the error.", required=False, ) diff --git a/saleor/graphql/discount/tests/benchmark/test_promotion_create.py b/saleor/graphql/discount/tests/benchmark/test_promotion_create.py index 6f7d0bcaa4e..a2c4f35bc24 100644 --- a/saleor/graphql/discount/tests/benchmark/test_promotion_create.py +++ b/saleor/graphql/discount/tests/benchmark/test_promotion_create.py @@ -196,7 +196,7 @@ def test_promotion_create_order_promotion( } # when - with django_assert_num_queries(37): + with django_assert_num_queries(36): response = staff_api_client.post_graphql(PROMOTION_CREATE_MUTATION, variables) # then diff --git a/saleor/graphql/discount/tests/benchmark/test_promotion_delete.py b/saleor/graphql/discount/tests/benchmark/test_promotion_delete.py index b16ff78476f..718b36f042f 100644 --- a/saleor/graphql/discount/tests/benchmark/test_promotion_delete.py +++ b/saleor/graphql/discount/tests/benchmark/test_promotion_delete.py @@ -50,7 +50,7 @@ def test_gift_promotion_delete( } # when - with django_assert_num_queries(25): + with django_assert_num_queries(24): content = get_graphql_content( staff_api_client.post_graphql(PROMOTION_DELETE_MUTATION, variables) ) diff --git a/saleor/graphql/discount/tests/benchmark/test_promotion_rule_create.py b/saleor/graphql/discount/tests/benchmark/test_promotion_rule_create.py index a4e1217291e..7a54e10bd21 100644 --- a/saleor/graphql/discount/tests/benchmark/test_promotion_rule_create.py +++ b/saleor/graphql/discount/tests/benchmark/test_promotion_rule_create.py @@ -133,7 +133,7 @@ def test_promotion_rule_create_gift( } # when - with django_assert_num_queries(17): + with django_assert_num_queries(16): content = get_graphql_content( staff_api_client.post_graphql(PROMOTION_RULE_CREATE_MUTATION, variables) ) diff --git a/saleor/graphql/discount/tests/benchmark/test_voucher_code_bulk_delete.py b/saleor/graphql/discount/tests/benchmark/test_voucher_code_bulk_delete.py index c24b6855e96..c46c3439f0e 100644 --- a/saleor/graphql/discount/tests/benchmark/test_voucher_code_bulk_delete.py +++ b/saleor/graphql/discount/tests/benchmark/test_voucher_code_bulk_delete.py @@ -29,7 +29,7 @@ def test_voucher_code_bulk_delete_queries( variables = {"ids": ids[:1]} # when - with django_assert_num_queries(10): + with django_assert_num_queries(9): response = staff_api_client.post_graphql( VOUCHER_CODE_BULK_DELETE_MUTATION, variables ) @@ -38,7 +38,7 @@ def test_voucher_code_bulk_delete_queries( variables = {"ids": ids[1:]} - with django_assert_num_queries(10): + with django_assert_num_queries(9): response = staff_api_client.post_graphql( VOUCHER_CODE_BULK_DELETE_MUTATION, variables ) diff --git a/saleor/graphql/discount/utils.py b/saleor/graphql/discount/utils.py index a40740bc52b..6bac0d3b236 100644 --- a/saleor/graphql/discount/utils.py +++ b/saleor/graphql/discount/utils.py @@ -10,6 +10,7 @@ from ...checkout.models import Checkout from ...discount.models import Promotion, PromotionRule from ...discount.utils import update_rule_variant_relation +from ...order.models import Order from ...product.managers import ProductsQueryset, ProductVariantQueryset from ...product.models import ( Category, @@ -20,6 +21,7 @@ ) from ..checkout.filters import CheckoutDiscountedObjectWhere from ..core.connection import where_filter_qs +from ..order.filters import OrderDiscountedObjectWhere from ..product.filters import ( CategoryWhere, CollectionWhere, @@ -283,6 +285,10 @@ def _handle_predicate( return _handle_checkout_predicate( result_qs, base_qs, predicate_data, operator, currency ) + elif predicate_type == PredicateObjectType.ORDER: + return _handle_order_predicate( + result_qs, base_qs, predicate_data, operator, currency + ) def _handle_catalogue_predicate( @@ -327,6 +333,32 @@ def _handle_checkout_predicate( return result_qs +def _handle_order_predicate( + result_qs: QuerySet, + base_qs: QuerySet, + predicate_data: dict[str, Union[dict, str, list, bool]], + operator, + currency: Optional[str] = None, +): + predicate_data = _predicate_to_snake_case(predicate_data) + if predicate := predicate_data.get("discounted_object_predicate"): + if currency: + predicate["currency"] = currency + + orders = where_filter_qs( + Order.objects.filter(pk__in=base_qs.values("pk")), + {}, + OrderDiscountedObjectWhere, + predicate, + None, + ) + if operator == Operators.AND: + result_qs &= orders + else: + result_qs |= orders + return result_qs + + def _predicate_to_snake_case(obj: Any) -> Any: if isinstance(obj, dict): data = {} diff --git a/saleor/graphql/executor.py b/saleor/graphql/executor.py new file mode 100644 index 00000000000..ddafe3a2e99 --- /dev/null +++ b/saleor/graphql/executor.py @@ -0,0 +1,23 @@ +from graphql.execution import executor + +original_complete_value_catching_error = executor.complete_value_catching_error + + +def _patched_complete_value_catching_error(*args, **kwargs): + info = args[3] + from saleor.core.db.connection import allow_writer_in_context + + with allow_writer_in_context(info.context): + return original_complete_value_catching_error(*args, **kwargs) + + +def patch_executor(): + """Patch `complete_value_catching_error` function to allow writer DB in mutations. + + The `complete_value_catching_error` function is called when resolving a field in + GraphQL. This patch wraps each call with `allow_writer_in_context` context manager. + This allows to use writer DB in resolvers, when they are called via mutation, while + they will still raise or log error when a resolver is run in a query. + """ + + executor.complete_value_catching_error = _patched_complete_value_catching_error diff --git a/saleor/graphql/meta/mutations/base.py b/saleor/graphql/meta/mutations/base.py index 0ecd4926a4b..e125f013d37 100644 --- a/saleor/graphql/meta/mutations/base.py +++ b/saleor/graphql/meta/mutations/base.py @@ -5,6 +5,7 @@ from ....checkout import models as checkout_models from ....checkout.models import Checkout from ....core import models +from ....core.db.connection import allow_writer from ....core.error_codes import MetadataErrorCode from ....core.exceptions import PermissionDenied from ....discount import models as discount_models @@ -168,6 +169,7 @@ def check_permissions(cls, context, permissions=None, **data): return super().check_permissions(context, permissions) @classmethod + @allow_writer() def mutate(cls, root, info: ResolveInfo, **data): try: type_name, object_pk = cls.get_object_type_name_and_pk(data) diff --git a/saleor/graphql/notifications/mutations/external_notification_trigger.py b/saleor/graphql/notifications/mutations/external_notification_trigger.py index e4ed876f923..8875888b571 100644 --- a/saleor/graphql/notifications/mutations/external_notification_trigger.py +++ b/saleor/graphql/notifications/mutations/external_notification_trigger.py @@ -32,7 +32,7 @@ class ExternalNotificationTriggerInput(graphene.InputObjectType): extra_payload = JSONString( description=( "Additional payload that will be merged with " - "the one based on the bussines object ID." + "the one based on the business object ID." ) ) external_event_type = graphene.String( diff --git a/saleor/graphql/order/filters.py b/saleor/graphql/order/filters.py index aeb1fffe94e..aa6056ae404 100644 --- a/saleor/graphql/order/filters.py +++ b/saleor/graphql/order/filters.py @@ -2,6 +2,7 @@ import django_filters import graphene +from django.core.exceptions import ValidationError from django.db.models import Exists, OuterRef, Q from django.utils import timezone from graphql.error import GraphQLError @@ -12,6 +13,7 @@ from ...order.search import search_orders from ...payment import ChargeStatus from ...product.models import ProductVariant +from ..channel.filters import get_currency_from_filter_data from ..core.filters import ( GlobalIDMultipleChoiceFilter, ListObjectTypeFilter, @@ -20,9 +22,10 @@ ) from ..core.types import DateRangeInput, DateTimeRangeInput from ..core.utils import from_global_id_or_error +from ..discount.filters import DiscountedObjectWhere from ..payment.enums import PaymentChargeStatusEnum from ..utils import resolve_global_ids_to_primary_keys -from ..utils.filters import filter_range_field +from ..utils.filters import filter_range_field, filter_where_by_numeric_field from .enums import OrderAuthorizeStatusEnum, OrderChargeStatusEnum, OrderStatusFilter @@ -230,3 +233,28 @@ def is_valid(self): message="'ids' and 'numbers` are not allowed to use together in filter." ) return super().is_valid() + + +class OrderDiscountedObjectWhere(DiscountedObjectWhere): + class Meta: + model = Order + fields = ["subtotal_net_amount", "total_net_amount"] + + def filter_base_subtotal_price(self, queryset, name, value): + currency = get_currency_from_filter_data(self.data) + return _filter_price(queryset, name, "subtotal_net_amount", value, currency) + + def filter_base_total_price(self, queryset, name, value): + currency = get_currency_from_filter_data(self.data) + return _filter_price(queryset, name, "total_net_amount", value, currency) + + +def _filter_price(qs, _, field_name, value, currency): + # We will have single channel/currency as the rule can be applied only + # on channels with the same currencies + if not currency: + raise ValidationError( + "You must provide a currency to filter by price field.", code="required" + ) + qs = qs.filter(currency=currency) + return filter_where_by_numeric_field(qs, field_name, value) diff --git a/saleor/graphql/order/mutations/draft_order_complete.py b/saleor/graphql/order/mutations/draft_order_complete.py index 08c610d6546..46e0ca5000a 100644 --- a/saleor/graphql/order/mutations/draft_order_complete.py +++ b/saleor/graphql/order/mutations/draft_order_complete.py @@ -111,7 +111,7 @@ def perform_mutation( # type: ignore[override] cls.validate_order(order) country = get_order_country(order) - validate_draft_order(order, country, manager) + validate_draft_order(order, order.lines.all(), country, manager) with traced_atomic_transaction(): cls.update_user_fields(order) order.status = OrderStatus.UNFULFILLED diff --git a/saleor/graphql/order/mutations/draft_order_create.py b/saleor/graphql/order/mutations/draft_order_create.py index 6dfe4c639b0..4e61972b67a 100644 --- a/saleor/graphql/order/mutations/draft_order_create.py +++ b/saleor/graphql/order/mutations/draft_order_create.py @@ -481,7 +481,7 @@ def _commit_changes( ) @classmethod - def should_invalidate_prices(cls, instance, cleaned_input, is_new_instance) -> bool: + def should_invalidate_prices(cls, cleaned_input, is_new_instance) -> bool: # Force price recalculation for all new instances return is_new_instance @@ -565,7 +565,7 @@ def _save_draft_order( "display_gross_prices", ] ) - if cls.should_invalidate_prices(instance, cleaned_input, is_new_instance): + if cls.should_invalidate_prices(cleaned_input, is_new_instance): invalidate_order_prices(instance) updated_fields.extend(["should_refresh_prices"]) recalculate_order_weight(instance) diff --git a/saleor/graphql/order/mutations/draft_order_update.py b/saleor/graphql/order/mutations/draft_order_update.py index d835ff65998..83cf777f575 100644 --- a/saleor/graphql/order/mutations/draft_order_update.py +++ b/saleor/graphql/order/mutations/draft_order_update.py @@ -51,14 +51,13 @@ def get_instance(cls, info: ResolveInfo, **data): return instance @classmethod - def should_invalidate_prices(cls, instance, cleaned_input, is_new_instance) -> bool: + def should_invalidate_prices(cls, cleaned_input, *args) -> bool: return any( field in cleaned_input for field in [ "shipping_address", "billing_address", "shipping_method", - "lines", "voucher", ] ) diff --git a/saleor/graphql/order/mutations/order_update.py b/saleor/graphql/order/mutations/order_update.py index ec5e8dee2d1..06012422d54 100644 --- a/saleor/graphql/order/mutations/order_update.py +++ b/saleor/graphql/order/mutations/order_update.py @@ -84,7 +84,7 @@ def get_instance(cls, info: ResolveInfo, **data): return instance @classmethod - def should_invalidate_prices(cls, instance, cleaned_input, is_new_instance) -> bool: + def should_invalidate_prices(cls, cleaned_input, *args) -> bool: return any( cleaned_input.get(field) is not None for field in ["shipping_address", "billing_address"] @@ -101,7 +101,7 @@ def save(cls, info: ResolveInfo, instance, cleaned_input): *prepare_order_search_vector_value(instance) ) manager = get_plugin_manager_promise(info.context).get() - if cls.should_invalidate_prices(instance, cleaned_input, False): + if cls.should_invalidate_prices(cleaned_input): invalidate_order_prices(instance) instance.save() diff --git a/saleor/graphql/order/tests/mutations/test_draft_order_complete.py b/saleor/graphql/order/tests/mutations/test_draft_order_complete.py index e568f56cf21..82267764ad1 100644 --- a/saleor/graphql/order/tests/mutations/test_draft_order_complete.py +++ b/saleor/graphql/order/tests/mutations/test_draft_order_complete.py @@ -1,4 +1,5 @@ from datetime import datetime, timedelta +from decimal import Decimal from unittest.mock import patch import graphene @@ -10,10 +11,13 @@ from .....core import EventDeliveryStatus from .....core.models import EventDelivery +from .....core.prices import quantize_price from .....core.taxes import zero_taxed_money +from .....discount import DiscountValueType from .....discount.models import VoucherCustomer from .....order import OrderOrigin, OrderStatus from .....order import events as order_events +from .....order.calculations import fetch_order_prices_if_expired from .....order.error_codes import OrderErrorCode from .....order.interface import OrderTaxedPricesData from .....order.models import OrderEvent @@ -1193,3 +1197,212 @@ def side_effect(order, *args, **kwargs): order.refresh_from_db() assert not order.should_refresh_prices assert order.tax_error == "Empty tax data." + + +DRAFT_ORDER_COMPLETE_WITH_DISCOUNTS_MUTATION = """ + mutation draftComplete($id: ID!) { + draftOrderComplete(id: $id) { + errors { + field + code + message + } + order { + id + total { + net { + amount + } + } + discounts { + amount { + amount + } + valueType + type + reason + } + lines { + id + quantity + totalPrice { + net { + amount + } + } + unitDiscount { + amount + } + unitDiscountValue + unitDiscountReason + unitDiscountType + isGift + } + } + } + } + """ + + +def test_draft_order_complete_with_catalogue_and_order_discount( + staff_api_client, + permission_group_manage_orders, + staff_user, + draft_order_and_promotions, + plugins_manager, +): + # given + Allocation.objects.all().delete() + permission_group_manage_orders.user_set.add(staff_api_client.user) + + order, rule_catalogue, rule_total, _ = draft_order_and_promotions + catalogue_promotion_id = graphene.Node.to_global_id( + "Promotion", rule_catalogue.promotion_id + ) + order_promotion_id = graphene.Node.to_global_id( + "Promotion", rule_total.promotion_id + ) + rule_catalogue_value = rule_catalogue.reward_value + rule_total_value = rule_total.reward_value + + currency = order.currency + order_id = graphene.Node.to_global_id("Order", order.id) + variables = {"id": order_id} + fetch_order_prices_if_expired(order, plugins_manager, force_update=True) + + # when + response = staff_api_client.post_graphql( + DRAFT_ORDER_COMPLETE_WITH_DISCOUNTS_MUTATION, variables + ) + + # then + content = get_graphql_content(response) + order_data = content["data"]["draftOrderComplete"]["order"] + + assert len(order_data["discounts"]) == 1 + + order_discount = order_data["discounts"][0] + assert order_discount["amount"]["amount"] == 25.00 == rule_total_value + assert order_discount["reason"] == f"Promotion: {order_promotion_id}" + assert order_discount["amount"]["amount"] == 25.00 == rule_total_value + assert order_discount["valueType"] == DiscountValueType.FIXED.upper() + + lines_db = order.lines.all() + line_1_db = [line for line in lines_db if line.quantity == 3][0] + line_2_db = [line for line in lines_db if line.quantity == 2][0] + line_1_base_total = line_1_db.quantity * line_1_db.base_unit_price_amount + line_2_base_total = line_2_db.quantity * line_2_db.base_unit_price_amount + base_total = line_1_base_total + line_2_base_total + line_1_order_discount_portion = rule_total_value * line_1_base_total / base_total + line_2_order_discount_portion = rule_total_value - line_1_order_discount_portion + + lines = order_data["lines"] + line_1 = [line for line in lines if line["quantity"] == 3][0] + line_2 = [line for line in lines if line["quantity"] == 2][0] + line_1_total = quantize_price( + line_1_db.undiscounted_total_price_net_amount - line_1_order_discount_portion, + currency, + ) + assert line_1["totalPrice"]["net"]["amount"] == float(line_1_total) + assert line_1["unitDiscount"]["amount"] == 0.00 + assert line_1["unitDiscountReason"] is None + assert line_1["unitDiscountValue"] == 0.00 + + line_2_total = quantize_price( + line_2_db.undiscounted_total_price_net_amount + - rule_catalogue_value * line_2_db.quantity + - line_2_order_discount_portion, + currency, + ) + assert line_2["totalPrice"]["net"]["amount"] == float(line_2_total) + assert line_2["unitDiscount"]["amount"] == rule_catalogue_value + assert line_2["unitDiscountReason"] == f"Promotion: {catalogue_promotion_id}" + assert line_2["unitDiscountType"] == DiscountValueType.FIXED.upper() + assert line_2["unitDiscountValue"] == rule_catalogue_value + + total = ( + order.undiscounted_total_net_amount + - line_2["quantity"] * rule_catalogue_value + - rule_total_value + ) + assert order_data["total"]["net"]["amount"] == total + assert total == line_2_total + line_1_total + order.base_shipping_price_amount + + +def test_draft_order_complete_with_catalogue_and_gift_discount( + staff_api_client, + permission_group_manage_orders, + staff_user, + draft_order_and_promotions, + plugins_manager, +): + # given + Allocation.objects.all().delete() + permission_group_manage_orders.user_set.add(staff_api_client.user) + + order, rule_catalogue, rule_total, rule_gift = draft_order_and_promotions + rule_total.reward_value = Decimal(0) + rule_total.save(update_fields=["reward_value"]) + catalogue_promotion_id = graphene.Node.to_global_id( + "Promotion", rule_catalogue.promotion_id + ) + gift_promotion_id = graphene.Node.to_global_id("Promotion", rule_gift.promotion_id) + rule_catalogue_value = rule_catalogue.reward_value + + currency = order.currency + order_id = graphene.Node.to_global_id("Order", order.id) + variables = {"id": order_id} + fetch_order_prices_if_expired(order, plugins_manager, force_update=True) + + # when + response = staff_api_client.post_graphql( + DRAFT_ORDER_COMPLETE_WITH_DISCOUNTS_MUTATION, variables + ) + + # then + content = get_graphql_content(response) + order_data = content["data"]["draftOrderComplete"]["order"] + assert not order_data["discounts"] + + lines_db = order.lines.all() + line_1_db = [line for line in lines_db if line.quantity == 3][0] + line_2_db = [line for line in lines_db if line.quantity == 2][0] + gift_line_db = [line for line in lines_db if line.is_gift][0] + gift_price = gift_line_db.variant.channel_listings.get( + channel=order.channel + ).price_amount + + lines = order_data["lines"] + assert len(lines) == 3 + line_1 = [line for line in lines if line["quantity"] == 3][0] + line_2 = [line for line in lines if line["quantity"] == 2][0] + gift_line = [line for line in lines if line["isGift"] is True][0] + + line_1_total = line_1_db.undiscounted_total_price_net_amount + assert line_1["totalPrice"]["net"]["amount"] == line_1_total + assert line_1["unitDiscount"]["amount"] == 0.00 + assert line_1["unitDiscountReason"] is None + assert line_1["unitDiscountValue"] == 0.00 + + line_2_total = quantize_price( + line_2_db.undiscounted_total_price_net_amount + - rule_catalogue_value * line_2_db.quantity, + currency, + ) + assert line_2["totalPrice"]["net"]["amount"] == line_2_total + assert line_2["unitDiscount"]["amount"] == rule_catalogue_value + assert line_2["unitDiscountReason"] == f"Promotion: {catalogue_promotion_id}" + assert line_2["unitDiscountType"] == DiscountValueType.FIXED.upper() + assert line_2["unitDiscountValue"] == rule_catalogue_value + + assert gift_line["totalPrice"]["net"]["amount"] == 0.00 + assert gift_line["unitDiscount"]["amount"] == gift_price + assert gift_line["unitDiscountReason"] == f"Promotion: {gift_promotion_id}" + assert gift_line["unitDiscountType"] == DiscountValueType.FIXED.upper() + assert gift_line["unitDiscountValue"] == gift_price + + total = ( + order.undiscounted_total_net_amount - rule_catalogue_value * line_2_db.quantity + ) + assert order_data["total"]["net"]["amount"] == total + assert total == line_2_total + line_1_total + order.base_shipping_price_amount diff --git a/saleor/graphql/order/tests/mutations/test_draft_order_create.py b/saleor/graphql/order/tests/mutations/test_draft_order_create.py index bb5d8d14d33..9f6c74c85ff 100644 --- a/saleor/graphql/order/tests/mutations/test_draft_order_create.py +++ b/saleor/graphql/order/tests/mutations/test_draft_order_create.py @@ -8,8 +8,9 @@ from prices import Money from .....checkout import AddressType +from .....core.prices import quantize_price from .....core.taxes import TaxError, zero_taxed_money -from .....discount import DiscountType, DiscountValueType +from .....discount import DiscountType, DiscountValueType, RewardType, RewardValueType from .....discount.models import VoucherChannelListing, VoucherCustomer from .....order import OrderStatus from .....order import events as order_events @@ -18,6 +19,7 @@ from .....payment.model_helpers import get_subtotal from .....product.models import ProductVariant from .....tax import TaxCalculationStrategy +from .....tests.utils import round_up from ....tests.utils import assert_no_permission, get_graphql_content DRAFT_ORDER_CREATE_MUTATION = """ @@ -40,6 +42,14 @@ amount } discountName + discounts { + amount { + amount + } + valueType + type + reason + } redirectUrl lines { productName @@ -116,6 +126,7 @@ unitDiscountReason unitDiscountType unitDiscountValue + isGift } } } @@ -2289,7 +2300,7 @@ def test_draft_order_create_with_custom_price_in_order_line( assert order_line_1.undiscounted_base_unit_price_amount == expected_price_variant_1 -def test_draft_order_create_product_on_promotion( +def test_draft_order_create_product_catalogue_promotion( staff_api_client, permission_group_manage_orders, staff_user, @@ -2415,7 +2426,7 @@ def test_draft_order_create_product_on_promotion( assert event_parameters["lines"][0]["quantity"] == quantity -def test_draft_order_create_product_on_promotion_flat_taxes( +def test_draft_order_create_product_catalogue_promotion_flat_taxes( staff_api_client, permission_group_manage_orders, staff_user, @@ -2544,3 +2555,247 @@ def test_draft_order_create_product_on_promotion_flat_taxes( assert event_parameters["lines"][0]["item"] == str(order_lines[0]) assert event_parameters["lines"][0]["line_pk"] == str(order_lines[0].pk) assert event_parameters["lines"][0]["quantity"] == quantity + + +def test_draft_order_create_order_promotion_flat_rates( + staff_api_client, + permission_group_manage_orders, + customer_user, + shipping_method, + graphql_address_data, + order_promotion_rule, + variant_with_many_stocks, + channel_USD, +): + # given + query = DRAFT_ORDER_CREATE_MUTATION + permission_group_manage_orders.user_set.add(staff_api_client.user) + currency = channel_USD.currency_code + + tc = channel_USD.tax_configuration + tc.country_exceptions.all().delete() + tc.tax_calculation_strategy = TaxCalculationStrategy.FLAT_RATES + tc.prices_entered_with_tax = False + tc.save() + tax_rate = Decimal("1.23") + + rule = order_promotion_rule + promotion_id = graphene.Node.to_global_id("Promotion", rule.promotion_id) + assert rule.reward_value_type == RewardValueType.PERCENTAGE + reward_value = rule.reward_value + assert rule.reward_value == Decimal("25") + + variant = variant_with_many_stocks + user_id = graphene.Node.to_global_id("User", customer_user.id) + variant_id = graphene.Node.to_global_id("ProductVariant", variant.id) + + quantity = 4 + variant_list = [ + {"variantId": variant_id, "quantity": quantity}, + ] + + # calculate expected values + variant_price = variant.channel_listings.get( + channel=channel_USD + ).discounted_price_amount + undiscounted_subtotal_net = Decimal(quantity * variant_price) + discount_amount = quantize_price( + reward_value / 100 * undiscounted_subtotal_net, currency + ) + subtotal_net = undiscounted_subtotal_net - discount_amount + subtotal_gross = quantize_price(tax_rate * subtotal_net, currency) + shipping_price_net = shipping_method.channel_listings.get( + channel=channel_USD + ).price_amount + shipping_price_gross = quantize_price(tax_rate * shipping_price_net, currency) + total_gross = quantize_price(subtotal_gross + shipping_price_gross, currency) + + shipping_address = graphql_address_data + shipping_id = graphene.Node.to_global_id("ShippingMethod", shipping_method.id) + channel_id = graphene.Node.to_global_id("Channel", channel_USD.id) + redirect_url = "https://www.example.com" + + variables = { + "input": { + "user": user_id, + "lines": variant_list, + "billingAddress": shipping_address, + "shippingAddress": shipping_address, + "shippingMethod": shipping_id, + "channelId": channel_id, + "redirectUrl": redirect_url, + } + } + + # when + response = staff_api_client.post_graphql(query, variables) + + # then + content = get_graphql_content(response) + assert not content["data"]["draftOrderCreate"]["errors"] + order = content["data"]["draftOrderCreate"]["order"] + assert order["status"] == OrderStatus.DRAFT.upper() + assert order["subtotal"]["gross"]["amount"] == float(subtotal_gross) + assert order["total"]["gross"]["amount"] == float(total_gross) + assert order["shippingPrice"]["gross"]["amount"] == float(shipping_price_gross) + + assert len(order["discounts"]) == 1 + assert order["discounts"][0]["amount"]["amount"] == discount_amount + assert order["discounts"][0]["reason"] == f"Promotion: {promotion_id}" + assert order["discounts"][0]["type"] == DiscountType.ORDER_PROMOTION.upper() + assert order["discounts"][0]["valueType"] == RewardValueType.PERCENTAGE.upper() + + assert len(order["lines"]) == 1 + assert order["lines"][0]["quantity"] == quantity + assert order["lines"][0]["totalPrice"]["gross"]["amount"] == float(subtotal_gross) + assert order["lines"][0]["undiscountedUnitPrice"]["gross"]["amount"] == float( + quantize_price(undiscounted_subtotal_net * tax_rate / quantity, currency) + ) + assert order["lines"][0]["unitPrice"]["gross"]["amount"] == float( + round_up(subtotal_gross / quantity) + ) + + order_db = Order.objects.get() + assert order_db.total_gross_amount == total_gross + assert order_db.subtotal_gross_amount == subtotal_gross + assert order_db.shipping_price_gross_amount == shipping_price_gross + + line_db = order_db.lines.get() + assert line_db.total_price_gross_amount == subtotal_gross + assert line_db.undiscounted_unit_price_gross_amount == quantize_price( + undiscounted_subtotal_net * tax_rate / quantity, currency + ) + assert line_db.unit_price_net_amount == quantize_price( + subtotal_net / quantity, currency + ) + assert line_db.unit_price_gross_amount == round_up(subtotal_gross / quantity) + + discount_db = order_db.discounts.get() + assert discount_db.amount_value == discount_amount + assert discount_db.reason == f"Promotion: {promotion_id}" + assert discount_db.value == reward_value + assert discount_db.value_type == RewardValueType.PERCENTAGE + + +def test_draft_order_create_gift_promotion_flat_rates( + staff_api_client, + permission_group_manage_orders, + customer_user, + shipping_method, + graphql_address_data, + gift_promotion_rule, + variant_with_many_stocks, + channel_USD, +): + # given + query = DRAFT_ORDER_CREATE_MUTATION + permission_group_manage_orders.user_set.add(staff_api_client.user) + currency = channel_USD.currency_code + + tc = channel_USD.tax_configuration + tc.country_exceptions.all().delete() + tc.tax_calculation_strategy = TaxCalculationStrategy.FLAT_RATES + tc.prices_entered_with_tax = False + tc.save() + tax_rate = Decimal("1.23") + + rule = gift_promotion_rule + promotion_id = graphene.Node.to_global_id("Promotion", rule.promotion_id) + assert rule.reward_type == RewardType.GIFT + + variant = variant_with_many_stocks + user_id = graphene.Node.to_global_id("User", customer_user.id) + variant_id = graphene.Node.to_global_id("ProductVariant", variant.id) + + quantity = 4 + variant_list = [ + {"variantId": variant_id, "quantity": quantity}, + ] + + # calculate expected values + variant_price = variant.channel_listings.get( + channel=channel_USD + ).discounted_price_amount + subtotal_net = quantity * variant_price + subtotal_gross = quantize_price(tax_rate * subtotal_net, currency) + shipping_price_net = shipping_method.channel_listings.get( + channel=channel_USD + ).price_amount + shipping_price_gross = quantize_price(tax_rate * shipping_price_net, currency) + total_gross = quantize_price(subtotal_gross + shipping_price_gross, currency) + + shipping_address = graphql_address_data + shipping_id = graphene.Node.to_global_id("ShippingMethod", shipping_method.id) + channel_id = graphene.Node.to_global_id("Channel", channel_USD.id) + redirect_url = "https://www.example.com" + + variables = { + "input": { + "user": user_id, + "lines": variant_list, + "billingAddress": shipping_address, + "shippingAddress": shipping_address, + "shippingMethod": shipping_id, + "channelId": channel_id, + "redirectUrl": redirect_url, + } + } + + # when + response = staff_api_client.post_graphql(query, variables) + + # then + content = get_graphql_content(response) + assert not content["data"]["draftOrderCreate"]["errors"] + order = content["data"]["draftOrderCreate"]["order"] + + assert order["status"] == OrderStatus.DRAFT.upper() + assert order["subtotal"]["gross"]["amount"] == float(subtotal_gross) + assert Decimal(order["total"]["gross"]["amount"]) == float(total_gross) + assert Decimal(order["shippingPrice"]["gross"]["amount"]) == float( + shipping_price_gross + ) + + assert not order["discounts"] + + assert len(order["lines"]) == 2 + line = [line for line in order["lines"] if line["quantity"] == 4][0] + gift_line = [line for line in order["lines"] if line["isGift"]][0] + + assert line["totalPrice"]["gross"]["amount"] == float(subtotal_gross) + assert line["undiscountedUnitPrice"]["gross"]["amount"] == float( + subtotal_gross / quantity + ) + assert line["unitPrice"]["gross"]["amount"] == float(subtotal_gross / quantity) + assert line["unitDiscount"]["amount"] == 0.00 + + order_db = Order.objects.get() + assert order_db.total_gross_amount == total_gross + assert order_db.subtotal_gross_amount == subtotal_gross + assert order_db.shipping_price_gross_amount == shipping_price_gross + + lines_db = order_db.lines.all() + assert len(lines_db) == 2 + gift_line_db = [line for line in lines_db if line.is_gift][0] + gift_price = gift_line_db.variant.channel_listings.get( + channel=channel_USD + ).price_amount + + assert gift_line_db.total_price_gross_amount == Decimal(0) + assert gift_line_db.undiscounted_unit_price_gross_amount == Decimal(0) + assert gift_line_db.unit_price_gross_amount == Decimal(0) + assert gift_line_db.base_unit_price_amount == Decimal(0) + assert gift_line_db.unit_discount_value == gift_price + + assert gift_line["totalPrice"]["gross"]["amount"] == 0.00 + assert gift_line["undiscountedUnitPrice"]["gross"]["amount"] == 0.00 + assert gift_line["unitPrice"]["gross"]["amount"] == 0.00 + assert gift_line["unitDiscount"]["amount"] == gift_price + assert gift_line["unitDiscountReason"] == f"Promotion: {promotion_id}" + assert gift_line["unitDiscountType"] == RewardValueType.FIXED.upper() + assert gift_line["unitDiscountValue"] == gift_price + + discount_db = gift_line_db.discounts.get() + assert discount_db.amount_value == gift_price + assert discount_db.reason == f"Promotion: {promotion_id}" + assert discount_db.type == DiscountType.ORDER_PROMOTION diff --git a/saleor/graphql/order/tests/mutations/test_draft_order_update.py b/saleor/graphql/order/tests/mutations/test_draft_order_update.py index e568cc5c84b..3670562fb3f 100644 --- a/saleor/graphql/order/tests/mutations/test_draft_order_update.py +++ b/saleor/graphql/order/tests/mutations/test_draft_order_update.py @@ -1,9 +1,10 @@ import graphene +import pytest from prices import TaxedMoney from .....core.prices import quantize_price from .....core.taxes import zero_money -from .....discount import DiscountType, DiscountValueType +from .....discount import DiscountType, DiscountValueType, RewardValueType from .....order import OrderStatus from .....order.error_codes import OrderErrorCode from .....order.models import OrderEvent @@ -66,16 +67,58 @@ gross { amount } + net { + amount + } } subtotal { gross { amount } + net { + amount + } } undiscountedTotal { gross { amount } + net { + amount + } + } + discounts { + amount { + amount + } + valueType + type + reason + } + lines { + quantity + unitDiscount { + amount + } + undiscountedUnitPrice { + net { + amount + } + } + unitPrice { + net { + amount + } + } + totalPrice { + net { + amount + } + } + unitDiscountReason + unitDiscountType + unitDiscountValue + isGift } } } @@ -537,6 +580,234 @@ def test_draft_order_update_voucher_including_drafts_in_voucher_usage_invalid_co assert error["field"] == "voucher" +def test_draft_order_update_add_voucher_code_remove_order_promotion( + staff_api_client, + permission_group_manage_orders, + order_with_lines_and_order_promotion, + voucher, +): + # given + order = order_with_lines_and_order_promotion + order.status = OrderStatus.DRAFT + order.save(update_fields=["status"]) + order_discount = order.discounts.get() + assert order_discount.type == DiscountType.ORDER_PROMOTION + + discount_amount = voucher.channel_listings.get(channel=order.channel).discount_value + + permission_group_manage_orders.user_set.add(staff_api_client.user) + query = DRAFT_ORDER_UPDATE_MUTATION + order_id = graphene.Node.to_global_id("Order", order.id) + + variables = { + "id": order_id, + "input": { + "voucherCode": voucher.codes.first().code, + }, + } + + # when + response = staff_api_client.post_graphql(query, variables) + + # then + content = get_graphql_content(response) + data = content["data"]["draftOrderUpdate"] + assert not data["errors"] + + with pytest.raises(order_discount._meta.model.DoesNotExist): + order_discount.refresh_from_db() + + order.refresh_from_db() + voucher_discount = order.discounts.get() + assert voucher_discount.amount_value == discount_amount + assert voucher_discount.value == discount_amount + assert voucher_discount.type == DiscountType.VOUCHER + + assert ( + order.total_net_amount == order.undiscounted_total_net_amount - discount_amount + ) + + +def test_draft_order_update_add_voucher_code_remove_gift_promotion( + staff_api_client, + permission_group_manage_orders, + order_with_lines_and_gift_promotion, + voucher, +): + # given + order = order_with_lines_and_gift_promotion + order.status = OrderStatus.DRAFT + order.save(update_fields=["status"]) + + assert order.lines.count() == 3 + gift_line = order.lines.get(is_gift=True) + gift_discount = gift_line.discounts.get() + + discount_amount = voucher.channel_listings.get(channel=order.channel).discount_value + + permission_group_manage_orders.user_set.add(staff_api_client.user) + query = DRAFT_ORDER_UPDATE_MUTATION + order_id = graphene.Node.to_global_id("Order", order.id) + + variables = { + "id": order_id, + "input": { + "voucherCode": voucher.codes.first().code, + }, + } + + # when + response = staff_api_client.post_graphql(query, variables) + + # then + content = get_graphql_content(response) + data = content["data"]["draftOrderUpdate"] + assert not data["errors"] + + with pytest.raises(gift_line._meta.model.DoesNotExist): + gift_line.refresh_from_db() + + with pytest.raises(gift_discount._meta.model.DoesNotExist): + gift_discount.refresh_from_db() + + order.refresh_from_db() + assert order.lines.count() == 2 + voucher_discount = order.discounts.get() + assert voucher_discount.amount_value == discount_amount + assert voucher_discount.value == discount_amount + assert voucher_discount.type == DiscountType.VOUCHER + + assert ( + order.total_net_amount == order.undiscounted_total_net_amount - discount_amount + ) + + +def test_draft_order_update_remove_voucher_code_add_order_promotion( + staff_api_client, + permission_group_manage_orders, + draft_order, + voucher, + order_promotion_rule, +): + # given + order = draft_order + order.voucher = voucher + order.save(update_fields=["voucher"]) + + voucher_listing = voucher.channel_listings.get(channel=order.channel) + discount_amount = voucher_listing.discount_value + voucher_discount = order.discounts.create( + voucher=voucher, + value=discount_amount, + type=DiscountType.VOUCHER, + ) + + order.total_gross_amount -= discount_amount + order.total_net_amount -= discount_amount + order.save(update_fields=["total_net_amount", "total_gross_amount"]) + + permission_group_manage_orders.user_set.add(staff_api_client.user) + query = DRAFT_ORDER_UPDATE_MUTATION + order_id = graphene.Node.to_global_id("Order", order.id) + + variables = { + "id": order_id, + "input": { + "voucherCode": None, + }, + } + + # when + response = staff_api_client.post_graphql(query, variables) + + # then + content = get_graphql_content(response) + data = content["data"]["draftOrderUpdate"] + assert not data["errors"] + + with pytest.raises(voucher_discount._meta.model.DoesNotExist): + voucher_discount.refresh_from_db() + + order.refresh_from_db() + order_discount = order.discounts.get() + reward_value = order_promotion_rule.reward_value + assert order_discount.value == reward_value + assert order_discount.value_type == order_promotion_rule.reward_value_type + + undiscounted_subtotal = ( + order.undiscounted_total_net_amount - order.base_shipping_price_amount + ) + assert order_discount.amount.amount == reward_value / 100 * undiscounted_subtotal + assert ( + order.total_net_amount + == order.undiscounted_total_net_amount - order_discount.amount.amount + ) + + +def test_draft_order_update_remove_voucher_code_add_gift_promotion( + staff_api_client, + permission_group_manage_orders, + draft_order, + voucher, + gift_promotion_rule, +): + # given + order = draft_order + order.voucher = voucher + order.save(update_fields=["voucher"]) + assert order.lines.count() == 2 + + voucher_listing = voucher.channel_listings.get(channel=order.channel) + discount_amount = voucher_listing.discount_value + voucher_discount = order.discounts.create( + voucher=voucher, + value=discount_amount, + type=DiscountType.VOUCHER, + ) + + order.total_gross_amount -= discount_amount + order.total_net_amount -= discount_amount + order.save(update_fields=["total_net_amount", "total_gross_amount"]) + + permission_group_manage_orders.user_set.add(staff_api_client.user) + query = DRAFT_ORDER_UPDATE_MUTATION + order_id = graphene.Node.to_global_id("Order", order.id) + + variables = { + "id": order_id, + "input": { + "voucherCode": None, + }, + } + + # when + response = staff_api_client.post_graphql(query, variables) + + # then + content = get_graphql_content(response) + data = content["data"]["draftOrderUpdate"] + assert not data["errors"] + + with pytest.raises(voucher_discount._meta.model.DoesNotExist): + voucher_discount.refresh_from_db() + + order.refresh_from_db() + assert order.lines.count() == 3 + assert not order.discounts.exists() + + gift_line = order.lines.filter(is_gift=True).first() + gift_discount = gift_line.discounts.get() + gift_price = gift_line.variant.channel_listings.get( + channel=order.channel + ).price_amount + + assert gift_discount.value == gift_price + assert gift_discount.amount.amount == gift_price + assert gift_discount.value_type == DiscountValueType.FIXED + + assert order.total_net_amount == order.undiscounted_total_net_amount + + def test_draft_order_update_with_non_draft_order( staff_api_client, permission_group_manage_orders, order_with_lines, voucher ): @@ -1163,3 +1434,115 @@ def test_draft_order_update_no_shipping_method_channel_listings( assert len(errors) == 1 assert errors[0]["code"] == OrderErrorCode.SHIPPING_METHOD_NOT_APPLICABLE.name assert errors[0]["field"] == "shippingMethod" + + +def test_draft_order_update_order_promotion( + staff_api_client, + permission_group_manage_orders, + customer_user, + shipping_method, + graphql_address_data, + variant_with_many_stocks, + channel_USD, + draft_order, + order_promotion_rule, +): + # given + query = DRAFT_ORDER_UPDATE_MUTATION + permission_group_manage_orders.user_set.add(staff_api_client.user) + + rule = order_promotion_rule + promotion_id = graphene.Node.to_global_id("Promotion", rule.promotion_id) + assert rule.reward_value_type == RewardValueType.PERCENTAGE + reward_value = rule.reward_value + + variables = { + "id": graphene.Node.to_global_id("Order", draft_order.pk), + "input": { + "billingAddress": graphql_address_data, + }, + } + + # when + response = staff_api_client.post_graphql(query, variables) + + # then + content = get_graphql_content(response) + assert not content["data"]["draftOrderUpdate"]["errors"] + draft_order.refresh_from_db() + undiscounted_total = draft_order.undiscounted_total_net_amount + shipping_price = draft_order.base_shipping_price_amount + order = content["data"]["draftOrderUpdate"]["order"] + assert len(order["discounts"]) == 1 + discount_amount = reward_value / 100 * (undiscounted_total - shipping_price) + assert order["discounts"][0]["amount"]["amount"] == discount_amount + assert order["discounts"][0]["reason"] == f"Promotion: {promotion_id}" + assert order["discounts"][0]["type"] == DiscountType.ORDER_PROMOTION.upper() + assert order["discounts"][0]["valueType"] == RewardValueType.PERCENTAGE.upper() + + assert ( + order["subtotal"]["net"]["amount"] + == undiscounted_total - discount_amount - shipping_price + ) + assert order["total"]["net"]["amount"] == undiscounted_total - discount_amount + assert order["undiscountedTotal"]["net"]["amount"] == undiscounted_total + + +def test_draft_order_update_gift_promotion( + staff_api_client, + permission_group_manage_orders, + customer_user, + shipping_method, + graphql_address_data, + variant_with_many_stocks, + channel_USD, + draft_order, + gift_promotion_rule, +): + # given + query = DRAFT_ORDER_UPDATE_MUTATION + permission_group_manage_orders.user_set.add(staff_api_client.user) + + rule = gift_promotion_rule + promotion_id = graphene.Node.to_global_id("Promotion", rule.promotion_id) + + variables = { + "id": graphene.Node.to_global_id("Order", draft_order.pk), + "input": { + "billingAddress": graphql_address_data, + }, + } + + # when + response = staff_api_client.post_graphql(query, variables) + + # then + content = get_graphql_content(response) + assert not content["data"]["draftOrderUpdate"]["errors"] + + gift_line_db = [line for line in draft_order.lines.all() if line.is_gift][0] + gift_price = gift_line_db.variant.channel_listings.get( + channel=draft_order.channel + ).price_amount + + order = content["data"]["draftOrderUpdate"]["order"] + lines = order["lines"] + assert len(lines) == 3 + gift_line = [line for line in lines if line["isGift"]][0] + + assert gift_line["totalPrice"]["net"]["amount"] == 0.00 + assert gift_line["unitDiscount"]["amount"] == gift_price + assert gift_line["unitDiscountReason"] == f"Promotion: {promotion_id}" + assert gift_line["unitDiscountType"] == RewardValueType.FIXED.upper() + assert gift_line["unitDiscountValue"] == gift_price + + assert ( + order["subtotal"]["net"]["amount"] + == draft_order.undiscounted_total_net_amount + - draft_order.base_shipping_price_amount + ) + assert order["total"]["net"]["amount"] == draft_order.undiscounted_total_net_amount + assert ( + order["undiscountedTotal"]["net"]["amount"] + == draft_order.undiscounted_total_net_amount + ) diff --git a/saleor/graphql/order/tests/mutations/test_order_discount.py b/saleor/graphql/order/tests/mutations/test_order_discount.py index 9ac84ea9298..e33cdcbca41 100644 --- a/saleor/graphql/order/tests/mutations/test_order_discount.py +++ b/saleor/graphql/order/tests/mutations/test_order_discount.py @@ -7,7 +7,7 @@ from prices import Money, TaxedMoney, fixed_discount, percentage_discount from .....core.prices import quantize_price -from .....discount import DiscountValueType +from .....discount import DiscountType, DiscountValueType from .....order import OrderEvents, OrderStatus from .....order.error_codes import OrderErrorCode from .....order.interface import OrderTaxedPricesData @@ -280,6 +280,116 @@ def test_add_fixed_order_discount_to_order_by_app( assert discount_data["amount_value"] == str(order_discount.amount.amount) +def test_add_manual_discount_to_order_with_order_discount( + order_with_lines_and_order_promotion, + staff_api_client, + permission_group_manage_orders, +): + """Order discount should be deleted in a favour of manual discount.""" + # given + order = order_with_lines_and_order_promotion + order.status = OrderStatus.DRAFT + order.save(update_fields=["status"]) + order_discount = order.discounts.get() + + permission_group_manage_orders.user_set.add(staff_api_client.user) + discount_value = Decimal("10.00") + + variables = { + "orderId": graphene.Node.to_global_id("Order", order.pk), + "input": { + "valueType": DiscountValueTypeEnum.FIXED.name, + "value": discount_value, + }, + } + + # when + response = staff_api_client.post_graphql(ORDER_DISCOUNT_ADD, variables) + content = get_graphql_content(response) + + # then + data = content["data"]["orderDiscountAdd"] + order.refresh_from_db() + assert not data["errors"] + + with pytest.raises(order_discount._meta.model.DoesNotExist): + order_discount.refresh_from_db() + + assert order.discounts.count() == 1 + manual_discount = order.discounts.get() + + assert manual_discount.value == discount_value + assert manual_discount.value_type == DiscountValueType.FIXED + assert manual_discount.amount.amount == discount_value + + assert ( + order.total_net_amount == order.undiscounted_total_net_amount - discount_value + ) + assert ( + order.shipping_price_net_amount + order.subtotal_net_amount + == order.total_net_amount + ) + + +def test_add_manual_discount_to_order_with_gift_discount( + order_with_lines_and_gift_promotion, + staff_api_client, + permission_group_manage_orders, +): + """Order discount should be deleted in a favour of manual discount.""" + # given + order = order_with_lines_and_gift_promotion + order.status = OrderStatus.DRAFT + order.save(update_fields=["status"]) + + assert order.lines.count() == 3 + gift_line = order.lines.filter(is_gift=True).first() + gift_discount = gift_line.discounts.get() + + permission_group_manage_orders.user_set.add(staff_api_client.user) + discount_value = Decimal("10.00") + + variables = { + "orderId": graphene.Node.to_global_id("Order", order.pk), + "input": { + "valueType": DiscountValueTypeEnum.FIXED.name, + "value": discount_value, + }, + } + + # when + response = staff_api_client.post_graphql(ORDER_DISCOUNT_ADD, variables) + content = get_graphql_content(response) + + # then + data = content["data"]["orderDiscountAdd"] + order.refresh_from_db() + assert not data["errors"] + + assert order.lines.count() == 2 + + with pytest.raises(gift_line._meta.model.DoesNotExist): + gift_line.refresh_from_db() + + with pytest.raises(gift_discount._meta.model.DoesNotExist): + gift_discount.refresh_from_db() + + assert order.discounts.count() == 1 + manual_discount = order.discounts.get() + + assert manual_discount.value == discount_value + assert manual_discount.value_type == DiscountValueType.FIXED + assert manual_discount.amount.amount == discount_value + + assert ( + order.total_net_amount == order.undiscounted_total_net_amount - discount_value + ) + assert ( + order.shipping_price_net_amount + order.subtotal_net_amount + == order.total_net_amount + ) + + ORDER_DISCOUNT_UPDATE = """ mutation OrderDiscountUpdate($discountId: ID!, $input: OrderDiscountCommonInput!){ orderDiscountUpdate(discountId:$discountId, input: $input){ @@ -598,6 +708,9 @@ def test_update_percentage_order_discount_to_order_by_app( orderDiscountDelete(discountId: $discountId){ order{ id + discounts { + id + } } errors{ field @@ -642,8 +755,8 @@ def test_delete_order_discount_from_order( errors = data["errors"] assert len(errors) == 0 - assert order.undiscounted_total == current_undiscounted_total - assert order.total == current_undiscounted_total + assert order.undiscounted_total.net == current_undiscounted_total.net + assert order.total.net == current_undiscounted_total.net event = order.events.get() assert event.type == OrderEvents.ORDER_DISCOUNT_DELETED @@ -747,8 +860,8 @@ def test_delete_order_discount_from_order_by_app( errors = data["errors"] assert len(errors) == 0 - assert order.undiscounted_total == current_undiscounted_total - assert order.total == current_undiscounted_total + assert order.undiscounted_total.net == current_undiscounted_total.net + assert order.total.net == current_undiscounted_total.net event = order.events.get() assert event.type == OrderEvents.ORDER_DISCOUNT_DELETED @@ -756,6 +869,92 @@ def test_delete_order_discount_from_order_by_app( assert order.search_vector +def test_delete_manual_discount_from_order_with_subtotal_promotion( + draft_order_with_fixed_discount_order, + staff_api_client, + permission_group_manage_orders, + order_promotion_rule, +): + # given + order = draft_order_with_fixed_discount_order + manual_discount = draft_order_with_fixed_discount_order.discounts.get() + + permission_group_manage_orders.user_set.add(staff_api_client.user) + variables = { + "discountId": graphene.Node.to_global_id("OrderDiscount", manual_discount.pk), + } + + # when + response = staff_api_client.post_graphql(ORDER_DISCOUNT_DELETE, variables) + content = get_graphql_content(response) + + # then + data = content["data"]["orderDiscountDelete"] + assert not data["errors"] + + with pytest.raises(manual_discount._meta.model.DoesNotExist): + manual_discount.refresh_from_db() + + order.refresh_from_db() + order_discount = order.discounts.get() + reward_value = order_promotion_rule.reward_value + assert order_discount.value == reward_value + assert order_discount.value_type == order_promotion_rule.reward_value_type + + undiscounted_subtotal = ( + order.undiscounted_total_net_amount - order.base_shipping_price_amount + ) + assert order_discount.amount.amount == reward_value / 100 * undiscounted_subtotal + assert ( + order.total_net_amount + == order.undiscounted_total_net_amount - order_discount.amount.amount + ) + + +def test_delete_manual_discount_from_order_with_gift_promotion( + draft_order_with_fixed_discount_order, + staff_api_client, + permission_group_manage_orders, + gift_promotion_rule, +): + # given + order = draft_order_with_fixed_discount_order + manual_discount = draft_order_with_fixed_discount_order.discounts.get() + assert order.lines.count() == 2 + + permission_group_manage_orders.user_set.add(staff_api_client.user) + variables = { + "discountId": graphene.Node.to_global_id("OrderDiscount", manual_discount.pk), + } + + # when + response = staff_api_client.post_graphql(ORDER_DISCOUNT_DELETE, variables) + content = get_graphql_content(response) + + # then + data = content["data"]["orderDiscountDelete"] + assert not data["errors"] + + with pytest.raises(manual_discount._meta.model.DoesNotExist): + manual_discount.refresh_from_db() + + order.refresh_from_db() + assert order.lines.count() == 3 + assert not order.discounts.exists() + + gift_line = order.lines.filter(is_gift=True).first() + gift_discount = gift_line.discounts.get() + gift_price = gift_line.variant.channel_listings.get( + channel=order.channel + ).price_amount + + assert gift_discount.value == gift_price + assert gift_discount.amount.amount == gift_price + assert gift_discount.value_type == DiscountValueType.FIXED + + assert order.total_net_amount == order.undiscounted_total_net_amount + + ORDER_LINE_DISCOUNT_UPDATE = """ mutation OrderLineDiscountUpdate($input: OrderDiscountCommonInput!, $orderLineId: ID!){ orderLineDiscountUpdate(orderLineId: $orderLineId, input: $input){ @@ -798,14 +997,23 @@ def test_update_order_line_discount( line_to_discount.undiscounted_total_price = total_price line_to_discount.save() + line_to_discount.discounts.create( + value_type="fixed", + value=0, + amount_value=0, + name="Manual line discount", + type="manual", + ) + line_price_before_discount = line_to_discount.unit_price value = Decimal("5") + value_type = DiscountValueTypeEnum.FIXED reason = "New reason for unit discount" variables = { "orderLineId": graphene.Node.to_global_id("OrderLine", line_to_discount.pk), "input": { - "valueType": DiscountValueTypeEnum.FIXED.name, + "valueType": value_type.name, "value": value, "reason": reason, }, @@ -866,9 +1074,16 @@ def test_update_order_line_discount( discount_data = line_data.get("discount") assert discount_data["value"] == str(value) - assert discount_data["value_type"] == DiscountValueTypeEnum.FIXED.value + assert discount_data["value_type"] == value_type.value assert discount_data["amount_value"] == str(unit_discount.amount) + line_discount = line_to_discount.discounts.get() + assert line_discount.type == DiscountType.MANUAL + assert line_discount.value == value + assert line_discount.value_type == value_type.value + assert line_discount.reason == reason + assert line_discount.amount_value == value * line_to_discount.quantity + def test_update_order_line_discount_by_user_no_channel_access( draft_order_with_fixed_discount_order, @@ -916,11 +1131,12 @@ def test_update_order_line_discount_by_app( line_to_discount = order.lines.first() value = Decimal("5") + value_type = DiscountValueTypeEnum.FIXED reason = "New reason for unit discount" variables = { "orderLineId": graphene.Node.to_global_id("OrderLine", line_to_discount.pk), "input": { - "valueType": DiscountValueTypeEnum.FIXED.name, + "valueType": value_type.name, "value": value, "reason": reason, }, @@ -953,9 +1169,16 @@ def test_update_order_line_discount_by_app( discount_data = line_data.get("discount") assert discount_data["value"] == str(value) - assert discount_data["value_type"] == DiscountValueTypeEnum.FIXED.value + assert discount_data["value_type"] == value_type.value assert discount_data["amount_value"] == str(unit_discount.amount) + line_discount = line_to_discount.discounts.get() + assert line_discount.type == DiscountType.MANUAL + assert line_discount.value == value + assert line_discount.value_type == value_type.value + assert line_discount.reason == reason + assert line_discount.amount_value == value * line_to_discount.quantity + @pytest.mark.parametrize("status", [OrderStatus.DRAFT, OrderStatus.UNCONFIRMED]) def test_update_order_line_discount_line_with_discount( @@ -1002,11 +1225,12 @@ def test_update_order_line_discount_line_with_discount( line_undiscounted_price = line_to_discount.undiscounted_unit_price value = Decimal("50") + value_type = DiscountValueTypeEnum.PERCENTAGE reason = "New reason for unit discount" variables = { "orderLineId": graphene.Node.to_global_id("OrderLine", line_to_discount.pk), "input": { - "valueType": DiscountValueTypeEnum.PERCENTAGE.name, + "valueType": value_type.name, "value": value, "reason": reason, }, @@ -1022,7 +1246,6 @@ def test_update_order_line_discount_line_with_discount( data = content["data"]["orderLineDiscountUpdate"] line_to_discount.refresh_from_db() - errors = data["errors"] assert not errors @@ -1047,13 +1270,64 @@ def test_update_order_line_discount_line_with_discount( discount_data = line_data.get("discount") assert discount_data["value"] == str(value) - assert discount_data["value_type"] == DiscountValueTypeEnum.PERCENTAGE.value + assert discount_data["value_type"] == value_type.value assert discount_data["amount_value"] == str(unit_discount.amount) assert discount_data["old_value"] == str(line_discount_value_before_update) assert discount_data["old_value_type"] == DiscountValueTypeEnum.FIXED.value assert discount_data["old_amount_value"] == str(line_discount_amount_before_update) + line_discount = line_to_discount.discounts.get() + assert line_discount.type == DiscountType.MANUAL + assert line_discount.value == value + assert line_discount.value_type == value_type.value + assert line_discount.reason == reason + assert ( + line_discount.amount_value + == line_to_discount.unit_discount_amount * line_to_discount.quantity + ) + + +def test_update_order_line_discount_line_with_catalogue_promotion( + order_with_lines_and_catalogue_promotion, + staff_api_client, + permission_group_manage_orders, +): + # given + permission_group_manage_orders.user_set.add(staff_api_client.user) + order = order_with_lines_and_catalogue_promotion + order.status = OrderStatus.DRAFT + order.save(update_fields=["status"]) + line = order.lines.get(quantity=3) + assert line.discounts.filter(type=DiscountType.PROMOTION).exists() + + value = Decimal("5") + value_type = DiscountValueTypeEnum.FIXED + reason = "Manual fixed line discount" + variables = { + "orderLineId": graphene.Node.to_global_id("OrderLine", line.pk), + "input": { + "valueType": value_type.name, + "value": value, + "reason": reason, + }, + } + + # when + response = staff_api_client.post_graphql(ORDER_LINE_DISCOUNT_UPDATE, variables) + + # then + content = get_graphql_content(response) + data = content["data"]["orderLineDiscountUpdate"] + assert not data["errors"] + + line_discount = line.discounts.get() + assert line_discount.type == DiscountType.MANUAL + assert line_discount.value == value + assert line_discount.value_type == value_type.value + assert line_discount.reason == reason + assert line_discount.amount_value == value * line.quantity + def test_update_order_line_discount_order_is_not_draft( draft_order_with_fixed_discount_order, @@ -1141,6 +1415,13 @@ def test_delete_discount_from_order_line( line.unit_discount_value = Decimal("2.5") line.save() + line.discounts.create( + type=DiscountType.MANUAL, + value_type=DiscountValueType.FIXED, + value=Decimal("2.5"), + currency=order.currency, + ) + variables = { "orderLineId": graphene.Node.to_global_id("OrderLine", line.pk), } @@ -1169,6 +1450,8 @@ def test_delete_discount_from_order_line( line_data = lines[0] assert line_data.get("line_pk") == str(line.pk) + assert not line.discounts.exists() + @patch("saleor.plugins.manager.PluginsManager.calculate_order_line_unit") @patch("saleor.plugins.manager.PluginsManager.calculate_order_line_total") @@ -1254,6 +1537,13 @@ def test_delete_discount_from_order_line_by_app( line.unit_discount_value = Decimal("2.5") line.save() + line.discounts.create( + type=DiscountType.MANUAL, + value_type=DiscountValueType.FIXED, + value=Decimal("2.5"), + currency=order.currency, + ) + variables = { "orderLineId": graphene.Node.to_global_id("OrderLine", line.pk), } @@ -1287,6 +1577,8 @@ def test_delete_discount_from_order_line_by_app( line_data = lines[0] assert line_data.get("line_pk") == str(line.pk) + assert not line.discounts.exists() + def test_delete_order_line_discount_order_is_not_draft( draft_order_with_fixed_discount_order, @@ -1321,3 +1613,42 @@ def test_delete_order_line_discount_order_is_not_draft( assert error["code"] == OrderErrorCode.CANNOT_DISCOUNT.name assert line.unit_discount_amount == Decimal("2.5") + + +def test_delete_order_line_discount_line_with_catalogue_promotion( + order_with_lines, + staff_api_client, + permission_group_manage_orders, + catalogue_promotion, +): + # given + permission_group_manage_orders.user_set.add(staff_api_client.user) + order = order_with_lines + order.status = OrderStatus.DRAFT + order.save(update_fields=["status"]) + line = order.lines.get(quantity=3) + + manual_reward_value = Decimal(1) + line.discounts.create( + type=DiscountType.MANUAL, + value_type=DiscountValueType.FIXED, + value=manual_reward_value, + amount_value=manual_reward_value * line.quantity, + currency=order.currency, + reason="Manual line discount", + ) + + variables = { + "orderLineId": graphene.Node.to_global_id("OrderLine", line.pk), + } + + # when + response = staff_api_client.post_graphql(ORDER_LINE_DISCOUNT_REMOVE, variables) + + # then + content = get_graphql_content(response) + data = content["data"]["orderLineDiscountRemove"] + assert not data["errors"] + # Deleting manual discount should result in creating catalogue discount in this case + # https://github.com/saleor/saleor/issues/15517 + # assert line.discounts.filter(type=DiscountType.PROMOTION).exists() diff --git a/saleor/graphql/order/tests/mutations/test_order_line_update.py b/saleor/graphql/order/tests/mutations/test_order_line_update.py index 9a9e09697c3..19ad0526f9c 100644 --- a/saleor/graphql/order/tests/mutations/test_order_line_update.py +++ b/saleor/graphql/order/tests/mutations/test_order_line_update.py @@ -1,8 +1,10 @@ +from decimal import Decimal from unittest.mock import patch import graphene import pytest +from .....discount import DiscountType, RewardValueType from .....order import OrderStatus from .....order import events as order_events from .....order.error_codes import OrderErrorCode @@ -24,6 +26,12 @@ orderLine { id quantity + unitDiscount { + amount + } + unitDiscountType + unitDiscountValue + isGift } order { total { @@ -31,6 +39,11 @@ amount } } + discounts { + amount { + amount + } + } } } } @@ -417,3 +430,97 @@ def test_order_line_update_quantity_gift( assert len(errors) == 1 assert errors[0]["field"] == "id" assert errors[0]["code"] == OrderErrorCode.NON_EDITABLE_GIFT_LINE.name + + +def test_order_line_update_order_promotion( + draft_order, + staff_api_client, + permission_group_manage_orders, + order_promotion_rule, +): + # given + query = ORDER_LINE_UPDATE_MUTATION + permission_group_manage_orders.user_set.add(staff_api_client.user) + order = draft_order + + rule = order_promotion_rule + promotion_id = graphene.Node.to_global_id("Promotion", rule.promotion_id) + reward_value = Decimal("25") + assert rule.reward_value == reward_value + assert rule.reward_value_type == RewardValueType.PERCENTAGE + + order.lines.last().delete() + line = order.lines.first() + line_id = graphene.Node.to_global_id("OrderLine", line.id) + variant = line.variant + variant_channel_listing = variant.channel_listings.get(channel=order.channel) + quantity = 4 + undiscounted_subtotal = quantity * variant_channel_listing.discounted_price_amount + expected_discount = round(reward_value / 100 * undiscounted_subtotal, 2) + + variables = {"lineId": line_id, "quantity": quantity} + + # when + response = staff_api_client.post_graphql(query, variables) + + # then + content = get_graphql_content(response) + data = content["data"]["orderLineUpdate"] + + discounts = data["order"]["discounts"] + assert len(discounts) == 1 + assert discounts[0]["amount"]["amount"] == expected_discount + + discount_db = order.discounts.get() + assert discount_db.promotion_rule == rule + assert discount_db.amount_value == expected_discount + assert discount_db.type == DiscountType.ORDER_PROMOTION + assert discount_db.reason == f"Promotion: {promotion_id}" + + +def test_order_line_update_gift_promotion( + draft_order, + staff_api_client, + permission_group_manage_orders, + gift_promotion_rule, +): + # given + query = ORDER_LINE_UPDATE_MUTATION + permission_group_manage_orders.user_set.add(staff_api_client.user) + order = draft_order + rule = gift_promotion_rule + promotion_id = graphene.Node.to_global_id("Promotion", rule.promotion_id) + + order.lines.last().delete() + line = order.lines.first() + line_id = graphene.Node.to_global_id("OrderLine", line.id) + quantity = 4 + + variables = {"lineId": line_id, "quantity": quantity} + + # when + response = staff_api_client.post_graphql(query, variables) + + # then + content = get_graphql_content(response) + data = content["data"]["orderLineUpdate"] + + line = data["orderLine"] + assert line["quantity"] == quantity + assert line["unitDiscount"]["amount"] == 0 + assert line["unitDiscountValue"] == 0 + + gift_line_db = order.lines.get(is_gift=True) + gift_price = gift_line_db.variant.channel_listings.get( + channel=order.channel + ).price_amount + assert gift_line_db.unit_discount_amount == gift_price + assert gift_line_db.unit_price_gross_amount == Decimal(0) + + assert not data["order"]["discounts"] + + discount = gift_line_db.discounts.get() + assert discount.promotion_rule == rule + assert discount.amount_value == gift_price + assert discount.type == DiscountType.ORDER_PROMOTION + assert discount.reason == f"Promotion: {promotion_id}" diff --git a/saleor/graphql/order/tests/mutations/test_order_lines_create.py b/saleor/graphql/order/tests/mutations/test_order_lines_create.py index da5680b5891..9c320976a7c 100644 --- a/saleor/graphql/order/tests/mutations/test_order_lines_create.py +++ b/saleor/graphql/order/tests/mutations/test_order_lines_create.py @@ -57,6 +57,12 @@ currency } } + unitDiscount { + amount + } + unitDiscountType + unitDiscountValue + isGift } order { total { @@ -64,6 +70,11 @@ amount } } + discounts { + amount { + amount + } + } } } } @@ -611,7 +622,6 @@ def test_order_lines_create_variant_on_promotion( line_data = data["orderLines"][0] assert line_data["productSku"] == variant.sku assert line_data["quantity"] == quantity - assert line_data["quantity"] == quantity assert ( line_data["unitPrice"]["gross"]["amount"] @@ -644,6 +654,152 @@ def test_order_lines_create_variant_on_promotion( ) +@pytest.mark.parametrize("status", [OrderStatus.DRAFT, OrderStatus.UNCONFIRMED]) +@patch("saleor.plugins.manager.PluginsManager.draft_order_updated") +@patch("saleor.plugins.manager.PluginsManager.order_updated") +def test_order_lines_create_order_promotion( + order_updated_webhook_mock, + draft_order_updated_webhook_mock, + status, + order_with_lines, + permission_group_manage_orders, + staff_api_client, + variant_with_many_stocks, + order_promotion_rule, +): + # given + query = ORDER_LINES_CREATE_MUTATION + + order = order_with_lines + order.status = status + order.save(update_fields=["status"]) + order.lines.all().delete() + + rule = order_promotion_rule + promotion_id = graphene.Node.to_global_id("Promotion", rule.promotion_id) + assert rule.reward_value_type == RewardValueType.PERCENTAGE + assert rule.reward_value == Decimal("25") + + variant = variant_with_many_stocks + quantity = 5 + order_id = graphene.Node.to_global_id("Order", order.id) + variant_id = graphene.Node.to_global_id("ProductVariant", variant.id) + variant_channel_listing = variant.channel_listings.get(channel=order.channel) + expected_discount = round( + quantity * variant_channel_listing.discounted_price.amount * Decimal(0.25), 2 + ) + expected_unit_discount = round(expected_discount / quantity, 2) + + variables = {"orderId": order_id, "variantId": variant_id, "quantity": quantity} + permission_group_manage_orders.user_set.add(staff_api_client.user) + + # when + response = staff_api_client.post_graphql(query, variables) + + # then + assert_proper_webhook_called_once( + order, status, draft_order_updated_webhook_mock, order_updated_webhook_mock + ) + assert OrderEvent.objects.count() == 1 + assert OrderEvent.objects.last().type == order_events.OrderEvents.ADDED_PRODUCTS + content = get_graphql_content(response) + data = content["data"]["orderLinesCreate"] + + line_data = data["orderLines"][0] + assert line_data["productSku"] == variant.sku + assert line_data["quantity"] == quantity + assert line_data["unitDiscount"]["amount"] == 0.00 + assert ( + line_data["unitPrice"]["gross"]["amount"] + == variant_channel_listing.price_amount - expected_unit_discount + ) + assert ( + line_data["unitPrice"]["net"]["amount"] + == variant_channel_listing.price_amount - expected_unit_discount + ) + + line = order.lines.get(product_sku=variant.sku) + assert line.unit_discount_amount == 0 + assert ( + line.unit_price_gross_amount + == variant_channel_listing.discounted_price.amount - expected_unit_discount + ) + + assert len(data["order"]["discounts"]) == 1 + discount = data["order"]["discounts"][0] + assert discount["amount"]["amount"] == expected_discount + + discount = order.discounts.get() + assert discount.promotion_rule == rule + assert discount.amount_value == expected_discount + assert discount.type == DiscountType.ORDER_PROMOTION + assert discount.reason == f"Promotion: {promotion_id}" + + +@pytest.mark.parametrize("status", [OrderStatus.DRAFT, OrderStatus.UNCONFIRMED]) +@patch("saleor.plugins.manager.PluginsManager.draft_order_updated") +@patch("saleor.plugins.manager.PluginsManager.order_updated") +def test_order_lines_create_gift_promotion( + order_updated_webhook_mock, + draft_order_updated_webhook_mock, + status, + order_with_lines, + permission_group_manage_orders, + staff_api_client, + variant_with_many_stocks, + gift_promotion_rule, +): + # given + query = ORDER_LINES_CREATE_MUTATION + + order = order_with_lines + order.status = status + order.save(update_fields=["status"]) + order.lines.all().delete() + + rule = gift_promotion_rule + promotion_id = graphene.Node.to_global_id("Promotion", rule.promotion_id) + + variant = variant_with_many_stocks + quantity = 5 + order_id = graphene.Node.to_global_id("Order", order.id) + variant_id = graphene.Node.to_global_id("ProductVariant", variant.id) + + variables = {"orderId": order_id, "variantId": variant_id, "quantity": quantity} + permission_group_manage_orders.user_set.add(staff_api_client.user) + + # when + response = staff_api_client.post_graphql(query, variables) + + # then + assert_proper_webhook_called_once( + order, status, draft_order_updated_webhook_mock, order_updated_webhook_mock + ) + assert OrderEvent.objects.count() == 1 + assert OrderEvent.objects.last().type == order_events.OrderEvents.ADDED_PRODUCTS + content = get_graphql_content(response) + data = content["data"]["orderLinesCreate"] + + lines = data["orderLines"] + # gift line is not returned + assert len(lines) == 1 + + gift_line_db = order.lines.get(is_gift=True) + gift_price = gift_line_db.variant.channel_listings.get( + channel=order.channel + ).price_amount + assert gift_line_db.unit_discount_amount == gift_price + assert gift_line_db.unit_price_gross_amount == Decimal(0) + + assert not data["order"]["discounts"] + + discount = gift_line_db.discounts.get() + assert discount.promotion_rule == rule + assert discount.amount_value == gift_price + assert discount.type == DiscountType.ORDER_PROMOTION + assert discount.reason == f"Promotion: {promotion_id}" + + @pytest.mark.parametrize("status", [OrderStatus.DRAFT, OrderStatus.UNCONFIRMED]) @patch("saleor.plugins.manager.PluginsManager.draft_order_updated") @patch("saleor.plugins.manager.PluginsManager.order_updated") diff --git a/saleor/graphql/order/tests/queries/test_draft_order_with_filter.py b/saleor/graphql/order/tests/queries/test_draft_order_with_filter.py index d7c15b56ac7..de79b9412f3 100644 --- a/saleor/graphql/order/tests/queries/test_draft_order_with_filter.py +++ b/saleor/graphql/order/tests/queries/test_draft_order_with_filter.py @@ -3,6 +3,7 @@ import graphene import pytest +from django.core.exceptions import ValidationError from freezegun import freeze_time from .....core.postgres import FlatConcatSearchVector @@ -13,7 +14,9 @@ prepare_order_search_vector_value, update_order_search_vector, ) +from ....core.connection import where_filter_qs from ....tests.utils import get_graphql_content +from ...filters import OrderDiscountedObjectWhere @pytest.fixture @@ -218,3 +221,243 @@ def test_draft_orders_query_with_filter_search( response = staff_api_client.post_graphql(draft_orders_query_with_filter, variables) content = get_graphql_content(response) assert content["data"]["draftOrders"]["totalCount"] == count + + +@pytest.mark.parametrize(("gte", "count"), [(20, 1), (0, 1), (500, 0), (20.01, 0)]) +def test_draft_orders_query_with_filter_base_total_price_range(draft_order, gte, count): + # given + order = draft_order + currency = order.currency + order.total_net_amount = Decimal("20") + order.save(update_fields=["total_net_amount"]) + + qs = Order.objects.all() + predicate_data = { + "currency": currency, + "base_total_price": { + "range": { + "gte": gte, + } + }, + } + + # when + result = where_filter_qs( + qs, + {}, + OrderDiscountedObjectWhere, + predicate_data, + None, + ) + + # then + assert result.count() == count + if count: + assert result.first() == order + + +@pytest.mark.parametrize(("gte", "count"), [(20, 1), (0, 1), (500, 0), (20.01, 0)]) +def test_draft_orders_query_with_filter_base_subtotal_price_range( + draft_order, gte, count +): + # given + order = draft_order + currency = order.currency + order.subtotal_net_amount = Decimal("20") + order.save(update_fields=["subtotal_net_amount"]) + + qs = Order.objects.all() + predicate_data = { + "currency": currency, + "base_subtotal_price": { + "range": { + "gte": gte, + } + }, + } + + # when + result = where_filter_qs( + qs, + {}, + OrderDiscountedObjectWhere, + predicate_data, + None, + ) + + # then + assert result.count() == count + if count: + assert result.first() == order + + +@pytest.mark.parametrize( + ("one_of", "count"), [([1, 20, 70], 1), ([3, 20.1], 0), ([-3, 0], 0)] +) +def test_draft_orders_query_with_filter_base_total_price_one_of( + draft_order, one_of, count +): + # given + order = draft_order + currency = order.currency + order.total_net_amount = Decimal("20") + order.save(update_fields=["total_net_amount"]) + + qs = Order.objects.all() + predicate_data = { + "currency": currency, + "base_total_price": {"one_of": one_of}, + } + + # when + result = where_filter_qs( + qs, + {}, + OrderDiscountedObjectWhere, + predicate_data, + None, + ) + + # then + assert result.count() == count + if count: + assert result.first() == order + + +@pytest.mark.parametrize( + ("one_of", "count"), [([1, 20, 70], 1), ([3, 20.1], 0), ([-3, 0], 0)] +) +def test_draft_orders_query_with_filter_base_subtotal_price_one_of( + draft_order, one_of, count +): + # given + order = draft_order + currency = order.currency + order.subtotal_net_amount = Decimal("20") + order.save(update_fields=["subtotal_net_amount"]) + + qs = Order.objects.all() + predicate_data = { + "currency": currency, + "base_subtotal_price": {"one_of": one_of}, + } + + # when + result = where_filter_qs( + qs, + {}, + OrderDiscountedObjectWhere, + predicate_data, + None, + ) + + # then + assert result.count() == count + if count: + assert result.first() == order + + +def test_draft_orders_query_with_filter_base_total_price_missing_currency(draft_order): + # given + order = draft_order + order.total_net_amount = Decimal("20") + order.save(update_fields=["total_net_amount"]) + + qs = Order.objects.all() + predicate_data = { + "base_total_price": { + "range": { + "gte": 20, + } + }, + } + + # when + with pytest.raises(ValidationError) as validation_error: + where_filter_qs( + qs, + {}, + OrderDiscountedObjectWhere, + predicate_data, + None, + ) + + # then + assert validation_error.value.code == "required" + + +def test_draft_orders_query_with_filter_base_subtotal_price_missing_currency( + draft_order, +): + # given + order = draft_order + order.subtotal_net_amount = Decimal("20") + order.save(update_fields=["subtotal_net_amount"]) + + qs = Order.objects.all() + predicate_data = { + "base_subtotal_price": { + "range": { + "gte": 20, + } + }, + } + + # when + with pytest.raises(ValidationError) as validation_error: + where_filter_qs( + qs, + {}, + OrderDiscountedObjectWhere, + predicate_data, + None, + ) + + # then + assert validation_error.value.code == "required" + + +def test_draft_orders_query_with_filter_price_with_and_or(draft_order): + # given + order = draft_order + currency = order.currency + order.total_net_amount = Decimal("20") + order.save(update_fields=["total_net_amount"]) + + qs = Order.objects.all() + predicate_data = { + "AND": [ + { + "OR": [ + { + "currency": currency, + "base_total_price": { + "range": { + "gte": 20, + } + }, + }, + { + "currency": currency, + "base_total_price": { + "range": { + "gte": 10, + } + }, + }, + ] + } + ], + } + + # when + result = where_filter_qs( + qs, + {}, + OrderDiscountedObjectWhere, + predicate_data, + None, + ) + + # then + assert result.count() == 1 diff --git a/saleor/graphql/order/tests/queries/test_order.py b/saleor/graphql/order/tests/queries/test_order.py index 3b25d606076..20e2890993f 100644 --- a/saleor/graphql/order/tests/queries/test_order.py +++ b/saleor/graphql/order/tests/queries/test_order.py @@ -349,7 +349,7 @@ def test_order_query( expected_methods = ShippingMethod.objects.applicable_shipping_methods( price=order.subtotal.gross, - weight=order.get_total_weight(), + weight=order.weight, country_code=order.shipping_address.country.code, channel_id=order.channel_id, ) @@ -997,7 +997,7 @@ def test_order_query_in_pln_channel( expected_methods = ShippingMethod.objects.applicable_shipping_methods( price=order.subtotal.gross, - weight=order.get_total_weight(), + weight=order.weight, country_code=order.shipping_address.country.code, channel_id=order.channel_id, ) diff --git a/saleor/graphql/order/tests/test_draft_order_validate.py b/saleor/graphql/order/tests/test_draft_order_validate.py index 1a7b5eef418..c5c49437733 100644 --- a/saleor/graphql/order/tests/test_draft_order_validate.py +++ b/saleor/graphql/order/tests/test_draft_order_validate.py @@ -1,5 +1,5 @@ from datetime import datetime, timedelta -from unittest.mock import Mock +from unittest.mock import patch import pytest import pytz @@ -15,7 +15,10 @@ def test_validate_draft_order(draft_order): # should not raise any errors assert ( validate_draft_order( - draft_order, "US", get_plugins_manager(allow_replica=False) + draft_order, + draft_order.lines.all(), + "US", + get_plugins_manager(allow_replica=False), ) is None ) @@ -27,7 +30,10 @@ def test_validate_draft_order_without_sku(draft_order): # should not raise any errors assert ( validate_draft_order( - draft_order, "US", get_plugins_manager(allow_replica=False) + draft_order, + draft_order.lines.all(), + "US", + get_plugins_manager(allow_replica=False), ) is None ) @@ -40,7 +46,9 @@ def test_validate_draft_order_wrong_shipping(draft_order): shipping_zone.save() assert order.shipping_address.country.code not in shipping_zone.countries with pytest.raises(ValidationError) as e: - validate_draft_order(order, "US", get_plugins_manager(allow_replica=False)) + validate_draft_order( + order, order.lines.all(), "US", get_plugins_manager(allow_replica=False) + ) msg = "Shipping method is not valid for chosen shipping address" assert e.value.error_dict["shipping"][0].message == msg @@ -48,7 +56,9 @@ def test_validate_draft_order_wrong_shipping(draft_order): def test_validate_draft_order_no_order_lines(order, shipping_method): order.shipping_method = shipping_method with pytest.raises(ValidationError) as e: - validate_draft_order(order, "US", get_plugins_manager(allow_replica=False)) + validate_draft_order( + order, order.lines.all(), "US", get_plugins_manager(allow_replica=False) + ) msg = "Could not create order without any products." assert e.value.error_dict["lines"][0].message == msg @@ -62,7 +72,9 @@ def test_validate_draft_order_non_existing_variant(draft_order): assert line.variant is None with pytest.raises(ValidationError) as e: - validate_draft_order(order, "US", get_plugins_manager(allow_replica=False)) + validate_draft_order( + order, order.lines.all(), "US", get_plugins_manager(allow_replica=False) + ) msg = "Could not create orders with non-existing products." assert e.value.error_dict["lines"][0].message == msg @@ -77,7 +89,9 @@ def test_validate_draft_order_with_unpublished_product(draft_order): line.refresh_from_db() with pytest.raises(ValidationError) as e: - validate_draft_order(order, "US", get_plugins_manager(allow_replica=False)) + validate_draft_order( + order, order.lines.all(), "US", get_plugins_manager(allow_replica=False) + ) msg = "Can't finalize draft with unpublished product." error = e.value.error_dict["lines"][0] @@ -93,7 +107,9 @@ def test_validate_draft_order_with_unavailable_for_purchase_product(draft_order) line.refresh_from_db() with pytest.raises(ValidationError) as e: - validate_draft_order(order, "US", get_plugins_manager(allow_replica=False)) + validate_draft_order( + order, order.lines.all(), "US", get_plugins_manager(allow_replica=False) + ) msg = "Can't finalize draft with product unavailable for purchase." error = e.value.error_dict["lines"][0] @@ -113,7 +129,9 @@ def test_validate_draft_order_with_product_available_for_purchase_in_future( line.refresh_from_db() with pytest.raises(ValidationError) as e: - validate_draft_order(order, "US", get_plugins_manager(allow_replica=False)) + validate_draft_order( + order, order.lines.all(), "US", get_plugins_manager(allow_replica=False) + ) msg = "Can't finalize draft with product unavailable for purchase." error = e.value.error_dict["lines"][0] @@ -131,7 +149,9 @@ def test_validate_draft_order_out_of_stock_variant(draft_order): stock.save(update_fields=["quantity"]) with pytest.raises(ValidationError) as e: - validate_draft_order(order, "US", get_plugins_manager(allow_replica=False)) + validate_draft_order( + order, order.lines.all(), "US", get_plugins_manager(allow_replica=False) + ) msg = "Insufficient product stock." assert e.value.error_dict["lines"][0].message == msg @@ -141,7 +161,9 @@ def test_validate_draft_order_no_shipping_address(draft_order): order.shipping_address = None with pytest.raises(ValidationError) as e: - validate_draft_order(order, "US", get_plugins_manager(allow_replica=False)) + validate_draft_order( + order, order.lines.all(), "US", get_plugins_manager(allow_replica=False) + ) error = e.value.error_dict["order"][0] assert error.message == "Can't finalize draft with no shipping address." assert error.code == OrderErrorCode.ORDER_NO_SHIPPING_ADDRESS.value @@ -152,7 +174,9 @@ def test_validate_draft_order_no_billing_address(draft_order): order.billing_address = None with pytest.raises(ValidationError) as e: - validate_draft_order(order, "US", get_plugins_manager(allow_replica=False)) + validate_draft_order( + order, order.lines.all(), "US", get_plugins_manager(allow_replica=False) + ) error = e.value.error_dict["order"][0] assert error.message == "Can't finalize draft with no billing address." assert error.code == OrderErrorCode.BILLING_ADDRESS_NOT_SET.value @@ -163,35 +187,44 @@ def test_validate_draft_order_no_shipping_method(draft_order): order.shipping_method = None with pytest.raises(ValidationError) as e: - validate_draft_order(order, "US", get_plugins_manager(allow_replica=False)) + validate_draft_order( + order, order.lines.all(), "US", get_plugins_manager(allow_replica=False) + ) error = e.value.error_dict["shipping"][0] assert error.message == "Shipping method is required." assert error.code == OrderErrorCode.SHIPPING_METHOD_REQUIRED.value -def test_validate_draft_order_no_shipping_method_shipping_not_required(draft_order): +@patch("saleor.graphql.order.utils.is_shipping_required") +def test_validate_draft_order_no_shipping_method_shipping_not_required( + mocked_is_shipping_required, draft_order +): order = draft_order order.shipping_method = None - required_mock = Mock(return_value=False) - order.is_shipping_required = required_mock + mocked_is_shipping_required.return_value = False assert ( - validate_draft_order(order, "US", get_plugins_manager(allow_replica=False)) + validate_draft_order( + order, order.lines.all(), "US", get_plugins_manager(allow_replica=False) + ) is None ) +@patch("saleor.graphql.order.utils.is_shipping_required") def test_validate_draft_order_no_shipping_address_no_method_shipping_not_required( + mocked_is_shipping_required, draft_order, ): order = draft_order order.shipping_method = None order.shipping_address = None - required_mock = Mock(return_value=False) - order.is_shipping_required = required_mock + mocked_is_shipping_required.return_value = False assert ( - validate_draft_order(order, "US", get_plugins_manager(allow_replica=False)) + validate_draft_order( + order, order.lines.all(), "US", get_plugins_manager(allow_replica=False) + ) is None ) @@ -203,7 +236,9 @@ def test_validate_draft_order_voucher(draft_order_with_voucher): # when & then with pytest.raises(ValidationError) as e: - validate_draft_order(order, "US", get_plugins_manager(allow_replica=False)) + validate_draft_order( + order, order.lines.all(), "US", get_plugins_manager(allow_replica=False) + ) error = e.value.error_dict["voucher"][0] assert error.code == OrderErrorCode.INVALID_VOUCHER.value diff --git a/saleor/graphql/order/types.py b/saleor/graphql/order/types.py index 3112f2065f3..57154d85fa7 100644 --- a/saleor/graphql/order/types.py +++ b/saleor/graphql/order/types.py @@ -13,6 +13,7 @@ from ...account.models import User as UserModel from ...checkout.utils import get_external_shipping_id from ...core.anonymize import obfuscate_address, obfuscate_email +from ...core.db.connection import allow_writer_in_context from ...core.prices import quantize_price from ...core.taxes import zero_money from ...discount import DiscountType @@ -24,6 +25,7 @@ from ...graphql.utils import get_user_or_app_from_context from ...graphql.warehouse.dataloaders import StockByIdLoader, WarehouseByIdLoader from ...order import OrderStatus, calculations, models +from ...order.calculations import fetch_order_prices_if_expired from ...order.models import FulfillmentStatus from ...order.utils import ( get_order_country, @@ -44,7 +46,6 @@ from ...permission.utils import has_one_of_permissions from ...product.models import ALL_PRODUCTS_PERMISSIONS, ProductMediaTypes from ...shipping.interface import ShippingMethodData -from ...shipping.models import ShippingMethodChannelListing from ...shipping.utils import convert_to_shipping_method_data from ...tax.utils import get_display_gross_prices from ...thumbnail.utils import ( @@ -932,6 +933,7 @@ def _resolve_thumbnail(result): @traced_resolver @prevent_sync_event_circular_query def resolve_unit_price(root: models.OrderLine, info): + @allow_writer_in_context(info.context) def _resolve_unit_price(data): order, lines, manager = data database_connection_name = get_database_connection_name(info.context) @@ -956,6 +958,7 @@ def resolve_quantity_to_fulfill(root: models.OrderLine, info): @traced_resolver @prevent_sync_event_circular_query def resolve_undiscounted_unit_price(root: models.OrderLine, info): + @allow_writer_in_context(info.context) def _resolve_undiscounted_unit_price(data): order, lines, manager = data database_connection_name = get_database_connection_name(info.context) @@ -989,6 +992,7 @@ def resolve_unit_discount(root: models.OrderLine, _info): @staticmethod @traced_resolver def resolve_tax_rate(root: models.OrderLine, info): + @allow_writer_in_context(info.context) def _resolve_tax_rate(data): order, lines, manager = data database_connection_name = get_database_connection_name(info.context) @@ -1009,6 +1013,7 @@ def _resolve_tax_rate(data): @traced_resolver @prevent_sync_event_circular_query def resolve_total_price(root: models.OrderLine, info): + @allow_writer_in_context(info.context) def _resolve_total_price(data): order, lines, manager = data database_connection_name = get_database_connection_name(info.context) @@ -1029,6 +1034,7 @@ def _resolve_total_price(data): @traced_resolver @prevent_sync_event_circular_query def resolve_undiscounted_total_price(root: models.OrderLine, info): + @allow_writer_in_context(info.context) def _resolve_undiscounted_total_price(data): order, lines, manager = data database_connection_name = get_database_connection_name(info.context) @@ -1521,7 +1527,11 @@ def resolve_token(root: models.Order, info): @staticmethod def resolve_discounts(root: models.Order, info): - return OrderDiscountsByOrderIDLoader(info.context).load(root.id) + def with_manager(manager): + fetch_order_prices_if_expired(root, manager) + return OrderDiscountsByOrderIDLoader(info.context).load(root.id) + + return get_plugin_manager_promise(info.context).then(with_manager) @staticmethod @traced_resolver @@ -1637,6 +1647,7 @@ def _resolve_shipping_address(data): @traced_resolver @prevent_sync_event_circular_query def resolve_shipping_price(root: models.Order, info): + @allow_writer_in_context(info.context) def _resolve_shipping_price(data): lines, manager = data database_connection_name = get_database_connection_name(info.context) @@ -1652,6 +1663,7 @@ def _resolve_shipping_price(data): @traced_resolver @prevent_sync_event_circular_query def resolve_shipping_tax_rate(root: models.Order, info): + @allow_writer_in_context(info.context) def _resolve_shipping_tax_rate(data): lines, manager = data database_connection_name = get_database_connection_name(info.context) @@ -1685,6 +1697,7 @@ def _resolve_actions(payments): @staticmethod @traced_resolver def resolve_subtotal(root: models.Order, info): + @allow_writer_in_context(info.context) def _resolve_subtotal(data): order_lines, manager = data database_connection_name = get_database_connection_name(info.context) @@ -1705,6 +1718,7 @@ def _resolve_subtotal(data): @prevent_sync_event_circular_query @plugin_manager_promise_callback def resolve_total(root: models.Order, info, manager): + @allow_writer_in_context(info.context) def _resolve_total(lines): database_connection_name = get_database_connection_name(info.context) return calculations.order_total( @@ -1719,6 +1733,7 @@ def _resolve_total(lines): @traced_resolver @prevent_sync_event_circular_query def resolve_undiscounted_total(root: models.Order, info): + @allow_writer_in_context(info.context) def _resolve_undiscounted_total(lines_and_manager): lines, manager = lines_and_manager database_connection_name = get_database_connection_name(info.context) @@ -1920,21 +1935,26 @@ def resolve_status_display(root: models.Order, _info): def resolve_can_finalize(root: models.Order, info): if root.status == OrderStatus.DRAFT: - def _validate_draft_order(manager): + @allow_writer_in_context(info.context) + def _validate_draft_order(data): + lines, manager = data country = get_order_country(root) database_connection_name = get_database_connection_name(info.context) try: validate_draft_order( - root, - country, - manager, + order=root, + lines=lines, + country=country, + manager=manager, database_connection_name=database_connection_name, ) except ValidationError: return False return True - return get_plugin_manager_promise(info.context).then(_validate_draft_order) + lines = OrderLinesByOrderIdLoader(info.context).load(root.id) + manager = get_plugin_manager_promise(info.context) + return Promise.all([lines, manager]).then(_validate_draft_order) return True @staticmethod @@ -2007,14 +2027,21 @@ def wrap_shipping_method_with_channel_context(data): ).load((shipping_method.id, channel.slug)) ) - def calculate_price( - listing: Optional[ShippingMethodChannelListing], - ) -> Optional[ShippingMethodData]: + tax_class = None + if shipping_method.tax_class_id: + tax_class = TaxClassByIdLoader(info.context).load( + shipping_method.tax_class_id + ) + + def calculate_price(data) -> Optional[ShippingMethodData]: + listing, tax_class = data if not listing: return None - return convert_to_shipping_method_data(shipping_method, listing) + return convert_to_shipping_method_data( + shipping_method, listing, tax_class + ) - return listing.then(calculate_price) + return Promise.all([listing, tax_class]).then(calculate_price) shipping_method = ShippingMethodByIdLoader(info.context).load( int(root.shipping_method_id) @@ -2045,6 +2072,7 @@ def with_channel(data): channel, manager = data database_connection_name = get_database_connection_name(info.context) + @allow_writer_in_context(info.context) def with_listings(channel_listings): return get_valid_shipping_methods_for_order( root, @@ -2140,21 +2168,26 @@ def resolve_original(root: models.Order, _info): def resolve_errors(root: models.Order, info): if root.status == OrderStatus.DRAFT: - def _validate_order(manager): + @allow_writer_in_context(info.context) + def _validate_order(data): + lines, manager = data country = get_order_country(root) database_connection_name = get_database_connection_name(info.context) try: validate_draft_order( - root, - country, - manager, + order=root, + lines=lines, + country=country, + manager=manager, database_connection_name=database_connection_name, ) except ValidationError as e: return validation_error_to_error_type(e, OrderError) return [] - return get_plugin_manager_promise(info.context).then(_validate_order) + lines = OrderLinesByOrderIdLoader(info.context).load(root.id) + manager = get_plugin_manager_promise(info.context) + return Promise.all([lines, manager]).then(_validate_order) return [] diff --git a/saleor/graphql/order/utils.py b/saleor/graphql/order/utils.py index 53197fc0b85..0113fc487db 100644 --- a/saleor/graphql/order/utils.py +++ b/saleor/graphql/order/utils.py @@ -12,7 +12,11 @@ from ...discount.models import NotApplicable from ...discount.utils import validate_voucher_in_order from ...order.error_codes import OrderErrorCode -from ...order.utils import get_valid_shipping_methods_for_order +from ...order.utils import ( + get_total_quantity, + get_valid_shipping_methods_for_order, + is_shipping_required, +) from ...plugins.manager import PluginsManager from ...product.models import Product, ProductChannelListing, ProductVariant from ...shipping.interface import ShippingMethodData @@ -22,7 +26,7 @@ if TYPE_CHECKING: from ...channel.models import Channel - from ...order.models import Order + from ...order.models import Order, OrderLine from dataclasses import dataclass @@ -39,8 +43,9 @@ class OrderLineData: rules_info: Optional[Iterable[VariantPromotionRuleInfo]] = None -def validate_total_quantity(order: "Order", errors: T_ERRORS): - if order.get_total_quantity() == 0: +def validate_total_quantity(lines: Iterable["OrderLine"], errors: T_ERRORS): + total_quantity = get_total_quantity(lines) + if total_quantity == 0: errors["lines"].append( ValidationError( "Could not create order without any products.", @@ -79,6 +84,7 @@ def get_shipping_method_availability_error( def validate_shipping_method( order: "Order", + channel: "Channel", errors: T_ERRORS, manager: "PluginsManager", database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME, @@ -103,7 +109,7 @@ def validate_shipping_method( code=OrderErrorCode.SHIPPING_METHOD_NOT_APPLICABLE.value, ) else: - listing = order.channel.shipping_method_listings.filter( + listing = channel.shipping_method_listings.filter( shipping_method=order.shipping_method ).last() if not listing: @@ -145,11 +151,13 @@ def validate_shipping_address(order: "Order", errors: T_ERRORS): def validate_order_lines( order: "Order", + lines: Iterable["OrderLine"], + channel: "Channel", country: str, errors: T_ERRORS, database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME, ): - for line in order.lines.all(): + for line in lines: if line.variant is None: errors["lines"].append( ValidationError( @@ -162,7 +170,7 @@ def validate_order_lines( check_stock_and_preorder_quantity( line.variant, country, - order.channel.slug, + channel.slug, line.quantity, order_line=line, database_connection_name=database_connection_name, @@ -174,15 +182,16 @@ def validate_order_lines( def validate_variants_is_available( - order: "Order", + channel: "Channel", + lines: Iterable["OrderLine"], errors: T_ERRORS, database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME, ): - variants_ids = {line.variant_id for line in order.lines.all()} + variants_ids = {line.variant_id for line in lines} try: validate_variants_available_in_channel( variants_ids, - order.channel_id, + channel.id, OrderErrorCode.NOT_AVAILABLE_IN_CHANNEL.value, database_connection_name=database_connection_name, ) @@ -191,15 +200,16 @@ def validate_variants_is_available( def validate_product_is_published( - order: "Order", + channel: "Channel", + lines: Iterable["OrderLine"], errors: T_ERRORS, database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME, ): - variant_ids = [line.variant_id for line in order.lines.all()] + variant_ids = [line.variant_id for line in lines] unpublished_product = ( Product.objects.using(database_connection_name) .filter(variants__id__in=variant_ids) - .not_published(order.channel.slug) + .not_published(channel.slug) ) if unpublished_product.exists(): errors["lines"].append( @@ -272,18 +282,19 @@ def validate_variant_channel_listings( def validate_product_is_available_for_purchase( - order: "Order", + channel: "Channel", + lines: Iterable["OrderLine"], errors: T_ERRORS, database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME, ): invalid_lines = [] - for line in order.lines.all(): + for line in lines: variant = line.variant if not variant: continue product_channel_listing = ( ProductChannelListing.objects.using(database_connection_name) - .filter(channel_id=order.channel_id, product_id=variant.product_id) + .filter(channel_id=channel.id, product_id=variant.product_id) .first() ) if not ( @@ -311,10 +322,12 @@ def validate_channel_is_active(channel: "Channel", errors: T_ERRORS): ) -def _validate_voucher(order: "Order", errors: T_ERRORS): - if order.channel.include_draft_order_in_voucher_usage: +def _validate_voucher( + order: "Order", lines: Iterable["OrderLine"], channel: "Channel", errors: T_ERRORS +): + if channel.include_draft_order_in_voucher_usage: try: - validate_voucher_in_order(order) + validate_voucher_in_order(order, lines, channel) except NotApplicable as e: errors["voucher"].append( ValidationError( @@ -326,6 +339,7 @@ def _validate_voucher(order: "Order", errors: T_ERRORS): def validate_draft_order( order: "Order", + lines: Iterable["OrderLine"], country: str, manager: "PluginsManager", database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME, @@ -341,18 +355,26 @@ def validate_draft_order( Returns a list of errors if any were found. """ + channel = order.channel + errors: T_ERRORS = defaultdict(list) validate_billing_address(order, errors) - if order.is_shipping_required(): + if is_shipping_required(lines): validate_shipping_address(order, errors) - validate_shipping_method(order, errors, manager, database_connection_name) - validate_total_quantity(order, errors) - validate_order_lines(order, country, errors, database_connection_name) - validate_channel_is_active(order.channel, errors) - validate_product_is_published(order, errors, database_connection_name) - validate_product_is_available_for_purchase(order, errors, database_connection_name) - validate_variants_is_available(order, errors, database_connection_name) - _validate_voucher(order, errors) + validate_shipping_method( + order, channel, errors, manager, database_connection_name + ) + validate_total_quantity(lines, errors) + validate_order_lines( + order, lines, channel, country, errors, database_connection_name + ) + validate_channel_is_active(channel, errors) + validate_product_is_published(channel, lines, errors, database_connection_name) + validate_product_is_available_for_purchase( + channel, lines, errors, database_connection_name + ) + validate_variants_is_available(channel, lines, errors, database_connection_name) + _validate_voucher(order, lines, channel, errors) if errors: raise ValidationError(errors) diff --git a/saleor/graphql/payment/mutations/stored_payment_methods/payment_gateway_initialize_tokenization.py b/saleor/graphql/payment/mutations/stored_payment_methods/payment_gateway_initialize_tokenization.py index e31c3c699d1..47e81c59662 100644 --- a/saleor/graphql/payment/mutations/stored_payment_methods/payment_gateway_initialize_tokenization.py +++ b/saleor/graphql/payment/mutations/stored_payment_methods/payment_gateway_initialize_tokenization.py @@ -70,7 +70,7 @@ def _perform_mutation(cls, root, info, id, channel, data=None): manager = get_plugin_manager_promise(info.context).get() is_active = manager.is_event_active_for_any_plugin( - "payment_gateway_initialize_tokenization" + "payment_gateway_initialize_tokenization", channel_slug=channel.slug ) if not is_active: diff --git a/saleor/graphql/payment/mutations/stored_payment_methods/payment_method_intialize_tokenization.py b/saleor/graphql/payment/mutations/stored_payment_methods/payment_method_intialize_tokenization.py index 535b13728a3..6a52443cfee 100644 --- a/saleor/graphql/payment/mutations/stored_payment_methods/payment_method_intialize_tokenization.py +++ b/saleor/graphql/payment/mutations/stored_payment_methods/payment_method_intialize_tokenization.py @@ -82,6 +82,7 @@ def _perform_mutation( payment_flow_to_support=payment_flow_to_support, ), PaymentMethodInitializeTokenizationErrorCode, + channel, ) return cls( diff --git a/saleor/graphql/payment/mutations/stored_payment_methods/payment_method_process_tokenization.py b/saleor/graphql/payment/mutations/stored_payment_methods/payment_method_process_tokenization.py index e2164989a08..d5dbf3543e6 100644 --- a/saleor/graphql/payment/mutations/stored_payment_methods/payment_method_process_tokenization.py +++ b/saleor/graphql/payment/mutations/stored_payment_methods/payment_method_process_tokenization.py @@ -66,6 +66,7 @@ def _perform_mutation(cls, root, info, id, channel, data=None): id=id, user=user, channel=channel, data=data ), PaymentMethodProcessTokenizationErrorCode, + channel, ) return cls( result=response.result, data=response.data, errors=errors, id=response.id diff --git a/saleor/graphql/payment/mutations/stored_payment_methods/payment_method_request_delete.py b/saleor/graphql/payment/mutations/stored_payment_methods/payment_method_request_delete.py index 5a2743c733a..f2eb1ac366d 100644 --- a/saleor/graphql/payment/mutations/stored_payment_methods/payment_method_request_delete.py +++ b/saleor/graphql/payment/mutations/stored_payment_methods/payment_method_request_delete.py @@ -72,7 +72,7 @@ def _perform_mutation(cls, root, info, id, channel): manager = get_plugin_manager_promise(info.context).get() is_active = manager.is_event_active_for_any_plugin( - "stored_payment_method_request_delete" + "stored_payment_method_request_delete", channel_slug=channel.slug ) if not is_active: raise ValidationError( diff --git a/saleor/graphql/payment/mutations/stored_payment_methods/utils.py b/saleor/graphql/payment/mutations/stored_payment_methods/utils.py index a4670de05b7..edf19270ee6 100644 --- a/saleor/graphql/payment/mutations/stored_payment_methods/utils.py +++ b/saleor/graphql/payment/mutations/stored_payment_methods/utils.py @@ -1,5 +1,6 @@ from django.core.exceptions import ValidationError +from .....channel.models import Channel from .....payment.interface import ( PaymentMethodTokenizationBaseRequestData, PaymentMethodTokenizationResponseData, @@ -14,9 +15,12 @@ def handle_payment_method_action( manager_func_name: str, request_data: PaymentMethodTokenizationBaseRequestData, error_type_class, + channel: "Channel", ) -> tuple[PaymentMethodTokenizationResponseData, list[dict]]: manager = get_plugin_manager_promise(info.context).get() - is_active = manager.is_event_active_for_any_plugin(manager_func_name) + is_active = manager.is_event_active_for_any_plugin( + manager_func_name, channel_slug=channel.slug + ) if not is_active: raise ValidationError( diff --git a/saleor/graphql/payment/tests/mutations/test_payment_gateway_initialize_tokenization.py b/saleor/graphql/payment/tests/mutations/test_payment_gateway_initialize_tokenization.py index 2b1c703856f..742172ea1b0 100644 --- a/saleor/graphql/payment/tests/mutations/test_payment_gateway_initialize_tokenization.py +++ b/saleor/graphql/payment/tests/mutations/test_payment_gateway_initialize_tokenization.py @@ -77,7 +77,7 @@ def test_payment_gateway_initialize_tokenization( assert response_data["data"] == expected_output_data mocked_is_event_active_for_any_plugin.assert_called_once_with( - "payment_gateway_initialize_tokenization" + "payment_gateway_initialize_tokenization", channel_slug=channel_USD.slug ) mocked_payment_gateway_initialize_tokenization.assert_called_once_with( request_data=PaymentGatewayInitializeTokenizationRequestData( @@ -168,7 +168,7 @@ def test_payment_gateway_initialize_tokenization_not_app_or_plugin_subscribed_to ) mocked_is_event_active_for_any_plugin.assert_called_once_with( - "payment_gateway_initialize_tokenization" + "payment_gateway_initialize_tokenization", channel_slug=channel_USD.slug ) assert not mocked_payment_gateway_initialize_tokenization.called @@ -248,7 +248,7 @@ def test_payment_gateway_initialize_tokenization_failure_from_app( assert error["message"] == error_message mocked_is_event_active_for_any_plugin.assert_called_once_with( - "payment_gateway_initialize_tokenization" + "payment_gateway_initialize_tokenization", channel_slug=channel_USD.slug ) mocked_payment_gateway_initialize_tokenization.assert_called_once_with( request_data=PaymentGatewayInitializeTokenizationRequestData( diff --git a/saleor/graphql/payment/tests/mutations/test_payment_method_initialize_tokenization.py b/saleor/graphql/payment/tests/mutations/test_payment_method_initialize_tokenization.py index f514e352842..a2c647c680d 100644 --- a/saleor/graphql/payment/tests/mutations/test_payment_method_initialize_tokenization.py +++ b/saleor/graphql/payment/tests/mutations/test_payment_method_initialize_tokenization.py @@ -87,7 +87,7 @@ def test_payment_method_initialize_tokenization( assert response_data["data"] == expected_output_data assert response_data["id"] == expected_payment_method_id mocked_is_event_active_for_any_plugin.assert_called_once_with( - "payment_method_initialize_tokenization" + "payment_method_initialize_tokenization", channel_slug=channel_USD.slug ) mocked_payment_method_initialize_tokenization.assert_called_once_with( request_data=PaymentMethodInitializeTokenizationRequestData( @@ -191,7 +191,7 @@ def test_payment_method_initialize_tokenization_not_app_or_plugin_subscribed_to_ ) mocked_is_event_active_for_any_plugin.assert_called_once_with( - "payment_method_initialize_tokenization" + "payment_method_initialize_tokenization", channel_slug=channel_USD.slug ) assert not mocked_payment_method_initialize_tokenization.called @@ -275,7 +275,7 @@ def test_payment_method_initialize_tokenization_failure_from_app( assert error["message"] == error_message mocked_is_event_active_for_any_plugin.assert_called_once_with( - "payment_method_initialize_tokenization" + "payment_method_initialize_tokenization", channel_slug=channel_USD.slug ) mocked_payment_method_initialize_tokenization.assert_called_once_with( request_data=PaymentMethodInitializeTokenizationRequestData( diff --git a/saleor/graphql/payment/tests/mutations/test_payment_method_process_tokenization.py b/saleor/graphql/payment/tests/mutations/test_payment_method_process_tokenization.py index a54566fd2f8..02c429728be 100644 --- a/saleor/graphql/payment/tests/mutations/test_payment_method_process_tokenization.py +++ b/saleor/graphql/payment/tests/mutations/test_payment_method_process_tokenization.py @@ -83,7 +83,7 @@ def test_payment_method_process_tokenization( assert response_data["data"] == expected_output_data assert response_data["id"] == expected_payment_method_id mocked_is_event_active_for_any_plugin.assert_called_once_with( - "payment_method_process_tokenization" + "payment_method_process_tokenization", channel_slug=channel_USD.slug ) mocked_payment_method_process_tokenization.assert_called_once_with( request_data=PaymentMethodProcessTokenizationRequestData( @@ -177,7 +177,7 @@ def test_payment_method_process_tokenization_not_app_or_plugin_subscribed_to_eve ) mocked_is_event_active_for_any_plugin.assert_called_once_with( - "payment_method_process_tokenization" + "payment_method_process_tokenization", channel_slug=channel_USD.slug ) assert not mocked_payment_method_process_tokenization.called @@ -256,7 +256,7 @@ def test_payment_method_process_tokenization_failure_from_app( assert error["message"] == error_message mocked_is_event_active_for_any_plugin.assert_called_once_with( - "payment_method_process_tokenization" + "payment_method_process_tokenization", channel_slug=channel_USD.slug ) mocked_payment_method_process_tokenization.assert_called_once_with( request_data=PaymentMethodProcessTokenizationRequestData( diff --git a/saleor/graphql/payment/tests/mutations/test_stored_payment_method_request_delete.py b/saleor/graphql/payment/tests/mutations/test_stored_payment_method_request_delete.py index a1c3aaa91b3..718391f3a13 100644 --- a/saleor/graphql/payment/tests/mutations/test_stored_payment_method_request_delete.py +++ b/saleor/graphql/payment/tests/mutations/test_stored_payment_method_request_delete.py @@ -56,14 +56,14 @@ def test_stored_payment_method_request_delete( ) mocked_is_event_active_for_any_plugin.assert_called_once_with( - "stored_payment_method_request_delete" + "stored_payment_method_request_delete", channel_slug=channel_USD.slug ) mocked_stored_payment_method_request_delete.assert_called_once_with( request_delete_data=StoredPaymentMethodRequestDeleteData( payment_method_id=expected_id, user=user_api_client.user, channel=channel_USD, - ) + ), ) @@ -104,7 +104,7 @@ def test_stored_payment_method_request_delete_app_returned_failure_event( assert error["message"] == expected_error_message mocked_is_event_active_for_any_plugin.assert_called_once_with( - "stored_payment_method_request_delete" + "stored_payment_method_request_delete", channel_slug=channel_USD.slug ) mocked_stored_payment_method_request_delete.assert_called_once_with( request_delete_data=StoredPaymentMethodRequestDeleteData( @@ -212,7 +212,7 @@ def test_stored_payment_method_request_delete_not_app_or_plugin_subscribed_to_ev ) mocked_is_event_active_for_any_plugin.assert_called_once_with( - "stored_payment_method_request_delete" + "stored_payment_method_request_delete", channel_slug=channel_USD.slug ) assert not mocked_stored_payment_method_request_delete.called diff --git a/saleor/graphql/plugins/resolvers.py b/saleor/graphql/plugins/resolvers.py index d96ded2c721..de72daf219c 100644 --- a/saleor/graphql/plugins/resolvers.py +++ b/saleor/graphql/plugins/resolvers.py @@ -43,7 +43,7 @@ def aggregate_plugins_configuration( plugins_per_channel: dict[str, list[BasePlugin]] = defaultdict(list) global_plugins: dict[str, BasePlugin] = {} - for plugin in manager.all_plugins: + for plugin in manager.get_all_plugins(): hide_private_configuration_fields(plugin.configuration, plugin.CONFIG_STRUCTURE) if plugin.HIDDEN is True: continue diff --git a/saleor/graphql/product/bulk_mutations/product_variant_bulk_update.py b/saleor/graphql/product/bulk_mutations/product_variant_bulk_update.py index c527059cccb..7a4824d644c 100644 --- a/saleor/graphql/product/bulk_mutations/product_variant_bulk_update.py +++ b/saleor/graphql/product/bulk_mutations/product_variant_bulk_update.py @@ -691,7 +691,12 @@ def save_variants(cls, variants_data_with_errors_list): models.ProductVariantChannelListing.objects.bulk_create(listings_to_create) models.ProductVariantChannelListing.objects.bulk_update( listings_to_update, - fields=["price_amount", "cost_price_amount", "preorder_quantity_threshold"], + fields=[ + "price_amount", + "discounted_price_amount", + "cost_price_amount", + "preorder_quantity_threshold", + ], ) warehouse_models.Stock.objects.filter(id__in=stocks_to_remove).delete() models.ProductVariantChannelListing.objects.filter( diff --git a/saleor/graphql/product/dataloaders/products.py b/saleor/graphql/product/dataloaders/products.py index d599332f5f1..6053cd93ead 100644 --- a/saleor/graphql/product/dataloaders/products.py +++ b/saleor/graphql/product/dataloaders/products.py @@ -4,6 +4,7 @@ from django.db.models import F +from ....core.db.connection import allow_writer_in_context from ....product import ProductMediaTypes from ....product.models import ( Category, @@ -19,7 +20,10 @@ VariantChannelListingPromotionRule, VariantMedia, ) -from ...core.dataloaders import BaseThumbnailBySizeAndFormatLoader, DataLoader +from ...core.dataloaders import ( + BaseThumbnailBySizeAndFormatLoader, + DataLoader, +) ProductIdAndChannelSlug = tuple[int, str] VariantIdAndChannelSlug = tuple[int, str] @@ -556,6 +560,7 @@ class ProductTypeByProductIdLoader(DataLoader): context_key = "producttype_by_product_id" def batch_load(self, keys): + @allow_writer_in_context(self.context) def with_products(products): product_ids = {p.id for p in products} product_types_map = ( diff --git a/saleor/graphql/product/mutations/digital_contents.py b/saleor/graphql/product/mutations/digital_contents.py index f4e14145797..b114de20329 100644 --- a/saleor/graphql/product/mutations/digital_contents.py +++ b/saleor/graphql/product/mutations/digital_contents.py @@ -1,6 +1,7 @@ import graphene from django.core.exceptions import ValidationError +from ....core.db.connection import allow_writer from ....core.exceptions import PermissionDenied from ....permission.enums import ProductPermissions from ....product import models @@ -172,6 +173,7 @@ class Meta: permissions = (ProductPermissions.MANAGE_PRODUCTS,) @classmethod + @allow_writer() def mutate( # type: ignore[override] cls, root, info: ResolveInfo, /, *, variant_id: str ): diff --git a/saleor/graphql/product/mutations/utils.py b/saleor/graphql/product/mutations/utils.py index ba9ac603080..9f1f15a443f 100644 --- a/saleor/graphql/product/mutations/utils.py +++ b/saleor/graphql/product/mutations/utils.py @@ -9,8 +9,7 @@ def clean_tax_code(cleaned_input: dict, manager: PluginsManager): This function provides backwards compatibility for the `taxCode` input field. If the `taxClass` is not provided but the `taxCode` is, try to find a tax class with given - tax code and assign it to the product type. If no matching tax class is found, - create one with the given tax code. + tax code and assign it to the product type. """ tax_code = cleaned_input.get("tax_code") if tax_code and "tax_class" not in cleaned_input: @@ -19,8 +18,4 @@ def clean_tax_code(cleaned_input: dict, manager: PluginsManager): | Q(metadata__contains={"avatax.code": tax_code}) | Q(metadata__contains={"vatlayer.code": tax_code}) ).first() - if not tax_class: - tax_class = TaxClass.objects.create(name=tax_code) - manager.assign_tax_code_to_object_meta(tax_class, tax_code) - tax_class.save(update_fields=["metadata"]) cleaned_input["tax_class"] = tax_class diff --git a/saleor/graphql/product/tests/benchmark/test_collection.py b/saleor/graphql/product/tests/benchmark/test_collection.py index de81b3ba626..752b1bf0a91 100644 --- a/saleor/graphql/product/tests/benchmark/test_collection.py +++ b/saleor/graphql/product/tests/benchmark/test_collection.py @@ -418,7 +418,7 @@ def test_collections_for_federation_query_count( ], } - with django_assert_num_queries(3): + with django_assert_num_queries(2): response = api_client.post_graphql(query, variables) content = get_graphql_content(response) assert len(content["data"]["_entities"]) == 1 @@ -436,7 +436,7 @@ def test_collections_for_federation_query_count( ], } - with django_assert_num_queries(3): + with django_assert_num_queries(2): response = api_client.post_graphql(query, variables) content = get_graphql_content(response) assert len(content["data"]["_entities"]) == 3 diff --git a/saleor/graphql/product/tests/benchmark/test_product.py b/saleor/graphql/product/tests/benchmark/test_product.py index 6b61e6c3de4..1e5c13b09de 100644 --- a/saleor/graphql/product/tests/benchmark/test_product.py +++ b/saleor/graphql/product/tests/benchmark/test_product.py @@ -743,7 +743,7 @@ def test_products_for_federation_query_count( ], } - with django_assert_num_queries(6): + with django_assert_num_queries(5): response = api_client.post_graphql(query, variables) content = get_graphql_content(response) assert len(content["data"]["_entities"]) == 1 @@ -765,7 +765,7 @@ def test_products_for_federation_query_count( ], } - with django_assert_num_queries(6): + with django_assert_num_queries(5): response = api_client.post_graphql(query, variables) content = get_graphql_content(response) assert len(content["data"]["_entities"]) == 2 diff --git a/saleor/graphql/product/tests/benchmark/test_variant.py b/saleor/graphql/product/tests/benchmark/test_variant.py index c904290478e..fa5f8d2775e 100644 --- a/saleor/graphql/product/tests/benchmark/test_variant.py +++ b/saleor/graphql/product/tests/benchmark/test_variant.py @@ -384,7 +384,7 @@ def test_products_variants_for_federation_query_count( ], } - with django_assert_num_queries(5): + with django_assert_num_queries(4): response = api_client.post_graphql(query, variables) content = get_graphql_content(response) assert len(content["data"]["_entities"]) == 1 @@ -400,7 +400,7 @@ def test_products_variants_for_federation_query_count( ], } - with django_assert_num_queries(5): + with django_assert_num_queries(4): response = api_client.post_graphql(query, variables) content = get_graphql_content(response) assert len(content["data"]["_entities"]) == 4 diff --git a/saleor/graphql/product/tests/deprecated/test_utils.py b/saleor/graphql/product/tests/deprecated/test_utils.py index a17f1d7b5ca..158f70681a2 100644 --- a/saleor/graphql/product/tests/deprecated/test_utils.py +++ b/saleor/graphql/product/tests/deprecated/test_utils.py @@ -82,6 +82,4 @@ def test_clean_tax_code_when_tax_class_does_not_exists(): clean_tax_code(data, manager) # then - assert TaxClass.objects.count() == 1 - tax_class = TaxClass.objects.first() - assert data["tax_class"] == tax_class + assert data["tax_class"] is None diff --git a/saleor/graphql/product/tests/mutations/test_product_create.py b/saleor/graphql/product/tests/mutations/test_product_create.py index 747862a9edf..c1884af71c8 100644 --- a/saleor/graphql/product/tests/mutations/test_product_create.py +++ b/saleor/graphql/product/tests/mutations/test_product_create.py @@ -119,7 +119,7 @@ def test_create_product( monkeypatch.setattr( PluginsManager, "get_tax_code_from_object_meta", - lambda self, x: TaxType(description="", code=product_tax_rate), + lambda self, x, channel_slug: TaxType(description="", code=product_tax_rate), ) # Default attribute defined in product_type fixture diff --git a/saleor/graphql/product/tests/mutations/test_product_update.py b/saleor/graphql/product/tests/mutations/test_product_update.py index dd0b477de68..7a348dd1b32 100644 --- a/saleor/graphql/product/tests/mutations/test_product_update.py +++ b/saleor/graphql/product/tests/mutations/test_product_update.py @@ -132,7 +132,7 @@ def test_update_product( monkeypatch.setattr( PluginsManager, "get_tax_code_from_object_meta", - lambda self, x: TaxType(description="", code=product_tax_rate), + lambda self, x, channel_slug: TaxType(description="", code=product_tax_rate), ) attribute_id = graphene.Node.to_global_id("Attribute", color_attribute.pk) diff --git a/saleor/graphql/product/tests/mutations/test_product_variant_bulk_update.py b/saleor/graphql/product/tests/mutations/test_product_variant_bulk_update.py index 70c5f60b006..9decc564a61 100644 --- a/saleor/graphql/product/tests/mutations/test_product_variant_bulk_update.py +++ b/saleor/graphql/product/tests/mutations/test_product_variant_bulk_update.py @@ -410,9 +410,7 @@ def test_product_variant_bulk_update_channel_listings_input( variant_id = graphene.Node.to_global_id("ProductVariant", variant.pk) ProductChannelListing.objects.create(product=product, channel=channel_PLN) - existing_variant_listing = variant.channel_listings.exclude( - channel=channel_PLN - ).last() + existing_variant_listing = variant.channel_listings.get() assert variant.channel_listings.count() == 1 product_id = graphene.Node.to_global_id("Product", product.pk) @@ -452,8 +450,21 @@ def test_product_variant_bulk_update_channel_listings_input( ) get_graphql_content(response, ignore_errors=True) - existing_variant_listing.refresh_from_db() # then + existing_variant_listing.refresh_from_db() + assert ( + existing_variant_listing.price_amount == new_price_for_existing_variant_listing + ) + assert ( + existing_variant_listing.discounted_price_amount + == new_price_for_existing_variant_listing + ) + new_variant_listing = variant.channel_listings.get(channel=channel_PLN) + assert new_variant_listing.price_amount == not_existing_variant_listing_price + assert ( + new_variant_listing.discounted_price_amount + == not_existing_variant_listing_price + ) # only promotions with created channel will be marked as dirty second_promotion_rule.refresh_from_db() diff --git a/saleor/graphql/product/types/products.py b/saleor/graphql/product/types/products.py index 6322671a2ba..67319708e39 100644 --- a/saleor/graphql/product/types/products.py +++ b/saleor/graphql/product/types/products.py @@ -1052,7 +1052,9 @@ def resolve_description_json(root: ChannelContext[models.Product], _info): def resolve_tax_type(root: ChannelContext[models.Product], info): def with_tax_class(data): tax_class, manager = data - tax_data = manager.get_tax_code_from_object_meta(tax_class) + tax_data = manager.get_tax_code_from_object_meta( + tax_class, channel_slug=root.channel_slug + ) return TaxType(tax_code=tax_data.code, description=tax_data.description) if root.node.tax_class_id: diff --git a/saleor/graphql/schema.graphql b/saleor/graphql/schema.graphql index 28142631a1c..633ca308253 100644 --- a/saleor/graphql/schema.graphql +++ b/saleor/graphql/schema.graphql @@ -1282,7 +1282,7 @@ type Query { ): ExportFileCountableConnection """List of all tax rates available from tax gateway.""" - taxTypes: [TaxType!] @doc(category: "Taxes") + taxTypes: [TaxType!] @doc(category: "Taxes") @deprecated(reason: "This field will be removed in Saleor 4.0. Use `taxClasses` field instead.") """ Look up a checkout by id. @@ -3147,7 +3147,7 @@ type App implements Node & ObjectWithMetadata @doc(category: "Apps") { """Version number of the app.""" version: String - """JWT token used to authenticate by thridparty app.""" + """JWT token used to authenticate by third-party app.""" accessToken: String """ @@ -4204,6 +4204,13 @@ type ShippingMethodTranslation implements Node @doc(category: "Shipping") { Rich text format. For reference see https://editorjs.io/ """ description: JSONString + + """ + Represents the shipping method fields to translate. + + Added in Saleor 3.14. + """ + translatableContent: ShippingMethodTranslatableContent } type LanguageDisplay { @@ -4997,6 +5004,44 @@ enum LanguageCodeEnum { ZU_ZA } +""" +Represents shipping method's original translatable fields and related translations. +""" +type ShippingMethodTranslatableContent implements Node @doc(category: "Shipping") { + """The ID of the shipping method translatable content.""" + id: ID! + + """ + The ID of the shipping method to translate. + + Added in Saleor 3.14. + """ + shippingMethodId: ID! + + """Shipping method name to translate.""" + name: String! + + """ + Shipping method description to translate. + + Rich text format. For reference see https://editorjs.io/ + """ + description: JSONString + + """Returns translated shipping method fields for the given language code.""" + translation( + """A language code to return the translation for shipping method.""" + languageCode: LanguageCodeEnum! + ): ShippingMethodTranslation + + """ + Shipping method are the methods you'll use to get customer's orders to them. They are directly exposed to the customers. + + Requires one of the following permissions: MANAGE_SHIPPING. + """ + shippingMethod: ShippingMethodType @deprecated(reason: "This field will be removed in Saleor 4.0. Get model fields from the root level queries.") +} + """Represents shipping method channel listing.""" type ShippingMethodChannelListing implements Node @doc(category: "Shipping") { """The ID of shipping method channel listing.""" @@ -5615,7 +5660,7 @@ type OrderSettings { deleteExpiredOrdersAfter: Day! """ - Determine if it is possible to place unpdaid order by calling `checkoutComplete` mutation. + Determine if it is possible to place unpaid order by calling `checkoutComplete` mutation. Added in Saleor 3.15. @@ -6602,7 +6647,7 @@ type AttributeValue implements Node @doc(category: "Attributes") { richText: JSONString """ - Represents the text of the attribute value, plain text without formating. + Represents the text of the attribute value, plain text without formatting. """ plainText: String @@ -6643,6 +6688,103 @@ type AttributeValueTranslation implements Node @doc(category: "Attributes") { """Translated plain text attribute value .""" plainText: String + + """ + Represents the attribute value fields to translate. + + Added in Saleor 3.14. + """ + translatableContent: AttributeValueTranslatableContent +} + +""" +Represents attribute value's original translatable fields and related translations. +""" +type AttributeValueTranslatableContent implements Node @doc(category: "Attributes") { + """The ID of the attribute value translatable content.""" + id: ID! + + """ + The ID of the attribute value to translate. + + Added in Saleor 3.14. + """ + attributeValueId: ID! + + """Name of the attribute value to translate.""" + name: String! + + """ + Attribute value. + + Rich text format. For reference see https://editorjs.io/ + """ + richText: JSONString + + """Attribute plain text value.""" + plainText: String + + """Returns translated attribute value fields for the given language code.""" + translation( + """A language code to return the translation for attribute value.""" + languageCode: LanguageCodeEnum! + ): AttributeValueTranslation + + """Represents a value of an attribute.""" + attributeValue: AttributeValue @deprecated(reason: "This field will be removed in Saleor 4.0. Get model fields from the root level queries.") + + """ + Associated attribute that can be translated. + + Added in Saleor 3.9. + """ + attribute: AttributeTranslatableContent +} + +""" +Represents attribute's original translatable fields and related translations. +""" +type AttributeTranslatableContent implements Node @doc(category: "Attributes") { + """The ID of the attribute translatable content.""" + id: ID! + + """ + The ID of the attribute to translate. + + Added in Saleor 3.14. + """ + attributeId: ID! + + """Name of the attribute to translate.""" + name: String! + + """Returns translated attribute fields for the given language code.""" + translation( + """A language code to return the translation for attribute.""" + languageCode: LanguageCodeEnum! + ): AttributeTranslation + + """Custom attribute of a product.""" + attribute: Attribute @deprecated(reason: "This field will be removed in Saleor 4.0. Get model fields from the root level queries.") +} + +"""Represents attribute translations.""" +type AttributeTranslation implements Node @doc(category: "Attributes") { + """The ID of the attribute translation.""" + id: ID! + + """Translation language.""" + language: LanguageDisplay! + + """Translated attribute name.""" + name: String! + + """ + Represents the attribute fields to translate. + + Added in Saleor 3.14. + """ + translatableContent: AttributeTranslatableContent } type File { @@ -6682,18 +6824,6 @@ input AttributeValueFilterInput @doc(category: "Attributes") { slugs: [String!] } -"""Represents attribute translations.""" -type AttributeTranslation implements Node @doc(category: "Attributes") { - """The ID of the attribute translation.""" - id: ID! - - """Translation language.""" - language: LanguageDisplay! - - """Translated attribute name.""" - name: String! -} - type ProductTypeCountableConnection @doc(category: "Products") { """Pagination data for this connection.""" pageInfo: PageInfo! @@ -7468,6 +7598,60 @@ type CategoryTranslation implements Node @doc(category: "Products") { Rich text format. For reference see https://editorjs.io/ """ descriptionJson: JSONString @deprecated(reason: "This field will be removed in Saleor 4.0. Use the `description` field instead.") + + """ + Represents the category fields to translate. + + Added in Saleor 3.14. + """ + translatableContent: CategoryTranslatableContent +} + +""" +Represents category original translatable fields and related translations. +""" +type CategoryTranslatableContent implements Node @doc(category: "Products") { + """The ID of the category translatable content.""" + id: ID! + + """ + The ID of the category to translate. + + Added in Saleor 3.14. + """ + categoryId: ID! + + """SEO title to translate.""" + seoTitle: String + + """SEO description to translate.""" + seoDescription: String + + """Name of the category translatable content.""" + name: String! + + """ + Category description to translate. + + Rich text format. For reference see https://editorjs.io/ + """ + description: JSONString + + """ + Description of the category. + + Rich text format. For reference see https://editorjs.io/ + """ + descriptionJson: JSONString @deprecated(reason: "This field will be removed in Saleor 4.0. Use the `description` field instead.") + + """Returns translated category fields for the given language code.""" + translation( + """A language code to return the translation for category.""" + languageCode: LanguageCodeEnum! + ): CategoryTranslation + + """Represents a single category of products.""" + category: Category @deprecated(reason: "This field will be removed in Saleor 4.0. Get model fields from the root level queries.") } """Represents a version of a product such as different size or color.""" @@ -7929,6 +8113,43 @@ type ProductVariantTranslation implements Node @doc(category: "Products") { """Translated product variant name.""" name: String! + + """ + Represents the product variant fields to translate. + + Added in Saleor 3.14. + """ + translatableContent: ProductVariantTranslatableContent +} + +""" +Represents product variant's original translatable fields and related translations. +""" +type ProductVariantTranslatableContent implements Node @doc(category: "Products") { + """The ID of the product variant translatable content.""" + id: ID! + + """ + The ID of the product variant to translate. + + Added in Saleor 3.14. + """ + productVariantId: ID! + + """Name of the product variant to translate.""" + name: String! + + """Returns translated product variant fields for the given language code.""" + translation( + """A language code to return the translation for product variant.""" + languageCode: LanguageCodeEnum! + ): ProductVariantTranslation + + """Represents a version of a product such as different size or color.""" + productVariant: ProductVariant @deprecated(reason: "This field will be removed in Saleor 4.0. Get model fields from the root level queries.") + + """List of product variant attribute values that can be translated.""" + attributeValues: [AttributeValueTranslatableContent!]! } """Represents digital content associated with a product variant.""" @@ -8363,217 +8584,13 @@ type CollectionTranslation implements Node @doc(category: "Products") { Rich text format. For reference see https://editorjs.io/ """ descriptionJson: JSONString @deprecated(reason: "This field will be removed in Saleor 4.0. Use the `description` field instead.") -} - -"""Represents collection channel listing.""" -type CollectionChannelListing implements Node @doc(category: "Products") { - """The ID of the collection channel listing.""" - id: ID! - publicationDate: Date @deprecated(reason: "This field will be removed in Saleor 4.0. Use the `publishedAt` field to fetch the publication date.") - - """ - The collection publication date. - - Added in Saleor 3.3. - """ - publishedAt: DateTime - - """Indicates if the collection is published in the channel.""" - isPublished: Boolean! - - """The channel to which the collection belongs.""" - channel: Channel! -} - -"""Represents product translations.""" -type ProductTranslation implements Node @doc(category: "Products") { - """The ID of the product translation.""" - id: ID! - - """Translation language.""" - language: LanguageDisplay! - - """Translated SEO title.""" - seoTitle: String - - """Translated SEO description.""" - seoDescription: String - - """Translated product name.""" - name: String - - """ - Translated description of the product. - - Rich text format. For reference see https://editorjs.io/ - """ - description: JSONString - - """ - Translated description of the product. - - Rich text format. For reference see https://editorjs.io/ - """ - descriptionJson: JSONString @deprecated(reason: "This field will be removed in Saleor 4.0. Use the `description` field instead.") -} - -type WarehouseCountableConnection @doc(category: "Products") { - """Pagination data for this connection.""" - pageInfo: PageInfo! - edges: [WarehouseCountableEdge!]! - - """A total count of items in the collection.""" - totalCount: Int -} - -type WarehouseCountableEdge @doc(category: "Products") { - """The item at the end of the edge.""" - node: Warehouse! - - """A cursor for use in pagination.""" - cursor: String! -} - -input WarehouseFilterInput @doc(category: "Products") { - clickAndCollectOption: WarehouseClickAndCollectOptionEnum - metadata: [MetadataFilter!] - search: String - ids: [ID!] - isPrivate: Boolean - channels: [ID!] - slugs: [String!] -} - -input WarehouseSortingInput @doc(category: "Products") { - """Specifies the direction in which to sort warehouses.""" - direction: OrderDirection! - - """Sort warehouses by the selected field.""" - field: WarehouseSortField! -} - -enum WarehouseSortField @doc(category: "Products") { - """Sort warehouses by name.""" - NAME -} - -type TranslatableItemConnection { - """Pagination data for this connection.""" - pageInfo: PageInfo! - edges: [TranslatableItemEdge!]! - - """A total count of items in the collection.""" - totalCount: Int -} - -type TranslatableItemEdge { - """The item at the end of the edge.""" - node: TranslatableItem! - - """A cursor for use in pagination.""" - cursor: String! -} - -union TranslatableItem = ProductTranslatableContent | CollectionTranslatableContent | CategoryTranslatableContent | AttributeTranslatableContent | AttributeValueTranslatableContent | ProductVariantTranslatableContent | PageTranslatableContent | ShippingMethodTranslatableContent | VoucherTranslatableContent | MenuItemTranslatableContent | PromotionTranslatableContent | PromotionRuleTranslatableContent | SaleTranslatableContent - -""" -Represents product's original translatable fields and related translations. -""" -type ProductTranslatableContent implements Node @doc(category: "Products") { - """The ID of the product translatable content.""" - id: ID! - - """SEO title to translate.""" - seoTitle: String - - """SEO description to translate.""" - seoDescription: String - - """Product's name to translate.""" - name: String! - - """ - Product's description to translate. - - Rich text format. For reference see https://editorjs.io/ - """ - description: JSONString - - """ - Description of the product. - - Rich text format. For reference see https://editorjs.io/ - """ - descriptionJson: JSONString @deprecated(reason: "This field will be removed in Saleor 4.0. Use the `description` field instead.") - - """Returns translated product fields for the given language code.""" - translation( - """A language code to return the translation for product.""" - languageCode: LanguageCodeEnum! - ): ProductTranslation - - """Represents an individual item for sale in the storefront.""" - product: Product @deprecated(reason: "This field will be removed in Saleor 4.0. Get model fields from the root level queries.") - - """List of product attribute values that can be translated.""" - attributeValues: [AttributeValueTranslatableContent!]! -} - -""" -Represents attribute value's original translatable fields and related translations. -""" -type AttributeValueTranslatableContent implements Node @doc(category: "Attributes") { - """The ID of the attribute value translatable content.""" - id: ID! - - """Name of the attribute value to translate.""" - name: String! - - """ - Attribute value. - - Rich text format. For reference see https://editorjs.io/ - """ - richText: JSONString - - """Attribute plain text value.""" - plainText: String - - """Returns translated attribute value fields for the given language code.""" - translation( - """A language code to return the translation for attribute value.""" - languageCode: LanguageCodeEnum! - ): AttributeValueTranslation - - """Represents a value of an attribute.""" - attributeValue: AttributeValue @deprecated(reason: "This field will be removed in Saleor 4.0. Get model fields from the root level queries.") - - """ - Associated attribute that can be translated. - - Added in Saleor 3.9. - """ - attribute: AttributeTranslatableContent -} - -""" -Represents attribute's original translatable fields and related translations. -""" -type AttributeTranslatableContent implements Node @doc(category: "Attributes") { - """The ID of the attribute.""" - id: ID! - - """Name of the attribute to translate.""" - name: String! - - """Returns translated attribute fields for the given language code.""" - translation( - """A language code to return the translation for attribute.""" - languageCode: LanguageCodeEnum! - ): AttributeTranslation - - """Custom attribute of a product.""" - attribute: Attribute @deprecated(reason: "This field will be removed in Saleor 4.0. Get model fields from the root level queries.") + + """ + Represents the collection fields to translate. + + Added in Saleor 3.14. + """ + translatableContent: CollectionTranslatableContent } """ @@ -8583,6 +8600,13 @@ type CollectionTranslatableContent implements Node @doc(category: "Products") { """The ID of the collection translatable content.""" id: ID! + """ + The ID of the collection to translate. + + Added in Saleor 3.14. + """ + collectionId: ID! + """SEO title to translate.""" seoTitle: String @@ -8616,69 +8640,174 @@ type CollectionTranslatableContent implements Node @doc(category: "Products") { collection: Collection @deprecated(reason: "This field will be removed in Saleor 4.0. Get model fields from the root level queries.") } +"""Represents collection channel listing.""" +type CollectionChannelListing implements Node @doc(category: "Products") { + """The ID of the collection channel listing.""" + id: ID! + publicationDate: Date @deprecated(reason: "This field will be removed in Saleor 4.0. Use the `publishedAt` field to fetch the publication date.") + + """ + The collection publication date. + + Added in Saleor 3.3. + """ + publishedAt: DateTime + + """Indicates if the collection is published in the channel.""" + isPublished: Boolean! + + """The channel to which the collection belongs.""" + channel: Channel! +} + +"""Represents product translations.""" +type ProductTranslation implements Node @doc(category: "Products") { + """The ID of the product translation.""" + id: ID! + + """Translation language.""" + language: LanguageDisplay! + + """Translated SEO title.""" + seoTitle: String + + """Translated SEO description.""" + seoDescription: String + + """Translated product name.""" + name: String + + """ + Translated description of the product. + + Rich text format. For reference see https://editorjs.io/ + """ + description: JSONString + + """ + Translated description of the product. + + Rich text format. For reference see https://editorjs.io/ + """ + descriptionJson: JSONString @deprecated(reason: "This field will be removed in Saleor 4.0. Use the `description` field instead.") + + """ + Represents the product fields to translate. + + Added in Saleor 3.14. + """ + translatableContent: ProductTranslatableContent +} + """ -Represents category original translatable fields and related translations. +Represents product's original translatable fields and related translations. """ -type CategoryTranslatableContent implements Node @doc(category: "Products") { - """The ID of the category translatable content.""" +type ProductTranslatableContent implements Node @doc(category: "Products") { + """The ID of the product translatable content.""" id: ID! + """ + The ID of the product to translate. + + Added in Saleor 3.14. + """ + productId: ID! + """SEO title to translate.""" seoTitle: String """SEO description to translate.""" seoDescription: String - """Name of the category translatable content.""" + """Product's name to translate.""" name: String! """ - Category description to translate. + Product's description to translate. Rich text format. For reference see https://editorjs.io/ """ description: JSONString """ - Description of the category. + Description of the product. Rich text format. For reference see https://editorjs.io/ """ descriptionJson: JSONString @deprecated(reason: "This field will be removed in Saleor 4.0. Use the `description` field instead.") - """Returns translated category fields for the given language code.""" + """Returns translated product fields for the given language code.""" translation( - """A language code to return the translation for category.""" + """A language code to return the translation for product.""" languageCode: LanguageCodeEnum! - ): CategoryTranslation + ): ProductTranslation - """Represents a single category of products.""" - category: Category @deprecated(reason: "This field will be removed in Saleor 4.0. Get model fields from the root level queries.") + """Represents an individual item for sale in the storefront.""" + product: Product @deprecated(reason: "This field will be removed in Saleor 4.0. Get model fields from the root level queries.") + + """List of product attribute values that can be translated.""" + attributeValues: [AttributeValueTranslatableContent!]! } -""" -Represents product variant's original translatable fields and related translations. -""" -type ProductVariantTranslatableContent implements Node @doc(category: "Products") { - """The ID of the product variant translatable content.""" - id: ID! +type WarehouseCountableConnection @doc(category: "Products") { + """Pagination data for this connection.""" + pageInfo: PageInfo! + edges: [WarehouseCountableEdge!]! - """Name of the product variant to translate.""" - name: String! + """A total count of items in the collection.""" + totalCount: Int +} - """Returns translated product variant fields for the given language code.""" - translation( - """A language code to return the translation for product variant.""" - languageCode: LanguageCodeEnum! - ): ProductVariantTranslation +type WarehouseCountableEdge @doc(category: "Products") { + """The item at the end of the edge.""" + node: Warehouse! - """Represents a version of a product such as different size or color.""" - productVariant: ProductVariant @deprecated(reason: "This field will be removed in Saleor 4.0. Get model fields from the root level queries.") + """A cursor for use in pagination.""" + cursor: String! +} - """List of product variant attribute values that can be translated.""" - attributeValues: [AttributeValueTranslatableContent!]! +input WarehouseFilterInput @doc(category: "Products") { + clickAndCollectOption: WarehouseClickAndCollectOptionEnum + metadata: [MetadataFilter!] + search: String + ids: [ID!] + isPrivate: Boolean + channels: [ID!] + slugs: [String!] +} + +input WarehouseSortingInput @doc(category: "Products") { + """Specifies the direction in which to sort warehouses.""" + direction: OrderDirection! + + """Sort warehouses by the selected field.""" + field: WarehouseSortField! +} + +enum WarehouseSortField @doc(category: "Products") { + """Sort warehouses by name.""" + NAME +} + +type TranslatableItemConnection { + """Pagination data for this connection.""" + pageInfo: PageInfo! + edges: [TranslatableItemEdge!]! + + """A total count of items in the collection.""" + totalCount: Int +} + +type TranslatableItemEdge { + """The item at the end of the edge.""" + node: TranslatableItem! + + """A cursor for use in pagination.""" + cursor: String! } +union TranslatableItem = ProductTranslatableContent | CollectionTranslatableContent | CategoryTranslatableContent | AttributeTranslatableContent | AttributeValueTranslatableContent | ProductVariantTranslatableContent | PageTranslatableContent | ShippingMethodTranslatableContent | VoucherTranslatableContent | MenuItemTranslatableContent | PromotionTranslatableContent | PromotionRuleTranslatableContent | SaleTranslatableContent + """ Represents page's original translatable fields and related translations. """ @@ -8686,6 +8815,13 @@ type PageTranslatableContent implements Node @doc(category: "Pages") { """The ID of the page translatable content.""" id: ID! + """ + The ID of the page to translate. + + Added in Saleor 3.14. + """ + pageId: ID! + """SEO title to translate.""" seoTitle: String @@ -8754,6 +8890,13 @@ type PageTranslation implements Node @doc(category: "Pages") { Rich text format. For reference see https://editorjs.io/ """ contentJson: JSONString @deprecated(reason: "This field will be removed in Saleor 4.0. Use the `content` field instead.") + + """ + Represents the page fields to translate. + + Added in Saleor 3.14. + """ + translatableContent: PageTranslatableContent } """ @@ -8943,42 +9086,18 @@ type PageType implements Node & ObjectWithMetadata @doc(category: "Pages") { } """ -Represents shipping method's original translatable fields and related translations. +Represents voucher's original translatable fields and related translations. """ -type ShippingMethodTranslatableContent implements Node @doc(category: "Shipping") { - """The ID of the shipping method translatable content.""" +type VoucherTranslatableContent implements Node @doc(category: "Discounts") { + """The ID of the voucher translatable content.""" id: ID! - """Shipping method name to translate.""" - name: String! - - """ - Shipping method description to translate. - - Rich text format. For reference see https://editorjs.io/ - """ - description: JSONString - - """Returns translated shipping method fields for the given language code.""" - translation( - """A language code to return the translation for shipping method.""" - languageCode: LanguageCodeEnum! - ): ShippingMethodTranslation - """ - Shipping method are the methods you'll use to get customer's orders to them. They are directly exposed to the customers. + The ID of the voucher to translate. - Requires one of the following permissions: MANAGE_SHIPPING. + Added in Saleor 3.14. """ - shippingMethod: ShippingMethodType @deprecated(reason: "This field will be removed in Saleor 4.0. Get model fields from the root level queries.") -} - -""" -Represents voucher's original translatable fields and related translations. -""" -type VoucherTranslatableContent implements Node @doc(category: "Discounts") { - """The ID of the voucher translatable content.""" - id: ID! + voucherId: ID! """Voucher name to translate.""" name: String @@ -9007,6 +9126,13 @@ type VoucherTranslation implements Node @doc(category: "Discounts") { """Translated voucher name.""" name: String + + """ + Represents the voucher fields to translate. + + Added in Saleor 3.14. + """ + translatableContent: VoucherTranslatableContent } """ @@ -9354,6 +9480,13 @@ type MenuItemTranslatableContent implements Node @doc(category: "Menu") { """The ID of the menu item translatable content.""" id: ID! + """ + The ID of the menu item to translate. + + Added in Saleor 3.14. + """ + menuItemId: ID! + """Name of the menu item to translate.""" name: String! @@ -9379,6 +9512,13 @@ type MenuItemTranslation implements Node @doc(category: "Menu") { """Translated menu item name.""" name: String! + + """ + Represents the menu item fields to translate. + + Added in Saleor 3.14. + """ + translatableContent: MenuItemTranslatableContent } """ @@ -9528,6 +9668,9 @@ type PromotionTranslatableContent implements Node @doc(category: "Discounts") { """ID of the promotion translatable content.""" id: ID! + """ID of the promotion to translate.""" + promotionId: ID! + """Name of the promotion.""" name: String! @@ -9566,6 +9709,13 @@ type PromotionTranslation implements Node @doc(category: "Discounts") { Rich text format. For reference see https://editorjs.io/ """ description: JSONString + + """ + Represents the promotion fields to translate. + + Added in Saleor 3.14. + """ + translatableContent: PromotionTranslatableContent } """ @@ -9577,6 +9727,13 @@ type PromotionRuleTranslatableContent implements Node @doc(category: "Discounts" """ID of the promotion rule translatable content.""" id: ID! + """ + ID of the promotion rule to translate. + + Added in Saleor 3.14. + """ + promotionRuleId: ID! + """Name of the promotion rule.""" name: String @@ -9615,6 +9772,13 @@ type PromotionRuleTranslation implements Node @doc(category: "Discounts") { Rich text format. For reference see https://editorjs.io/ """ description: JSONString + + """ + Represents the promotion rule fields to translate. + + Added in Saleor 3.14. + """ + translatableContent: PromotionRuleTranslatableContent } """ @@ -9626,6 +9790,13 @@ type SaleTranslatableContent implements Node @doc(category: "Discounts") { """The ID of the sale translatable content.""" id: ID! + """ + The ID of the sale to translate. + + Added in Saleor 3.14. + """ + saleId: ID! + """Name of the sale to translate.""" name: String! @@ -9657,6 +9828,13 @@ type SaleTranslation implements Node @doc(category: "Discounts") { """Translated name of sale.""" name: String + + """ + Represents the sale fields to translate. + + Added in Saleor 3.14. + """ + translatableContent: SaleTranslatableContent } """ @@ -10188,7 +10366,7 @@ type Shop implements ObjectWithMetadata { enableAccountConfirmationByEmail: Boolean """ - Determines if user can login without confirmation when `enableAccountConfrimation` is enabled. + Determines if user can login without confirmation when `enableAccountConfirmation` is enabled. Added in Saleor 3.15. @@ -15015,7 +15193,7 @@ type AddressValidationData @doc(category: "Users") { Many fields in the JSON refer to address fields by one-letter abbreviations. These are defined as follows: - `N`: Name - - `O`: Organisation + - `O`: Organization - `A`: Street Address Line(s) - `D`: Dependent locality (may be an inner-city district or a suburb) - `C`: City or Locality @@ -15033,7 +15211,7 @@ type AddressValidationData @doc(category: "Users") { Many fields in the JSON refer to address fields by one-letter abbreviations. These are defined as follows: - `N`: Name - - `O`: Organisation + - `O`: Organization - `A`: Street Address Line(s) - `D`: Dependent locality (may be an inner-city district or a suburb) - `C`: City or Locality @@ -27552,7 +27730,7 @@ input ExternalNotificationTriggerInput { ids: [ID!]! """ - Additional payload that will be merged with the one based on the bussines object ID. + Additional payload that will be merged with the one based on the business object ID. """ extraPayload: JSONString @@ -28864,7 +29042,7 @@ type CheckoutError @doc(category: "Checkout") { """The error code.""" code: CheckoutErrorCode! - """List of varint IDs which causes the error.""" + """List of variant IDs which causes the error.""" variants: [ID!] """List of line Ids which cause the error.""" @@ -29567,7 +29745,7 @@ input OrderSettingsInput @doc(category: "Orders") { automaticallyConfirmAllNewOrders: Boolean """ - When enabled, all non-shippable gift card orders will be fulfilled automatically. By defualt set to True. + When enabled, all non-shippable gift card orders will be fulfilled automatically. By default set to True. """ automaticallyFulfillNonShippableGiftCard: Boolean @@ -29926,7 +30104,7 @@ input AttributeValueCreateInput @doc(category: "Attributes") { richText: JSONString """ - Represents the text of the attribute value, plain text without formating. + Represents the text of the attribute value, plain text without formatting. DEPRECATED: this field will be removed in Saleor 4.0.The plain text attribute hasn't got predefined value, so can be specified only from instance that supports the given attribute. """ @@ -30055,7 +30233,7 @@ input AttributeValueUpdateInput @doc(category: "Attributes") { richText: JSONString """ - Represents the text of the attribute value, plain text without formating. + Represents the text of the attribute value, plain text without formatting. DEPRECATED: this field will be removed in Saleor 4.0.The plain text attribute hasn't got predefined value, so can be specified only from instance that supports the given attribute. """ @@ -31889,7 +32067,7 @@ type PermissionGroupError @doc(category: "Users") { """List of user IDs which causes the error.""" users: [ID!] - """List of chnnels IDs which causes the error.""" + """List of channels IDs which causes the error.""" channels: [ID!] } diff --git a/saleor/graphql/shop/types.py b/saleor/graphql/shop/types.py index 3a8cea914af..bf91ae960ff 100644 --- a/saleor/graphql/shop/types.py +++ b/saleor/graphql/shop/types.py @@ -324,7 +324,7 @@ class Shop(graphene.ObjectType): graphene.Boolean, description=( "Determines if user can login without confirmation when " - "`enableAccountConfrimation` is enabled." + ADDED_IN_315 + "`enableAccountConfirmation` is enabled." + ADDED_IN_315 ), permissions=[SitePermissions.MANAGE_SETTINGS], ) diff --git a/saleor/graphql/translations/tests/test_translations.py b/saleor/graphql/translations/tests/test_translations.py index 59bbce2f7d6..4195bcdffed 100644 --- a/saleor/graphql/translations/tests/test_translations.py +++ b/saleor/graphql/translations/tests/test_translations.py @@ -3206,8 +3206,10 @@ def test_translations_query_inline_fragment( __typename ...on ProductTranslatableContent{ id + productId name translation(languageCode: $languageCode){ + id name } } @@ -3223,7 +3225,9 @@ def test_translation_query_product( product_translation_fr, ): product_id = graphene.Node.to_global_id("Product", product.id) - + translation_id = graphene.Node.to_global_id( + "ProductTranslation", product_translation_fr.id + ) variables = { "id": product_id, "kind": TranslatableKinds.PRODUCT.name, @@ -3236,7 +3240,9 @@ def test_translation_query_product( ) content = get_graphql_content(response) data = content["data"]["translation"] + assert data["productId"] == product_id assert data["name"] == product.name + assert data["translation"]["id"] == translation_id assert data["translation"]["name"] == product_translation_fr.name @@ -3248,8 +3254,10 @@ def test_translation_query_product( __typename ...on CollectionTranslatableContent{ id + collectionId name translation(languageCode: $languageCode){ + id name } } @@ -3268,6 +3276,9 @@ def test_translation_query_collection( channel_listing = published_collection.channel_listings.get() channel_listing.save() collection_id = graphene.Node.to_global_id("Collection", published_collection.id) + translation_id = graphene.Node.to_global_id( + "CollectionTranslation", collection_translation_fr.id + ) variables = { "id": collection_id, @@ -3281,7 +3292,9 @@ def test_translation_query_collection( ) content = get_graphql_content(response) data = content["data"]["translation"] + assert data["collectionId"] == collection_id assert data["name"] == published_collection.name + assert data["translation"]["id"] == translation_id assert data["translation"]["name"] == collection_translation_fr.name @@ -3293,9 +3306,15 @@ def test_translation_query_collection( __typename ...on CategoryTranslatableContent{ id + categoryId name translation(languageCode: $languageCode){ + id name + translatableContent { + id + name + } } } } @@ -3307,12 +3326,15 @@ def test_translation_query_category( staff_api_client, category, category_translation_fr, permission_manage_translations ): category_id = graphene.Node.to_global_id("Category", category.id) - + translation_id = graphene.Node.to_global_id( + "CategoryTranslation", category_translation_fr.id + ) variables = { "id": category_id, "kind": TranslatableKinds.CATEGORY.name, "languageCode": LanguageCodeEnum.FR.name, } + response = staff_api_client.post_graphql( QUERY_TRANSLATION_CATEGORY, variables, @@ -3320,8 +3342,10 @@ def test_translation_query_category( ) content = get_graphql_content(response) data = content["data"]["translation"] + assert data["categoryId"] == category_id assert data["name"] == category.name assert data["translation"]["name"] == category_translation_fr.name + assert data["translation"]["id"] == translation_id QUERY_TRANSLATION_ATTRIBUTE = """ @@ -3332,8 +3356,10 @@ def test_translation_query_category( __typename ...on AttributeTranslatableContent{ id + attributeId name translation(languageCode: $languageCode){ + id name } } @@ -3347,6 +3373,9 @@ def test_translation_query_attribute( ): attribute = translated_attribute.attribute attribute_id = graphene.Node.to_global_id("Attribute", attribute.id) + translation_id = graphene.Node.to_global_id( + "AttributeTranslation", translated_attribute.id + ) variables = { "id": attribute_id, @@ -3360,7 +3389,9 @@ def test_translation_query_attribute( ) content = get_graphql_content(response) data = content["data"]["translation"] + assert data["attributeId"] == attribute_id assert data["name"] == attribute.name + assert data["translation"]["id"] == translation_id assert data["translation"]["name"] == translated_attribute.name @@ -3372,8 +3403,10 @@ def test_translation_query_attribute( __typename ...on AttributeValueTranslatableContent{ id + attributeValueId name translation(languageCode: $languageCode){ + id name } } @@ -3391,6 +3424,9 @@ def test_translation_query_attribute_value( attribute_value_id = graphene.Node.to_global_id( "AttributeValue", pink_attribute_value.id ) + translation_id = graphene.Node.to_global_id( + "AttributeValueTranslation", translated_attribute_value.id + ) variables = { "id": attribute_value_id, @@ -3404,7 +3440,9 @@ def test_translation_query_attribute_value( ) content = get_graphql_content(response) data = content["data"]["translation"] + assert data["attributeValueId"] == attribute_value_id assert data["name"] == pink_attribute_value.name + assert data["translation"]["id"] == translation_id assert data["translation"]["name"] == translated_attribute_value.name @@ -3416,8 +3454,10 @@ def test_translation_query_attribute_value( __typename ...on ProductVariantTranslatableContent{ id + productVariantId name translation(languageCode: $languageCode){ + id name } } @@ -3434,6 +3474,9 @@ def test_translation_query_variant( variant_translation_fr, ): variant_id = graphene.Node.to_global_id("ProductVariant", variant.id) + translation_id = graphene.Node.to_global_id( + "ProductVariantTranslation", variant_translation_fr.id + ) variables = { "id": variant_id, "kind": TranslatableKinds.VARIANT.name, @@ -3446,7 +3489,9 @@ def test_translation_query_variant( ) content = get_graphql_content(response) data = content["data"]["translation"] + assert data["productVariantId"] == variant_id assert data["name"] == variant.name + assert data["translation"]["id"] == translation_id assert data["translation"]["name"] == variant_translation_fr.name @@ -3458,8 +3503,10 @@ def test_translation_query_variant( __typename ...on PageTranslatableContent{ id + pageId title translation(languageCode: $languageCode){ + id title } } @@ -3487,6 +3534,9 @@ def test_translation_query_page( page.save() page_id = graphene.Node.to_global_id("Page", page.id) + translation_id = graphene.Node.to_global_id( + "PageTranslation", page_translation_fr.id + ) perms = list(Permission.objects.filter(codename__in=perm_codenames)) variables = { @@ -3499,7 +3549,9 @@ def test_translation_query_page( ) content = get_graphql_content(response) data = content["data"]["translation"] + assert data["pageId"] == page_id assert data["title"] == page.title + assert data["translation"]["id"] == translation_id assert data["translation"]["title"] == page_translation_fr.title @@ -3511,9 +3563,11 @@ def test_translation_query_page( __typename ...on ShippingMethodTranslatableContent{ id + shippingMethodId name description translation(languageCode: $languageCode){ + id name } } @@ -3539,6 +3593,9 @@ def test_translation_query_shipping_method( shipping_method_id = graphene.Node.to_global_id( "ShippingMethodType", shipping_method.id ) + translation_id = graphene.Node.to_global_id( + "ShippingMethodTranslation", shipping_method_translation_fr.id + ) perms = list(Permission.objects.filter(codename__in=perm_codenames)) variables = { @@ -3551,8 +3608,10 @@ def test_translation_query_shipping_method( ) content = get_graphql_content(response, ignore_errors=True) data = content["data"]["translation"] + assert data["shippingMethodId"] == shipping_method_id assert data["name"] == shipping_method.name assert data["description"] == shipping_method.description + assert data["translation"]["id"] == translation_id assert data["translation"]["name"] == shipping_method_translation_fr.name @@ -3564,8 +3623,10 @@ def test_translation_query_shipping_method( __typename ...on SaleTranslatableContent{ id + saleId name translation(languageCode: $languageCode){ + id name } } @@ -3585,6 +3646,9 @@ def test_translation_query_sale( promotion = promotion_converted_from_sale promotion_translation = promotion.translations.first() sale_id = graphene.Node.to_global_id("Sale", promotion.old_sale_id) + translation_id = graphene.Node.to_global_id( + "SaleTranslation", promotion_converted_from_sale_translation_fr.id + ) variables = { "id": sale_id, @@ -3602,7 +3666,9 @@ def test_translation_query_sale( # then content = get_graphql_content(response, ignore_errors=True) data = content["data"]["translation"] + assert data["saleId"] == sale_id assert data["name"] == promotion.name + assert data["translation"]["id"] == translation_id assert data["translation"]["name"] == promotion_translation.name @@ -3614,8 +3680,10 @@ def test_translation_query_sale( __typename ...on VoucherTranslatableContent{ id + voucherId name translation(languageCode: $languageCode){ + id name } } @@ -3635,6 +3703,9 @@ def test_translation_query_voucher( staff_api_client, voucher, voucher_translation_fr, perm_codenames, return_voucher ): voucher_id = graphene.Node.to_global_id("Voucher", voucher.id) + translation_id = graphene.Node.to_global_id( + "VoucherTranslation", voucher_translation_fr.id + ) perms = list(Permission.objects.filter(codename__in=perm_codenames)) variables = { @@ -3647,7 +3718,9 @@ def test_translation_query_voucher( ) content = get_graphql_content(response, ignore_errors=True) data = content["data"]["translation"] + assert data["voucherId"] == voucher_id assert data["name"] == voucher.name + assert data["translation"]["id"] == translation_id assert data["translation"]["name"] == voucher_translation_fr.name @@ -3659,8 +3732,10 @@ def test_translation_query_voucher( __typename ...on MenuItemTranslatableContent{ id + menuItemId name translation(languageCode: $languageCode){ + id name } } @@ -3676,6 +3751,9 @@ def test_translation_query_menu_item( permission_manage_translations, ): menu_item_id = graphene.Node.to_global_id("MenuItem", menu_item.id) + translation_id = graphene.Node.to_global_id( + "MenuItemTranslation", menu_item_translation_fr.id + ) variables = { "id": menu_item_id, @@ -3689,7 +3767,9 @@ def test_translation_query_menu_item( ) content = get_graphql_content(response) data = content["data"]["translation"] + assert data["menuItemId"] == menu_item_id assert data["name"] == menu_item.name + assert data["translation"]["id"] == translation_id assert data["translation"]["name"] == menu_item_translation_fr.name diff --git a/saleor/graphql/translations/types.py b/saleor/graphql/translations/types.py index df4281852ac..1c780cc7f04 100644 --- a/saleor/graphql/translations/types.py +++ b/saleor/graphql/translations/types.py @@ -19,11 +19,12 @@ from ...product import models as product_models from ...shipping import models as shipping_models from ...site import models as site_models -from ..attribute.dataloaders import AttributesByAttributeId +from ..attribute.dataloaders import AttributesByAttributeId, AttributeValueByIdLoader from ..channel import ChannelContext from ..core.context import get_database_connection_name from ..core.descriptions import ( ADDED_IN_39, + ADDED_IN_314, ADDED_IN_317, DEPRECATED_IN_3X_FIELD, DEPRECATED_IN_3X_TYPE, @@ -34,15 +35,27 @@ from ..core.tracing import traced_resolver from ..core.types import LanguageDisplay, ModelObjectType, NonNullList from ..core.utils import str_to_enum +from ..discount.dataloaders import ( + PromotionByIdLoader, + PromotionRuleByIdLoader, + VoucherByIdLoader, +) +from ..menu.dataloaders import MenuItemByIdLoader from ..page.dataloaders import ( + PageByIdLoader, SelectedAttributesAllByPageIdLoader, SelectedAttributesVisibleInStorefrontPageIdLoader, ) from ..product.dataloaders import ( + CategoryByIdLoader, + CollectionByIdLoader, + ProductByIdLoader, + ProductVariantByIdLoader, SelectedAttributesAllByProductIdLoader, SelectedAttributesByProductVariantIdLoader, SelectedAttributesVisibleInStorefrontByProductIdLoader, ) +from ..shipping.dataloaders import ShippingMethodByIdLoader from ..utils import get_user_or_app_from_context from .fields import TranslationField @@ -100,27 +113,52 @@ class AttributeValueTranslation( description="Translated rich-text attribute value." + RICH_CONTENT ) plain_text = graphene.String(description="Translated plain text attribute value .") + translatable_content = graphene.Field( + "saleor.graphql.translations.types.AttributeValueTranslatableContent", + description="Represents the attribute value fields to translate." + + ADDED_IN_314, + ) class Meta: model = attribute_models.AttributeValueTranslation interfaces = [graphene.relay.Node] description = "Represents attribute value translations." + @staticmethod + def resolve_translatable_content( + root: attribute_models.AttributeValueTranslation, info + ): + return AttributeValueByIdLoader(info.context).load(root.attribute_value_id) + class AttributeTranslation(BaseTranslationType[attribute_models.AttributeTranslation]): id = graphene.GlobalID( required=True, description="The ID of the attribute translation." ) name = graphene.String(required=True, description="Translated attribute name.") + translatable_content = graphene.Field( + "saleor.graphql.translations.types.AttributeTranslatableContent", + description="Represents the attribute fields to translate." + ADDED_IN_314, + ) class Meta: model = attribute_models.AttributeTranslation interfaces = [graphene.relay.Node] description = "Represents attribute translations." + @staticmethod + def resolve_translatable_content(root: attribute_models.AttributeTranslation, info): + return AttributesByAttributeId(info.context).load(root.attribute_id) + class AttributeTranslatableContent(ModelObjectType[attribute_models.Attribute]): - id = graphene.GlobalID(required=True, description="The ID of the attribute.") + id = graphene.GlobalID( + required=True, description="The ID of the attribute translatable content." + ) + attribute_id = graphene.ID( + required=True, + description="The ID of the attribute to translate." + ADDED_IN_314, + ) name = graphene.String( required=True, description="Name of the attribute to translate." ) @@ -145,6 +183,10 @@ class Meta: def resolve_attribute(root: attribute_models.Attribute, _info): return root + @staticmethod + def resolve_attribute_id(root: attribute_models.Attribute, _info): + return graphene.Node.to_global_id("Attribute", root.id) + class AttributeValueTranslatableContent( ModelObjectType[attribute_models.AttributeValue] @@ -152,6 +194,10 @@ class AttributeValueTranslatableContent( id = graphene.GlobalID( required=True, description="The ID of the attribute value translatable content." ) + attribute_value_id = graphene.ID( + required=True, + description="The ID of the attribute value to translate." + ADDED_IN_314, + ) name = graphene.String( required=True, description="Name of the attribute value to translate.", @@ -189,6 +235,10 @@ def resolve_attribute_value(root: attribute_models.AttributeValue, _info): def resolve_attribute(root: attribute_models.AttributeValue, info): return AttributesByAttributeId(info.context).load(root.attribute_id) + @staticmethod + def resolve_attribute_value_id(root: attribute_models.AttributeValue, _info): + return graphene.Node.to_global_id("AttributeValue", root.id) + class ProductVariantTranslation( BaseTranslationType[product_models.ProductVariantTranslation] @@ -199,17 +249,32 @@ class ProductVariantTranslation( name = graphene.String( required=True, description="Translated product variant name." ) + translatable_content = graphene.Field( + "saleor.graphql.translations.types.ProductVariantTranslatableContent", + description="Represents the product variant fields to translate." + + ADDED_IN_314, + ) class Meta: model = product_models.ProductVariantTranslation interfaces = [graphene.relay.Node] description = "Represents product variant translations." + @staticmethod + def resolve_translatable_content( + root: product_models.ProductVariantTranslation, info + ): + return ProductVariantByIdLoader(info.context).load(root.product_variant_id) + class ProductVariantTranslatableContent(ModelObjectType[product_models.ProductVariant]): id = graphene.GlobalID( required=True, description="The ID of the product variant translatable content." ) + product_variant_id = graphene.ID( + required=True, + description="The ID of the product variant to translate." + ADDED_IN_314, + ) name = graphene.String( required=True, description="Name of the product variant to translate.", @@ -252,6 +317,10 @@ def resolve_attribute_values(root: product_models.ProductVariant, info): .then(get_translatable_attribute_values) ) + @staticmethod + def resolve_product_variant_id(root: product_models.ProductVariant, _info): + return graphene.Node.to_global_id("ProductVariant", root.id) + class ProductTranslation(BaseTranslationType[product_models.ProductTranslation]): id = graphene.GlobalID( @@ -269,6 +338,10 @@ class ProductTranslation(BaseTranslationType[product_models.ProductTranslation]) f"{DEPRECATED_IN_3X_FIELD} Use the `description` field instead." ), ) + translatable_content = graphene.Field( + "saleor.graphql.translations.types.ProductTranslatableContent", + description="Represents the product fields to translate." + ADDED_IN_314, + ) class Meta: model = product_models.ProductTranslation @@ -280,11 +353,19 @@ def resolve_description_json(root: product_models.ProductTranslation, _info): description = root.description return description if description is not None else {} + @staticmethod + def resolve_translatable_content(root: product_models.ProductTranslation, info): + return ProductByIdLoader(info.context).load(root.product_id) + class ProductTranslatableContent(ModelObjectType[product_models.Product]): id = graphene.GlobalID( required=True, description="The ID of the product translatable content." ) + product_id = graphene.ID( + required=True, + description="The ID of the product to translate." + ADDED_IN_314, + ) seo_title = graphene.String(description="SEO title to translate.") seo_description = graphene.String(description="SEO description to translate.") name = graphene.String(required=True, description="Product's name to translate.") @@ -348,6 +429,10 @@ def resolve_attribute_values(root: product_models.Product, info): .then(get_translatable_attribute_values) ) + @staticmethod + def resolve_product_id(root: product_models.Product, _info): + return graphene.Node.to_global_id("Product", root.id) + class CollectionTranslation(BaseTranslationType[product_models.CollectionTranslation]): id = graphene.GlobalID( @@ -365,6 +450,10 @@ class CollectionTranslation(BaseTranslationType[product_models.CollectionTransla f"{DEPRECATED_IN_3X_FIELD} Use the `description` field instead." ), ) + translatable_content = graphene.Field( + "saleor.graphql.translations.types.CollectionTranslatableContent", + description="Represents the collection fields to translate." + ADDED_IN_314, + ) class Meta: model = product_models.CollectionTranslation @@ -376,11 +465,19 @@ def resolve_description_json(root: product_models.CollectionTranslation, _info): description = root.description return description if description is not None else {} + @staticmethod + def resolve_translatable_content(root: product_models.CollectionTranslation, info): + return CollectionByIdLoader(info.context).load(root.collection_id) + class CollectionTranslatableContent(ModelObjectType[product_models.Collection]): id = graphene.GlobalID( required=True, description="The ID of the collection translatable content." ) + collection_id = graphene.ID( + required=True, + description="The ID of the collection to translate." + ADDED_IN_314, + ) seo_title = graphene.String(description="SEO title to translate.") seo_description = graphene.String(description="SEO description to translate.") name = graphene.String(required=True, description="Collection's name to translate.") @@ -429,6 +526,10 @@ def resolve_description_json(root: product_models.Collection, _info): description = root.description return description if description is not None else {} + @staticmethod + def resolve_collection_id(root: product_models.Collection, _info): + return graphene.Node.to_global_id("Collection", root.id) + class CategoryTranslation(BaseTranslationType[product_models.CategoryTranslation]): id = graphene.GlobalID( @@ -446,6 +547,10 @@ class CategoryTranslation(BaseTranslationType[product_models.CategoryTranslation f"{DEPRECATED_IN_3X_FIELD} Use the `description` field instead." ), ) + translatable_content = graphene.Field( + "saleor.graphql.translations.types.CategoryTranslatableContent", + description="Represents the category fields to translate." + ADDED_IN_314, + ) class Meta: model = product_models.CategoryTranslation @@ -457,11 +562,19 @@ def resolve_description_json(root: product_models.CategoryTranslation, _info): description = root.description return description if description is not None else {} + @staticmethod + def resolve_translatable_content(root: product_models.CategoryTranslation, info): + return CategoryByIdLoader(info.context).load(root.category_id) + class CategoryTranslatableContent(ModelObjectType[product_models.Category]): id = graphene.GlobalID( required=True, description="The ID of the category translatable content." ) + category_id = graphene.ID( + required=True, + description="The ID of the category to translate." + ADDED_IN_314, + ) seo_title = graphene.String(description="SEO title to translate.") seo_description = graphene.String(description="SEO description to translate.") name = graphene.String( @@ -501,6 +614,10 @@ def resolve_description_json(root: product_models.Category, _info): description = root.description return description if description is not None else {} + @staticmethod + def resolve_category_id(root: product_models.Category, _info): + return graphene.Node.to_global_id("Category", root.id) + class PageTranslation(BaseTranslationType[page_models.PageTranslation]): id = graphene.GlobalID(required=True, description="The ID of the page translation.") @@ -512,6 +629,10 @@ class PageTranslation(BaseTranslationType[page_models.PageTranslation]): description="Translated description of the page." + RICH_CONTENT, deprecation_reason=f"{DEPRECATED_IN_3X_FIELD} Use the `content` field instead.", ) + translatable_content = graphene.Field( + "saleor.graphql.translations.types.PageTranslatableContent", + description="Represents the page fields to translate." + ADDED_IN_314, + ) class Meta: model = page_models.PageTranslation @@ -523,11 +644,18 @@ def resolve_content_json(root: page_models.PageTranslation, _info): content = root.content return content if content is not None else {} + @staticmethod + def resolve_translatable_content(root: page_models.PageTranslation, info): + return PageByIdLoader(info.context).load(root.page_id) + class PageTranslatableContent(ModelObjectType[page_models.Page]): id = graphene.GlobalID( required=True, description="The ID of the page translatable content." ) + page_id = graphene.ID( + required=True, description="The ID of the page to translate." + ADDED_IN_314 + ) seo_title = graphene.String(description="SEO title to translate.") seo_description = graphene.String(description="SEO description to translate.") title = graphene.String(required=True, description="Page title to translate.") @@ -594,23 +722,39 @@ def resolve_attribute_values(root: page_models.Page, info): .then(get_translatable_attribute_values) ) + @staticmethod + def resolve_page_id(root: page_models.Page, _info): + return graphene.Node.to_global_id("Page", root.id) + class VoucherTranslation(BaseTranslationType[discount_models.VoucherTranslation]): id = graphene.GlobalID( required=True, description="The ID of the voucher translation." ) name = graphene.String(description="Translated voucher name.") + translatable_content = graphene.Field( + "saleor.graphql.translations.types.VoucherTranslatableContent", + description="Represents the voucher fields to translate." + ADDED_IN_314, + ) class Meta: model = discount_models.VoucherTranslation interfaces = [graphene.relay.Node] description = "Represents voucher translations." + @staticmethod + def resolve_translatable_content(root: discount_models.VoucherTranslation, info): + return VoucherByIdLoader(info.context).load(root.voucher_id) + class VoucherTranslatableContent(ModelObjectType[discount_models.Voucher]): id = graphene.GlobalID( required=True, description="The ID of the voucher translatable content." ) + voucher_id = graphene.ID( + required=True, + description="The ID of the voucher to translate." + ADDED_IN_314, + ) name = graphene.String(description="Voucher name to translate.") translation = TranslationField(VoucherTranslation, type_name="voucher") voucher = PermissionsField( @@ -638,10 +782,18 @@ class Meta: def resolve_voucher(root: discount_models.Voucher, _info): return ChannelContext(node=root, channel_slug=None) + @staticmethod + def resolve_voucher_id(root: discount_models.Voucher, _info): + return graphene.Node.to_global_id("Voucher", root.id) + class SaleTranslation(BaseTranslationType[discount_models.PromotionTranslation]): id = graphene.GlobalID(required=True, description="The ID of the sale translation.") name = graphene.String(description="Translated name of sale.") + translatable_content = graphene.Field( + "saleor.graphql.translations.types.SaleTranslatableContent", + description="Represents the sale fields to translate." + ADDED_IN_314, + ) class Meta: model = discount_models.PromotionTranslation @@ -652,11 +804,19 @@ class Meta: + " Use `PromotionTranslation` instead." ) + @staticmethod + def resolve_translatable_content(root: discount_models.PromotionTranslation, info): + return PromotionByIdLoader(info.context).load(root.promotion_id) + class SaleTranslatableContent(ModelObjectType[discount_models.Promotion]): id = graphene.GlobalID( required=True, description="The ID of the sale translatable content." ) + sale_id = graphene.ID( + required=True, + description="The ID of the sale to translate." + ADDED_IN_314, + ) name = graphene.String(required=True, description="Name of the sale to translate.") translation = TranslationField(SaleTranslation, type_name="sale") sale = PermissionsField( @@ -684,6 +844,10 @@ class Meta: def resolve_sale(root: discount_models.Promotion, _info): return ChannelContext(node=root, channel_slug=None) + @staticmethod + def resolve_sale_id(root: discount_models.Promotion, _info): + return graphene.Node.to_global_id("Sale", root.old_sale_id) + class ShopTranslation(BaseTranslationType[site_models.SiteSettingsTranslation]): id = graphene.GlobalID(required=True, description="The ID of the shop translation.") @@ -705,17 +869,29 @@ class MenuItemTranslation(BaseTranslationType[menu_models.MenuItemTranslation]): required=True, description="The ID of the menu item translation." ) name = graphene.String(required=True, description="Translated menu item name.") + translatable_content = graphene.Field( + "saleor.graphql.translations.types.MenuItemTranslatableContent", + description="Represents the menu item fields to translate." + ADDED_IN_314, + ) class Meta: model = menu_models.MenuItemTranslation interfaces = [graphene.relay.Node] description = "Represents menu item translations." + @staticmethod + def resolve_translatable_content(root: menu_models.MenuItemTranslation, info): + return MenuItemByIdLoader(info.context).load(root.menu_item_id) + class MenuItemTranslatableContent(ModelObjectType[menu_models.MenuItem]): id = graphene.GlobalID( required=True, description="The ID of the menu item translatable content." ) + menu_item_id = graphene.ID( + required=True, + description="The ID of the menu item to translate." + ADDED_IN_314, + ) name = graphene.String( required=True, description="Name of the menu item to translate." ) @@ -743,6 +919,10 @@ class Meta: def resolve_menu_item(root: menu_models.MenuItem, _info): return ChannelContext(node=root, channel_slug=None) + @staticmethod + def resolve_menu_item_id(root: menu_models.MenuItem, _info): + return graphene.Node.to_global_id("MenuItem", root.id) + class ShippingMethodTranslation( BaseTranslationType[shipping_models.ShippingMethodTranslation] @@ -756,12 +936,23 @@ class ShippingMethodTranslation( description = JSONString( description="Translated description of the shipping method." + RICH_CONTENT ) + translatable_content = graphene.Field( + "saleor.graphql.translations.types.ShippingMethodTranslatableContent", + description="Represents the shipping method fields to translate." + + ADDED_IN_314, + ) class Meta: model = shipping_models.ShippingMethodTranslation interfaces = [graphene.relay.Node] description = "Represents shipping method translations." + @staticmethod + def resolve_translatable_content( + root: shipping_models.ShippingMethodTranslation, info + ): + return ShippingMethodByIdLoader(info.context).load(root.shipping_method_id) + class ShippingMethodTranslatableContent( ModelObjectType[shipping_models.ShippingMethod] @@ -769,6 +960,10 @@ class ShippingMethodTranslatableContent( id = graphene.GlobalID( required=True, description="The ID of the shipping method translatable content." ) + shipping_method_id = graphene.ID( + required=True, + description="The ID of the shipping method to translate." + ADDED_IN_314, + ) name = graphene.String( required=True, description="Shipping method name to translate." ) @@ -804,6 +999,10 @@ class Meta: def resolve_shipping_method(root: shipping_models.ShippingMethod, _info): return ChannelContext(node=root, channel_slug=None) + @staticmethod + def resolve_shipping_method_id(root: shipping_models.ShippingMethod, _info): + return graphene.Node.to_global_id("ShippingMethodType", root.id) + class PromotionTranslation(BaseTranslationType[discount_models.PromotionTranslation]): id = graphene.GlobalID( @@ -813,17 +1012,28 @@ class PromotionTranslation(BaseTranslationType[discount_models.PromotionTranslat description = JSONString( description="Translated description of the promotion." + RICH_CONTENT ) + translatable_content = graphene.Field( + "saleor.graphql.translations.types.PromotionTranslatableContent", + description="Represents the promotion fields to translate." + ADDED_IN_314, + ) class Meta: model = discount_models.Promotion interfaces = [graphene.relay.Node] description = "Represents promotion translations." + ADDED_IN_317 + @staticmethod + def resolve_translatable_content(root: discount_models.PromotionTranslation, info): + return PromotionByIdLoader(info.context).load(root.promotion_id) + class PromotionTranslatableContent(ModelObjectType[discount_models.Promotion]): id = graphene.GlobalID( required=True, description="ID of the promotion translatable content." ) + promotion_id = graphene.ID( + required=True, description="ID of the promotion to translate." + ) name = graphene.String(required=True, description="Name of the promotion.") description = JSONString(description="Description of the promotion." + RICH_CONTENT) translation = TranslationField(PromotionTranslation, type_name="promotion") @@ -836,6 +1046,10 @@ class Meta: "and related translations." + ADDED_IN_317 ) + @staticmethod + def resolve_promotion_id(root: discount_models.Promotion, _info): + return graphene.Node.to_global_id("Promotion", root.id) + class PromotionRuleTranslation( BaseTranslationType[discount_models.PromotionTranslation] @@ -847,17 +1061,31 @@ class PromotionRuleTranslation( description = JSONString( description="Translated description of the promotion rule." + RICH_CONTENT ) + translatable_content = graphene.Field( + "saleor.graphql.translations.types.PromotionRuleTranslatableContent", + description="Represents the promotion rule fields to translate." + ADDED_IN_314, + ) class Meta: model = discount_models.PromotionRule interfaces = [graphene.relay.Node] description = "Represents promotion rule translations." + ADDED_IN_317 + @staticmethod + def resolve_translatable_content( + root: discount_models.PromotionRuleTranslation, info + ): + return PromotionRuleByIdLoader(info.context).load(root.promotion_rule_id) + class PromotionRuleTranslatableContent(ModelObjectType[discount_models.Promotion]): id = graphene.GlobalID( required=True, description="ID of the promotion rule translatable content." ) + promotion_rule_id = graphene.ID( + required=True, + description="ID of the promotion rule to translate." + ADDED_IN_314, + ) name = graphene.String(description="Name of the promotion rule.") description = JSONString( description="Description of the promotion rule." + RICH_CONTENT @@ -871,3 +1099,7 @@ class Meta: "Represents promotion rule's original translatable fields " "and related translations." + ADDED_IN_317 ) + + @staticmethod + def resolve_promotion_rule_id(root: discount_models.PromotionRule, _info): + return graphene.Node.to_global_id("PromotionRule", root.id) diff --git a/saleor/graphql/warehouse/tests/benchmark/test_stock_bulk_update.py b/saleor/graphql/warehouse/tests/benchmark/test_stock_bulk_update.py index a4a3209698b..6aa21c31e2e 100644 --- a/saleor/graphql/warehouse/tests/benchmark/test_stock_bulk_update.py +++ b/saleor/graphql/warehouse/tests/benchmark/test_stock_bulk_update.py @@ -71,7 +71,7 @@ def test_stocks_bulk_update_queries_count( ] # test number of queries when single object is updated - with django_assert_num_queries(12): + with django_assert_num_queries(11): staff_api_client.user.user_permissions.add(permission_manage_products) response = staff_api_client.post_graphql( STOCKS_BULK_UPDATE_MUTATION, {"stocks": stocks_input} @@ -107,7 +107,7 @@ def test_stocks_bulk_update_queries_count( ] # Test number of queries when multiple objects are updated - with django_assert_num_queries(12): + with django_assert_num_queries(11): staff_api_client.user.user_permissions.add(permission_manage_products) response = staff_api_client.post_graphql( STOCKS_BULK_UPDATE_MUTATION, {"stocks": stocks_input} diff --git a/saleor/order/base_calculations.py b/saleor/order/base_calculations.py index 219b77f2c84..e5848c04eb8 100644 --- a/saleor/order/base_calculations.py +++ b/saleor/order/base_calculations.py @@ -109,6 +109,13 @@ def propagate_order_discount_on_order_prices( currency=currency, price_to_discount=subtotal, ) + elif order_discount.type == DiscountType.ORDER_PROMOTION: + subtotal = apply_discount_to_value( + value=order_discount.value, + value_type=order_discount.value_type, + currency=currency, + price_to_discount=subtotal, + ) elif order_discount.type == DiscountType.MANUAL: if order_discount.value_type == DiscountValueType.PERCENTAGE: subtotal = apply_discount_to_value( @@ -123,7 +130,7 @@ def propagate_order_discount_on_order_prices( currency=currency, price_to_discount=shipping_price, ) - else: + elif order_discount.value_type == DiscountValueType.FIXED: temporary_undiscounted_total = subtotal + shipping_price if temporary_undiscounted_total.amount > 0: temporary_total = apply_discount_to_value( diff --git a/saleor/order/calculations.py b/saleor/order/calculations.py index 8d649bd324d..4ac30e918f5 100644 --- a/saleor/order/calculations.py +++ b/saleor/order/calculations.py @@ -7,9 +7,11 @@ from django.db.models import prefetch_related_objects from prices import Money, TaxedMoney +from ..core.db.connection import allow_writer from ..core.prices import quantize_price from ..core.taxes import TaxData, TaxEmptyData, TaxError, zero_taxed_money from ..discount import DiscountType +from ..discount.utils import create_or_update_discount_objects_from_promotion_for_order from ..payment.model_helpers import get_subtotal from ..plugins import PLUGIN_IDENTIFIER_PREFIX from ..plugins.manager import PluginsManager @@ -24,6 +26,7 @@ ) from . import ORDER_EDITABLE_STATUS from .base_calculations import apply_order_discounts, base_order_line_total +from .fetch import DraftOrderLineInfo, fetch_draft_order_lines_info from .interface import OrderTaxedPricesData from .models import Order, OrderLine @@ -48,53 +51,62 @@ def fetch_order_prices_if_expired( if not force_update and not order.should_refresh_prices: return order, lines - if lines is None: - lines = list( - order.lines.using(database_connection_name).select_related( - "variant__product__product_type" - ) - ) - else: - prefetch_related_objects(lines, "variant__product__product_type") + # handle promotions + lines_info: list[DraftOrderLineInfo] = fetch_draft_order_lines_info(order, lines) + create_or_update_discount_objects_from_promotion_for_order( + order, lines_info, database_connection_name + ) + lines = [line_info.line for line_info in lines_info] + _update_order_discount_for_voucher(order) - order.should_refresh_prices = False + _clear_prefetched_discounts(order, lines) + prefetch_related_objects([order], "discounts") - _update_order_discount_for_voucher(order) + # handle taxes _recalculate_prices( - order, manager, lines, database_connection_name=database_connection_name + order, + manager, + lines, + database_connection_name=database_connection_name, ) - order.subtotal = get_subtotal(lines, order.currency) + order.should_refresh_prices = False with transaction.atomic(savepoint=False): - order.save( - update_fields=[ - "subtotal_net_amount", - "subtotal_gross_amount", - "total_net_amount", - "total_gross_amount", - "undiscounted_total_net_amount", - "undiscounted_total_gross_amount", - "shipping_price_net_amount", - "shipping_price_gross_amount", - "shipping_tax_rate", - "should_refresh_prices", - "tax_error", - ] - ) - order.lines.bulk_update( - lines, - [ - "unit_price_net_amount", - "unit_price_gross_amount", - "undiscounted_unit_price_net_amount", - "undiscounted_unit_price_gross_amount", - "total_price_net_amount", - "total_price_gross_amount", - "undiscounted_total_price_net_amount", - "undiscounted_total_price_gross_amount", - "tax_rate", - ], - ) + with allow_writer(): + order.save( + update_fields=[ + "subtotal_net_amount", + "subtotal_gross_amount", + "total_net_amount", + "total_gross_amount", + "undiscounted_total_net_amount", + "undiscounted_total_gross_amount", + "shipping_price_net_amount", + "shipping_price_gross_amount", + "shipping_tax_rate", + "should_refresh_prices", + "tax_error", + ] + ) + order.lines.bulk_update( + lines, + [ + "unit_price_net_amount", + "unit_price_gross_amount", + "undiscounted_unit_price_net_amount", + "undiscounted_unit_price_gross_amount", + "total_price_net_amount", + "total_price_gross_amount", + "undiscounted_total_price_net_amount", + "undiscounted_total_price_gross_amount", + "tax_rate", + "unit_discount_amount", + "unit_discount_reason", + "unit_discount_type", + "unit_discount_value", + "base_unit_price_amount", + ], + ) return order, lines @@ -102,7 +114,8 @@ def fetch_order_prices_if_expired( def _update_order_discount_for_voucher(order: Order): """Create or delete OrderDiscount instances.""" if not order.voucher_id: - order.discounts.filter(type=DiscountType.VOUCHER).delete() + with allow_writer(): + order.discounts.filter(type=DiscountType.VOUCHER).delete() elif ( order.voucher_id @@ -122,14 +135,14 @@ def _update_order_discount_for_voucher(order: Order): voucher_code=order.voucher_code, ) - # Prefetch has to be cleared and refreshed to avoid returning cached discounts - if ( - hasattr(order, "_prefetched_objects_cache") - and "discounts" in order._prefetched_objects_cache - ): - del order._prefetched_objects_cache["discounts"] - prefetch_related_objects([order], "discounts") +def _clear_prefetched_discounts(order, lines): + if hasattr(order, "_prefetched_objects_cache"): + order._prefetched_objects_cache.pop("discounts", None) + + for line in lines: + if hasattr(line, "_prefetched_objects_cache"): + line._prefetched_objects_cache.pop("discounts", None) def _recalculate_prices( @@ -332,6 +345,7 @@ def _recalculate_with_plugins( order.undiscounted_total = undiscounted_subtotal + TaxedMoney( net=order.base_shipping_price, gross=order.base_shipping_price ) + order.subtotal = get_subtotal(lines, order.currency) order.total = manager.calculate_order_total(order, lines, plugin_ids=plugin_ids) @@ -381,12 +395,14 @@ def _apply_tax_data( order_line.tax_rate = normalize_tax_rate_for_db(tax_line.tax_rate) subtotal += line_total_price + order.subtotal = subtotal order.total = shipping_price + subtotal def _remove_tax(order, lines): order.total_gross_amount = order.total_net_amount order.undiscounted_total_gross_amount = order.undiscounted_total_net_amount + order.subtotal_gross_amount = order.subtotal_net_amount order.shipping_price_gross_amount = order.shipping_price_net_amount order.shipping_tax_rate = Decimal("0.00") diff --git a/saleor/order/fetch.py b/saleor/order/fetch.py index 9d1b364e4a1..5bab6e8ab90 100644 --- a/saleor/order/fetch.py +++ b/saleor/order/fetch.py @@ -1,14 +1,21 @@ from collections.abc import Iterable from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import Optional, cast from uuid import UUID -if TYPE_CHECKING: - from ..channel.models import Channel - from ..discount.models import OrderLineDiscount - from ..payment.models import Payment - from ..product.models import DigitalContent, ProductVariant - from .models import Order, OrderLine +from django.db.models import prefetch_related_objects + +from ..channel.models import Channel +from ..discount import DiscountType +from ..discount.interface import VariantPromotionRuleInfo, fetch_variant_rules_info +from ..discount.models import OrderLineDiscount, Voucher +from ..payment.models import Payment +from ..product.models import ( + DigitalContent, + ProductVariant, + ProductVariantChannelListing, +) +from .models import Order, OrderLine @dataclass @@ -63,3 +70,78 @@ def fetch_order_lines(order: "Order") -> list[OrderLineInfo]: ) return lines_info + + +@dataclass +class DraftOrderLineInfo: + line: "OrderLine" + variant: "ProductVariant" + channel_listing: "ProductVariantChannelListing" + discounts: list["OrderLineDiscount"] + rules_info: list["VariantPromotionRuleInfo"] + channel: "Channel" + voucher: Optional["Voucher"] = None + + def get_promotion_discounts(self) -> list["OrderLineDiscount"]: + return [ + discount + for discount in self.discounts + if discount.type in [DiscountType.PROMOTION, DiscountType.ORDER_PROMOTION] + ] + + def get_catalogue_discounts(self) -> list["OrderLineDiscount"]: + return [ + discount + for discount in self.discounts + if discount.type == DiscountType.PROMOTION + ] + + +def fetch_draft_order_lines_info( + order: "Order", lines: Optional[Iterable["OrderLine"]] = None +) -> list[DraftOrderLineInfo]: + prefetch_related_fields = [ + "discounts__promotion_rule__promotion", + "variant__channel_listings__variantlistingpromotionrule__promotion_rule__promotion__translations", + "variant__channel_listings__variantlistingpromotionrule__promotion_rule__translations", + ] + if lines is None: + lines = list(order.lines.prefetch_related(*prefetch_related_fields)) + else: + prefetch_related_objects(lines, *prefetch_related_fields) + + lines_info = [] + channel = order.channel + for line in lines: + variant = cast(ProductVariant, line.variant) + variant_channel_listing = get_prefetched_variant_listing(variant, channel.id) + if not variant_channel_listing: + continue + + rules_info = ( + fetch_variant_rules_info(variant_channel_listing, order.language_code) + if not line.is_gift + else [] + ) + lines_info.append( + DraftOrderLineInfo( + line=line, + variant=variant, + channel_listing=variant_channel_listing, + discounts=list(line.discounts.all()), + rules_info=rules_info, + channel=channel, + ) + ) + return lines_info + + +def get_prefetched_variant_listing( + variant: Optional[ProductVariant], channel_id: int +) -> Optional[ProductVariantChannelListing]: + if not variant: + return None + for channel_listing in variant.channel_listings.all(): + if channel_listing.channel_id == channel_id: + return channel_listing + return None diff --git a/saleor/order/migrations/0172_update_order_cc_addresses.py b/saleor/order/migrations/0172_update_order_cc_addresses.py new file mode 100644 index 00000000000..c686fdfd5fa --- /dev/null +++ b/saleor/order/migrations/0172_update_order_cc_addresses.py @@ -0,0 +1,52 @@ +from django.db import migrations +from django.db.models import Exists, OuterRef +from django.forms.models import model_to_dict + +# The batch of size 250 takes ~0.5 second and consumes ~20MB memory at peak +ADDRESS_UPDATE_BATCH_SIZE = 250 + + +def queryset_in_batches(queryset): + """Slice a queryset into batches. + + Input queryset should be sorted be pk. + """ + start_pk = 0 + + while True: + qs = queryset.filter(pk__gt=start_pk)[:ADDRESS_UPDATE_BATCH_SIZE] + pks = list(qs.values_list("pk", flat=True)) + if not pks: + break + yield pks + start_pk = pks[-1] + + +def update_order_addresses(apps, schema_editor): + Order = apps.get_model("order", "Order") + Warehouse = apps.get_model("warehouse", "Warehouse") + Address = apps.get_model("account", "Address") + queryset = Order.objects.filter( + Exists(Warehouse.objects.filter(address_id=OuterRef("shipping_address_id"))), + ).order_by("pk") + + for order_ids in queryset_in_batches(queryset): + orders = Order.objects.filter(id__in=order_ids) + addresses = [] + for order in orders: + if cc_address := order.shipping_address: + order_address = Address(**model_to_dict(cc_address, exclude=["id"])) + order.shipping_address = order_address + addresses.append(order_address) + Address.objects.bulk_create(addresses, ignore_conflicts=True) + Order.objects.bulk_update(orders, ["shipping_address"]) + + +class Migration(migrations.Migration): + dependencies = [ + ("order", "0171_order_order_user_email_user_id_idx"), + ] + + operations = [ + migrations.RunPython(update_order_addresses, migrations.RunPython.noop), + ] diff --git a/saleor/order/migrations/0176_merge_20240325_1315.py b/saleor/order/migrations/0176_merge_20240325_1315.py new file mode 100644 index 00000000000..7bb6b948017 --- /dev/null +++ b/saleor/order/migrations/0176_merge_20240325_1315.py @@ -0,0 +1,12 @@ +# Generated by Django 3.2.22 on 2024-03-25 13:15 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("order", "0172_update_order_cc_addresses"), + ("order", "0175_merge_20231122_1040"), + ] + + operations = [] diff --git a/saleor/order/migrations/0177_merge_20240325_1329.py b/saleor/order/migrations/0177_merge_20240325_1329.py new file mode 100644 index 00000000000..2fd49329770 --- /dev/null +++ b/saleor/order/migrations/0177_merge_20240325_1329.py @@ -0,0 +1,12 @@ +# Generated by Django 3.2.22 on 2024-03-25 13:29 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("order", "0176_merge_20231122_1346"), + ("order", "0176_merge_20240325_1315"), + ] + + operations = [] diff --git a/saleor/order/migrations/0180_merge_20240325_1333.py b/saleor/order/migrations/0180_merge_20240325_1333.py new file mode 100644 index 00000000000..675bc15cab5 --- /dev/null +++ b/saleor/order/migrations/0180_merge_20240325_1333.py @@ -0,0 +1,12 @@ +# Generated by Django 3.2.22 on 2024-03-25 13:33 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("order", "0177_merge_20240325_1329"), + ("order", "0179_merge_20231122_1348"), + ] + + operations = [] diff --git a/saleor/order/migrations/0182_merge_20240325_1338.py b/saleor/order/migrations/0182_merge_20240325_1338.py new file mode 100644 index 00000000000..d11d3a3d82d --- /dev/null +++ b/saleor/order/migrations/0182_merge_20240325_1338.py @@ -0,0 +1,12 @@ +# Generated by Django 3.2.22 on 2024-03-25 13:38 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("order", "0180_merge_20240325_1333"), + ("order", "0181_order_subtotal_as_a_field"), + ] + + operations = [] diff --git a/saleor/order/migrations/0184_merge_0182_merge_20240325_1338_0183_order_tax_error.py b/saleor/order/migrations/0184_merge_0182_merge_20240325_1338_0183_order_tax_error.py new file mode 100644 index 00000000000..0b8a3c21901 --- /dev/null +++ b/saleor/order/migrations/0184_merge_0182_merge_20240325_1338_0183_order_tax_error.py @@ -0,0 +1,12 @@ +# Generated by Django 3.2.22 on 2024-03-25 13:42 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("order", "0182_merge_20240325_1338"), + ("order", "0183_order_tax_error"), + ] + + operations = [] diff --git a/saleor/order/models.py b/saleor/order/models.py index 79ce62b9a01..469dc016593 100644 --- a/saleor/order/models.py +++ b/saleor/order/models.py @@ -498,9 +498,6 @@ def can_mark_as_paid(self, payments=None): def total_balance(self): return self.total_charged - self.total.gross - def get_total_weight(self, _lines=None): - return self.weight - class OrderLineQueryset(models.QuerySet["OrderLine"]): def digital(self): diff --git a/saleor/plugins/avatax/tests/deprecated/__init__.py b/saleor/order/tests/benchmark/__init__.py similarity index 100% rename from saleor/plugins/avatax/tests/deprecated/__init__.py rename to saleor/order/tests/benchmark/__init__.py diff --git a/saleor/order/tests/benchmark/test_fetch_order_prices.py b/saleor/order/tests/benchmark/test_fetch_order_prices.py new file mode 100644 index 00000000000..3a2a866ad80 --- /dev/null +++ b/saleor/order/tests/benchmark/test_fetch_order_prices.py @@ -0,0 +1,144 @@ +from decimal import Decimal + +import pytest +from prices import Money, TaxedMoney + +from ....discount import RewardValueType +from ....discount.models import OrderDiscount, OrderLineDiscount, PromotionRule +from ....product.models import Product +from ....product.utils.variant_prices import update_discounted_prices_for_promotion +from ....product.utils.variants import fetch_variants_for_promotion_rules +from ....tax import TaxCalculationStrategy +from ....warehouse.models import Stock +from ... import OrderStatus +from ...calculations import fetch_order_prices_if_expired +from ...models import OrderLine + + +@pytest.fixture +def order_with_lines(order_with_lines): + order_with_lines.status = OrderStatus.UNCONFIRMED + return order_with_lines + + +@pytest.mark.django_db +@pytest.mark.count_queries(autouse=False) +def test_fetch_order_prices_catalogue_discount( + order_with_lines_and_catalogue_promotion, + plugins_manager, + django_assert_num_queries, + count_queries, +): + # given + OrderLineDiscount.objects.all().delete() + order = order_with_lines_and_catalogue_promotion + channel = order.channel + + tc = channel.tax_configuration + tc.country_exceptions.all().delete() + tc.prices_entered_with_tax = False + tc.tax_calculation_strategy = TaxCalculationStrategy.FLAT_RATES + tc.save() + + # when + with django_assert_num_queries(38): + fetch_order_prices_if_expired(order, plugins_manager, None, True) + + # then + assert OrderLineDiscount.objects.count() == 1 + assert not OrderDiscount.objects.exists() + + +@pytest.mark.django_db +@pytest.mark.count_queries(autouse=False) +def test_fetch_order_prices_multiple_catalogue_discounts( + order_with_lines, + catalogue_promotion_without_rules, + plugins_manager, + product_variant_list, + django_assert_num_queries, + count_queries, +): + # given + Stock.objects.update(quantity=100) + order = order_with_lines + channel = order.channel + + variants = product_variant_list + variants.extend(line.variant for line in order.lines.all()) + variant_global_ids = [variant.get_global_id() for variant in variants] + + # create many rules + promotion = catalogue_promotion_without_rules + rules = [] + catalogue_predicate = {"variantPredicate": {"ids": variant_global_ids}} + for idx in range(5): + reward_value = 2 + idx + rules.append( + PromotionRule( + name=f"Catalogue rule fixed {reward_value}", + promotion=promotion, + catalogue_predicate=catalogue_predicate, + reward_value_type=RewardValueType.FIXED, + reward_value=Decimal(reward_value), + ) + ) + for idx in range(5): + reward_value = idx * 10 + 25 + rules.append( + PromotionRule( + name=f"Catalogue rule percentage {reward_value}", + promotion=promotion, + catalogue_predicate=catalogue_predicate, + reward_value_type=RewardValueType.PERCENTAGE, + reward_value=Decimal(reward_value), + ) + ) + rules = PromotionRule.objects.bulk_create(rules) + for rule in rules: + rule.channels.add(channel) + + fetch_variants_for_promotion_rules(PromotionRule.objects.all()) + + # update prices + update_discounted_prices_for_promotion(Product.objects.all()) + + # create lines + new_order_lines = [] + for idx, variant in enumerate(product_variant_list): + base_price = variant.channel_listings.first().discounted_price + currency = base_price.currency + gross = Money(amount=base_price.amount * Decimal(1.23), currency=currency) + quantity = 4 + idx + unit_price = TaxedMoney(net=base_price, gross=gross) + new_order_lines.append( + OrderLine( + order=order, + product_name=str(variant.product), + variant_name=str(variant), + product_sku=variant.sku, + product_variant_id=variant.get_global_id(), + is_shipping_required=variant.is_shipping_required(), + is_gift_card=variant.is_gift_card(), + quantity=quantity, + variant=variant, + unit_price=unit_price, + total_price=unit_price * quantity, + tax_rate=Decimal("0.23"), + ) + ) + OrderLine.objects.bulk_create(new_order_lines) + + tc = channel.tax_configuration + tc.country_exceptions.all().delete() + tc.prices_entered_with_tax = False + tc.tax_calculation_strategy = TaxCalculationStrategy.FLAT_RATES + tc.save() + + # when + with django_assert_num_queries(38): + fetch_order_prices_if_expired(order, plugins_manager, None, True) + + # then + assert OrderLineDiscount.objects.count() == 7 + assert not OrderDiscount.objects.exists() diff --git a/saleor/order/tests/test_discount_calculations.py b/saleor/order/tests/test_apply_order_discount.py similarity index 94% rename from saleor/order/tests/test_discount_calculations.py rename to saleor/order/tests/test_apply_order_discount.py index 66767e5d094..f22c5dc56f8 100644 --- a/saleor/order/tests/test_discount_calculations.py +++ b/saleor/order/tests/test_apply_order_discount.py @@ -1,7 +1,7 @@ from decimal import Decimal import pytest -from prices import Money, TaxedMoney +from prices import Money from ...core.prices import quantize_price from ...core.taxes import zero_money @@ -9,51 +9,7 @@ from ...order.base_calculations import ( apply_order_discounts, apply_subtotal_discount_to_order_lines, - base_order_line_total, - base_order_total, ) -from ...order.interface import OrderTaxedPricesData - - -def test_base_order_total(order_with_lines): - # given - order = order_with_lines - lines = order.lines.all() - shipping_price = order.shipping_price.net - subtotal = zero_money(order.currency) - for line in lines: - subtotal += line.base_unit_price * line.quantity - undiscounted_total = subtotal + shipping_price - - # when - order_total = base_order_total(order, lines) - - # then - assert order_total == undiscounted_total - - -def test_base_order_line_total(order_with_lines): - # given - line = order_with_lines.lines.all().first() - - # when - order_total = base_order_line_total(line) - - # then - base_line_unit_price = line.base_unit_price - quantity = line.quantity - expected_price_with_discount = ( - TaxedMoney(base_line_unit_price, base_line_unit_price) * quantity - ) - base_line_undiscounted_unit_price = line.undiscounted_base_unit_price - expected_undiscounted_price = ( - TaxedMoney(base_line_undiscounted_unit_price, base_line_undiscounted_unit_price) - * quantity - ) - assert order_total == OrderTaxedPricesData( - price_with_discounts=expected_price_with_discount, - undiscounted_price=expected_undiscounted_price, - ) def test_apply_order_discounts_voucher_entire_order(order_with_lines, voucher): diff --git a/saleor/order/tests/test_fetch.py b/saleor/order/tests/test_fetch.py new file mode 100644 index 00000000000..684d97115d1 --- /dev/null +++ b/saleor/order/tests/test_fetch.py @@ -0,0 +1,73 @@ +from decimal import Decimal + +import pytest + +from ...product.models import ProductVariantChannelListing +from ..fetch import fetch_draft_order_lines_info + + +@pytest.mark.django_db +@pytest.mark.count_queries(autouse=False) +def test_fetch_draft_order_lines_info( + draft_order_and_promotions, django_assert_num_queries, count_queries +): + # given + order, rule_catalogue, rule_total, rule_gift = draft_order_and_promotions + channel = order.channel + lines = order.lines.all() + line_1 = [line for line in lines if line.quantity == 3][0] + line_2 = [line for line in lines if line.quantity == 2][0] + + manual_discount = line_1.discounts.create( + value=Decimal(1), + amount_value=Decimal(1), + currency=channel.currency_code, + name="Manual line discount", + ) + rule_translation = "Rule translation" + rule_catalogue.translations.create( + language_code=order.language_code, name=rule_translation + ) + promotion_translation = "Promotion" + rule_catalogue.promotion.translations.create( + language_code=order.language_code, name=promotion_translation + ) + + # when + with django_assert_num_queries(9): + lines_info = fetch_draft_order_lines_info(order) + + # then + line_info_1 = [line_info for line_info in lines_info if line_info.line == line_1][0] + line_info_2 = [line_info for line_info in lines_info if line_info.line == line_2][0] + + variant_1 = line_1.variant + assert line_info_1.variant == variant_1 + assert ( + line_info_1.channel_listing + == ProductVariantChannelListing.objects.filter( + channel=channel, variant=variant_1 + ).first() + ) + assert line_info_1.discounts == [manual_discount] + assert line_info_1.channel == channel + assert line_info_1.voucher is None + assert line_info_1.rules_info == [] + + variant_2 = line_2.variant + assert line_info_2.variant == variant_2 + assert ( + line_info_2.channel_listing + == ProductVariantChannelListing.objects.filter( + channel=channel, variant=variant_2 + ).first() + ) + assert line_info_2.discounts == [] + assert line_info_2.channel == channel + assert line_info_2.voucher is None + rule_info_2 = line_info_2.rules_info[0] + assert rule_info_2.rule == rule_catalogue + assert rule_info_2.variant_listing_promotion_rule + assert rule_info_2.promotion == rule_catalogue.promotion + assert rule_info_2.promotion_translation.name == promotion_translation + assert rule_info_2.rule_translation.name == rule_translation diff --git a/saleor/order/tests/test_fetch_order_prices.py b/saleor/order/tests/test_fetch_order_prices.py new file mode 100644 index 00000000000..5ccf22d535b --- /dev/null +++ b/saleor/order/tests/test_fetch_order_prices.py @@ -0,0 +1,1394 @@ +from decimal import Decimal + +import before_after +import graphene +import pytest + +from ...core.prices import quantize_price +from ...discount import DiscountType, DiscountValueType +from ...discount.models import ( + OrderDiscount, + OrderLineDiscount, + PromotionRule, +) +from ...tax import TaxCalculationStrategy +from ...tests.utils import round_down, round_up +from .. import OrderStatus, calculations + + +@pytest.fixture +def order_with_lines(order_with_lines): + order_with_lines.status = OrderStatus.UNCONFIRMED + return order_with_lines + + +@pytest.mark.parametrize("create_new_discounts", [True, False]) +def test_fetch_order_prices_catalogue_discount_flat_rates( + order_with_lines_and_catalogue_promotion, + plugins_manager, + create_new_discounts, +): + # given + if create_new_discounts: + OrderLineDiscount.objects.all().delete() + + order = order_with_lines_and_catalogue_promotion + channel = order.channel + rule = PromotionRule.objects.get() + promotion_id = graphene.Node.to_global_id("Promotion", rule.promotion_id) + reward_value = rule.reward_value + + tc = channel.tax_configuration + tc.country_exceptions.all().delete() + tc.prices_entered_with_tax = False + tc.tax_calculation_strategy = TaxCalculationStrategy.FLAT_RATES + tc.save() + tax_rate = Decimal("1.23") + + # when + order, lines = calculations.fetch_order_prices_if_expired( + order, plugins_manager, None, True + ) + + # then + assert OrderLineDiscount.objects.count() == 1 + assert not OrderDiscount.objects.exists() + line_1 = [line for line in lines if line.quantity == 3][0] + line_2 = [line for line in lines if line.quantity == 2][0] + + discount = line_1.discounts.get() + reward_amount = reward_value * line_1.quantity + assert discount.amount_value == reward_amount + assert discount.value == reward_value + assert discount.value_type == DiscountValueType.FIXED + assert discount.type == DiscountType.PROMOTION + assert discount.reason == f"Promotion: {promotion_id}" + + variant_1 = line_1.variant + variant_1_listing = variant_1.channel_listings.get(channel=channel) + variant_1_unit_price = variant_1_listing.discounted_price_amount + variant_1_undiscounted_unit_price = variant_1_listing.price_amount + assert variant_1_undiscounted_unit_price - variant_1_unit_price == reward_value + + assert ( + line_1.undiscounted_total_price_net_amount + == variant_1_undiscounted_unit_price * line_1.quantity + ) + assert ( + line_1.undiscounted_total_price_gross_amount + == line_1.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_1.undiscounted_unit_price_net_amount == variant_1_undiscounted_unit_price + ) + assert ( + line_1.undiscounted_unit_price_gross_amount + == variant_1_undiscounted_unit_price * tax_rate + ) + assert ( + line_1.base_unit_price_amount + == variant_1_undiscounted_unit_price - reward_value + ) + assert ( + line_1.unit_price_net_amount == variant_1_undiscounted_unit_price - reward_value + ) + assert line_1.unit_price_gross_amount == line_1.unit_price_net_amount * tax_rate + assert ( + line_1.total_price_net_amount == line_1.unit_price_net_amount * line_1.quantity + ) + assert line_1.total_price_gross_amount == line_1.total_price_net_amount * tax_rate + + variant_2 = line_2.variant + variant_2_listing = variant_2.channel_listings.get(channel=channel) + variant_2_undiscounted_unit_price = variant_2_listing.price_amount + assert ( + line_2.undiscounted_total_price_net_amount + == variant_2_undiscounted_unit_price * line_2.quantity + ) + assert ( + line_2.undiscounted_total_price_gross_amount + == line_2.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_2.undiscounted_unit_price_net_amount == variant_2_undiscounted_unit_price + ) + assert ( + line_2.undiscounted_unit_price_gross_amount + == variant_2_undiscounted_unit_price * tax_rate + ) + assert line_2.base_unit_price_amount == variant_2_undiscounted_unit_price + assert line_2.unit_price_net_amount == variant_2_undiscounted_unit_price + assert ( + line_2.unit_price_gross_amount == variant_2_undiscounted_unit_price * tax_rate + ) + assert line_2.total_price_net_amount == line_2.undiscounted_total_price_net_amount + assert ( + line_2.total_price_gross_amount == line_2.undiscounted_total_price_gross_amount + ) + + shipping_net_price = order.shipping_price_net_amount + assert ( + order.undiscounted_total_net_amount + == line_1.undiscounted_total_price_net_amount + + line_2.undiscounted_total_price_net_amount + + shipping_net_price + ) + assert ( + order.undiscounted_total_gross_amount + == order.undiscounted_total_net_amount * tax_rate + ) + assert order.total_net_amount == order.undiscounted_total_net_amount - reward_amount + assert order.total_gross_amount == order.total_net_amount * tax_rate + assert order.subtotal_net_amount == order.total_net_amount - shipping_net_price + assert order.subtotal_gross_amount == order.subtotal_net_amount * tax_rate + + assert line_1.unit_discount_amount == reward_value + assert line_1.unit_discount_reason == f"Promotion: {promotion_id}" + assert line_1.unit_discount_type == DiscountValueType.FIXED + assert line_1.unit_discount_value == reward_value + + +@pytest.mark.parametrize("create_new_discounts", [True, False]) +def test_fetch_order_prices_order_discount_flat_rates( + order_with_lines_and_order_promotion, + plugins_manager, + create_new_discounts, +): + # given + if create_new_discounts: + OrderDiscount.objects.all().delete() + + order = order_with_lines_and_order_promotion + currency = order.currency + rule = PromotionRule.objects.get() + reward_amount = rule.reward_value + promotion_id = graphene.Node.to_global_id("Promotion", rule.promotion_id) + + tc = order.channel.tax_configuration + tc.country_exceptions.all().delete() + tc.prices_entered_with_tax = False + tc.tax_calculation_strategy = TaxCalculationStrategy.FLAT_RATES + tc.save() + tax_rate = Decimal("1.23") + + # when + order, lines = calculations.fetch_order_prices_if_expired( + order, plugins_manager, None, True + ) + + # then + assert not OrderLineDiscount.objects.exists() + line_1 = [line for line in lines if line.quantity == 3][0] + line_2 = [line for line in lines if line.quantity == 2][0] + discount = OrderDiscount.objects.get() + + line_1_base_total = line_1.quantity * line_1.base_unit_price_amount + line_2_base_total = line_2.quantity * line_2.base_unit_price_amount + base_total = line_1_base_total + line_2_base_total + line_1_order_discount_portion = reward_amount * line_1_base_total / base_total + line_2_order_discount_portion = reward_amount - line_1_order_discount_portion + + assert discount.order == order + assert discount.amount_value == reward_amount + assert discount.value == reward_amount + assert discount.value_type == DiscountValueType.FIXED + assert discount.type == DiscountType.ORDER_PROMOTION + assert discount.reason == f"Promotion: {promotion_id}" + + variant_1 = line_1.variant + variant_1_listing = variant_1.channel_listings.get(channel=order.channel) + variant_1_undiscounted_unit_price = variant_1_listing.price_amount + line_1_total_net_amount = quantize_price( + line_1.undiscounted_total_price_net_amount - line_1_order_discount_portion, + currency, + ) + assert ( + line_1.undiscounted_total_price_net_amount + == variant_1_undiscounted_unit_price * line_1.quantity + ) + assert ( + line_1.undiscounted_total_price_gross_amount + == line_1.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_1.undiscounted_unit_price_net_amount == variant_1_undiscounted_unit_price + ) + assert ( + line_1.undiscounted_unit_price_gross_amount + == variant_1_undiscounted_unit_price * tax_rate + ) + assert line_1.total_price_net_amount == line_1_total_net_amount + assert line_1.total_price_gross_amount == round_down( + line_1_total_net_amount * tax_rate + ) + assert line_1.base_unit_price_amount == variant_1_undiscounted_unit_price + assert line_1.unit_price_net_amount == line_1_total_net_amount / line_1.quantity + assert line_1.unit_price_gross_amount == quantize_price( + line_1.unit_price_net_amount * tax_rate, currency + ) + + variant_2 = line_2.variant + variant_2_listing = variant_2.channel_listings.get(channel=order.channel) + variant_2_undiscounted_unit_price = variant_2_listing.price_amount + line_2_total_net_amount = quantize_price( + line_2.undiscounted_total_price_net_amount - line_2_order_discount_portion, + currency, + ) + assert ( + line_2.undiscounted_total_price_net_amount + == variant_2_undiscounted_unit_price * line_2.quantity + ) + assert ( + line_2.undiscounted_total_price_gross_amount + == line_2.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_2.undiscounted_unit_price_net_amount == variant_2_undiscounted_unit_price + ) + assert ( + line_2.undiscounted_unit_price_gross_amount + == variant_2_undiscounted_unit_price * tax_rate + ) + assert line_2.total_price_net_amount == line_2_total_net_amount + assert line_2.total_price_gross_amount == round_up( + line_2_total_net_amount * tax_rate + ) + assert line_2.base_unit_price_amount == variant_2_undiscounted_unit_price + assert line_2.unit_price_net_amount == quantize_price( + line_2_total_net_amount / line_2.quantity, currency + ) + assert line_2.unit_price_gross_amount == round_down( + line_2.unit_price_net_amount * tax_rate + ) + + shipping_price = order.shipping_price_net_amount + assert ( + order.undiscounted_total_net_amount + == line_1.undiscounted_total_price_net_amount + + line_2.undiscounted_total_price_net_amount + + shipping_price + ) + assert ( + order.undiscounted_total_gross_amount + == order.undiscounted_total_net_amount * tax_rate + ) + assert ( + order.total_net_amount + == line_1_total_net_amount + line_2_total_net_amount + shipping_price + ) + assert order.total_gross_amount == order.total_net_amount * tax_rate + assert ( + order.subtotal_net_amount == line_1_total_net_amount + line_2_total_net_amount + ) + assert order.subtotal_gross_amount == order.subtotal_net_amount * tax_rate + + +@pytest.mark.parametrize("create_new_discounts", [True, False]) +def test_fetch_order_prices_gift_discount_flat_rates( + order_with_lines_and_gift_promotion, + plugins_manager, + create_new_discounts, +): + # given + if create_new_discounts: + OrderLineDiscount.objects.all().delete() + + order = order_with_lines_and_gift_promotion + rule = PromotionRule.objects.get() + promotion_id = graphene.Node.to_global_id("Promotion", rule.promotion_id) + + tc = order.channel.tax_configuration + tc.country_exceptions.all().delete() + tc.prices_entered_with_tax = False + tc.tax_calculation_strategy = TaxCalculationStrategy.FLAT_RATES + tc.save() + tax_rate = Decimal("1.23") + + # when + order, lines = calculations.fetch_order_prices_if_expired( + order, plugins_manager, None, True + ) + + # then + assert len(lines) == 3 + line_1 = [line for line in lines if line.quantity == 3][0] + line_2 = [line for line in lines if line.quantity == 2][0] + gift_line = [line for line in lines if line.is_gift][0] + assert not line_1.discounts.exists() + assert not line_2.discounts.exists() + discount = OrderLineDiscount.objects.get() + + variant_gift = gift_line.variant + variant_gift_listing = variant_gift.channel_listings.get(channel=order.channel) + variant_gift_undiscounted_unit_price = variant_gift_listing.price_amount + + assert discount.line == gift_line + assert discount.amount_value == variant_gift_undiscounted_unit_price + assert discount.value == variant_gift_undiscounted_unit_price + assert discount.value_type == DiscountValueType.FIXED + assert discount.type == DiscountType.ORDER_PROMOTION + assert discount.reason == f"Promotion: {promotion_id}" + + assert gift_line.unit_discount_amount == variant_gift_undiscounted_unit_price + assert gift_line.unit_discount_reason == f"Promotion: {promotion_id}" + assert gift_line.unit_discount_type == DiscountValueType.FIXED + assert gift_line.unit_discount_value == variant_gift_undiscounted_unit_price + assert gift_line.undiscounted_total_price_net_amount == Decimal(0) + assert gift_line.undiscounted_total_price_gross_amount == Decimal(0) + assert gift_line.undiscounted_unit_price_net_amount == Decimal(0) + assert gift_line.undiscounted_unit_price_gross_amount == Decimal(0) + assert gift_line.total_price_net_amount == Decimal(0) + assert gift_line.total_price_gross_amount == Decimal(0) + assert gift_line.base_unit_price_amount == Decimal(0) + assert gift_line.unit_price_net_amount == Decimal(0) + assert gift_line.unit_price_gross_amount == Decimal(0) + + variant_1 = line_1.variant + variant_1_listing = variant_1.channel_listings.get(channel=order.channel) + variant_1_undiscounted_unit_price = variant_1_listing.price_amount + assert ( + line_1.undiscounted_total_price_net_amount + == variant_1_undiscounted_unit_price * line_1.quantity + ) + assert ( + line_1.undiscounted_total_price_gross_amount + == line_1.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_1.undiscounted_unit_price_net_amount == variant_1_undiscounted_unit_price + ) + assert ( + line_1.undiscounted_unit_price_gross_amount + == variant_1_undiscounted_unit_price * tax_rate + ) + assert line_1.total_price_net_amount == line_1.undiscounted_total_price_net_amount + assert ( + line_1.total_price_gross_amount == line_1.undiscounted_total_price_gross_amount + ) + assert line_1.base_unit_price_amount == line_1.undiscounted_unit_price_net_amount + assert line_1.unit_price_net_amount == line_1.undiscounted_unit_price_net_amount + assert line_1.unit_price_gross_amount == line_1.undiscounted_unit_price_gross_amount + + variant_2 = line_2.variant + variant_2_listing = variant_2.channel_listings.get(channel=order.channel) + variant_2_undiscounted_unit_price = variant_2_listing.price_amount + assert ( + line_2.undiscounted_total_price_net_amount + == variant_2_undiscounted_unit_price * line_2.quantity + ) + assert ( + line_2.undiscounted_total_price_gross_amount + == line_2.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_2.undiscounted_unit_price_net_amount == variant_2_undiscounted_unit_price + ) + assert ( + line_2.undiscounted_unit_price_gross_amount + == variant_2_undiscounted_unit_price * tax_rate + ) + assert line_2.total_price_net_amount == line_2.undiscounted_total_price_net_amount + assert ( + line_2.total_price_gross_amount == line_2.undiscounted_total_price_gross_amount + ) + assert line_2.base_unit_price_amount == line_2.undiscounted_unit_price_net_amount + assert line_2.unit_price_net_amount == line_2.undiscounted_unit_price_net_amount + assert line_2.unit_price_gross_amount == line_2.undiscounted_unit_price_gross_amount + + shipping_price = order.shipping_price_net_amount + assert ( + order.undiscounted_total_net_amount + == line_1.undiscounted_total_price_net_amount + + line_2.undiscounted_total_price_net_amount + + shipping_price + ) + assert ( + order.undiscounted_total_gross_amount + == order.undiscounted_total_net_amount * tax_rate + ) + assert order.total_net_amount == order.undiscounted_total_net_amount + assert order.total_gross_amount == order.undiscounted_total_gross_amount + assert ( + order.subtotal_net_amount + == line_1.undiscounted_total_price_net_amount + + line_2.undiscounted_total_price_net_amount + ) + assert order.subtotal_gross_amount == order.subtotal_net_amount * tax_rate + + +def test_fetch_order_prices_catalogue_and_order_discounts_flat_rates( + draft_order_and_promotions, + plugins_manager, +): + # given + order, rule_catalogue, rule_total, _ = draft_order_and_promotions + catalogue_promotion_id = graphene.Node.to_global_id( + "Promotion", rule_catalogue.promotion_id + ) + order_promotion_id = graphene.Node.to_global_id( + "Promotion", rule_total.promotion_id + ) + rule_catalogue_reward = rule_catalogue.reward_value + rule_total_reward = rule_total.reward_value + currency = order.currency + + tc = order.channel.tax_configuration + tc.country_exceptions.all().delete() + tc.prices_entered_with_tax = False + tc.tax_calculation_strategy = TaxCalculationStrategy.FLAT_RATES + tc.save() + tax_rate = Decimal("1.23") + + # when + order, lines = calculations.fetch_order_prices_if_expired( + order, plugins_manager, None, True + ) + + # then + line_1 = [line for line in lines if line.quantity == 3][0] + line_2 = [line for line in lines if line.quantity == 2][0] + catalogue_discount = OrderLineDiscount.objects.get() + order_discount = OrderDiscount.objects.get() + + line_1_base_total = line_1.quantity * line_1.base_unit_price_amount + line_2_base_total = line_2.quantity * line_2.base_unit_price_amount + base_total = line_1_base_total + line_2_base_total + line_1_order_discount_portion = rule_total_reward * line_1_base_total / base_total + line_2_order_discount_portion = rule_total_reward - line_1_order_discount_portion + + assert order_discount.order == order + assert order_discount.amount_value == rule_total_reward + assert order_discount.value == rule_total_reward + assert order_discount.value_type == DiscountValueType.FIXED + assert order_discount.type == DiscountType.ORDER_PROMOTION + assert order_discount.reason == f"Promotion: {order_promotion_id}" + + variant_1 = line_1.variant + variant_1_listing = variant_1.channel_listings.get(channel=order.channel) + variant_1_undiscounted_unit_price = variant_1_listing.price_amount + line_1_total_net_amount = quantize_price( + line_1.undiscounted_total_price_net_amount - line_1_order_discount_portion, + currency, + ) + assert not line_1.discounts.exists() + assert ( + line_1.undiscounted_total_price_net_amount + == variant_1_undiscounted_unit_price * line_1.quantity + ) + assert ( + line_1.undiscounted_total_price_gross_amount + == line_1.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_1.undiscounted_unit_price_net_amount == variant_1_undiscounted_unit_price + ) + assert ( + line_1.undiscounted_unit_price_gross_amount + == variant_1_undiscounted_unit_price * tax_rate + ) + assert line_1.base_unit_price_amount == variant_1_undiscounted_unit_price + assert line_1.total_price_net_amount == line_1_total_net_amount + assert line_1.total_price_gross_amount == round_up( + line_1_total_net_amount * tax_rate + ) + assert line_1.unit_price_net_amount == quantize_price( + line_1_total_net_amount / line_1.quantity, currency + ) + assert line_1.unit_price_gross_amount == round_up( + line_1.unit_price_net_amount * tax_rate + ) + + assert catalogue_discount.line == line_2 + assert catalogue_discount.amount_value == rule_catalogue_reward * line_2.quantity + assert catalogue_discount.value == rule_catalogue_reward + assert catalogue_discount.value_type == DiscountValueType.FIXED + assert catalogue_discount.type == DiscountType.PROMOTION + assert catalogue_discount.reason == f"Promotion: {catalogue_promotion_id}" + + variant_2 = line_2.variant + variant_2_listing = variant_2.channel_listings.get(channel=order.channel) + variant_2_undiscounted_unit_price = variant_2_listing.price_amount + line_2_total_net_amount = quantize_price( + line_2.undiscounted_total_price_net_amount + - line_2_order_discount_portion + - catalogue_discount.amount_value, + currency, + ) + assert ( + line_2.undiscounted_total_price_net_amount + == variant_2_undiscounted_unit_price * line_2.quantity + ) + assert ( + line_2.undiscounted_total_price_gross_amount + == line_2.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_2.undiscounted_unit_price_net_amount == variant_2_undiscounted_unit_price + ) + assert ( + line_2.undiscounted_unit_price_gross_amount + == variant_2_undiscounted_unit_price * tax_rate + ) + assert ( + line_2.base_unit_price_amount + == variant_2_undiscounted_unit_price - rule_catalogue_reward + ) + assert line_2.total_price_net_amount == line_2_total_net_amount + assert line_2.total_price_gross_amount == round_down( + line_2_total_net_amount * tax_rate + ) + assert line_2.unit_price_net_amount == quantize_price( + line_2_total_net_amount / line_2.quantity, currency + ) + assert line_2.unit_price_gross_amount == quantize_price( + line_2.unit_price_net_amount * tax_rate, currency + ) + + shipping_price = order.shipping_price_net_amount + total_net_amount = quantize_price( + order.undiscounted_total_net_amount + - order_discount.amount_value + - catalogue_discount.amount_value, + currency, + ) + assert ( + order.undiscounted_total_net_amount + == line_1.undiscounted_total_price_net_amount + + line_2.undiscounted_total_price_net_amount + + shipping_price + ) + assert ( + order.undiscounted_total_gross_amount + == order.undiscounted_total_net_amount * tax_rate + ) + assert order.total_net_amount == total_net_amount + assert order.total_gross_amount == quantize_price( + total_net_amount * tax_rate, currency + ) + assert ( + order.subtotal_net_amount == line_1_total_net_amount + line_2_total_net_amount + ) + assert order.subtotal_gross_amount == quantize_price( + order.subtotal_net_amount * tax_rate, currency + ) + + +def test_fetch_order_prices_catalogue_and_gift_discounts_flat_rates( + draft_order_and_promotions, + plugins_manager, +): + # given + order, rule_catalogue, rule_total, rule_gift = draft_order_and_promotions + rule_total.reward_value = Decimal(0) + rule_total.save(update_fields=["reward_value"]) + + catalogue_promotion_id = graphene.Node.to_global_id( + "Promotion", rule_catalogue.promotion_id + ) + gift_promotion_id = graphene.Node.to_global_id("Promotion", rule_gift.promotion_id) + rule_catalogue_reward = rule_catalogue.reward_value + currency = order.currency + + tc = order.channel.tax_configuration + tc.country_exceptions.all().delete() + tc.prices_entered_with_tax = False + tc.tax_calculation_strategy = TaxCalculationStrategy.FLAT_RATES + tc.save() + tax_rate = Decimal("1.23") + + # when + order, lines = calculations.fetch_order_prices_if_expired( + order, plugins_manager, None, True + ) + + # then + assert len(lines) == 3 + line_1 = [line for line in lines if line.quantity == 3][0] + line_2 = [line for line in lines if line.quantity == 2][0] + gift_line = [line for line in lines if line.is_gift][0] + + assert OrderLineDiscount.objects.count() == 2 + gift_discount = gift_line.discounts.get() + catalogue_discount = line_2.discounts.get() + + variant_gift = gift_line.variant + variant_gift_listing = variant_gift.channel_listings.get(channel=order.channel) + variant_gift_undiscounted_unit_price = variant_gift_listing.price_amount + + assert gift_discount.line == gift_line + assert gift_discount.amount_value == variant_gift_undiscounted_unit_price + assert gift_discount.value == variant_gift_undiscounted_unit_price + assert gift_discount.value_type == DiscountValueType.FIXED + assert gift_discount.type == DiscountType.ORDER_PROMOTION + assert gift_discount.reason == f"Promotion: {gift_promotion_id}" + + assert gift_line.unit_discount_amount == variant_gift_undiscounted_unit_price + assert gift_line.unit_discount_reason == f"Promotion: {gift_promotion_id}" + assert gift_line.unit_discount_type == DiscountValueType.FIXED + assert gift_line.unit_discount_value == variant_gift_undiscounted_unit_price + assert gift_line.undiscounted_total_price_net_amount == Decimal(0) + assert gift_line.undiscounted_total_price_gross_amount == Decimal(0) + assert gift_line.undiscounted_unit_price_net_amount == Decimal(0) + assert gift_line.undiscounted_unit_price_gross_amount == Decimal(0) + assert gift_line.total_price_net_amount == Decimal(0) + assert gift_line.total_price_gross_amount == Decimal(0) + assert gift_line.base_unit_price_amount == Decimal(0) + assert gift_line.unit_price_net_amount == Decimal(0) + assert gift_line.unit_price_gross_amount == Decimal(0) + + assert not line_1.discounts.exists() + variant_1 = line_1.variant + variant_1_listing = variant_1.channel_listings.get(channel=order.channel) + variant_1_undiscounted_unit_price = variant_1_listing.price_amount + line_1_total_net_amount = line_1.undiscounted_total_price_net_amount + assert ( + line_1.undiscounted_total_price_net_amount + == variant_1_undiscounted_unit_price * line_1.quantity + ) + assert ( + line_1.undiscounted_total_price_gross_amount + == line_1.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_1.undiscounted_unit_price_net_amount == variant_1_undiscounted_unit_price + ) + assert ( + line_1.undiscounted_unit_price_gross_amount + == variant_1_undiscounted_unit_price * tax_rate + ) + assert line_1.total_price_net_amount == line_1.undiscounted_total_price_net_amount + assert ( + line_1.total_price_gross_amount == line_1.undiscounted_total_price_gross_amount + ) + assert line_1.base_unit_price_amount == line_1.undiscounted_unit_price_net_amount + assert line_1.unit_price_net_amount == line_1.undiscounted_unit_price_net_amount + assert line_1.unit_price_gross_amount == line_1.undiscounted_unit_price_gross_amount + + assert catalogue_discount.line == line_2 + assert catalogue_discount.amount_value == rule_catalogue_reward * line_2.quantity + assert catalogue_discount.value == rule_catalogue_reward + assert catalogue_discount.value_type == DiscountValueType.FIXED + assert catalogue_discount.type == DiscountType.PROMOTION + assert catalogue_discount.reason == f"Promotion: {catalogue_promotion_id}" + + variant_2 = line_2.variant + variant_2_listing = variant_2.channel_listings.get(channel=order.channel) + variant_2_undiscounted_unit_price = variant_2_listing.price_amount + line_2_total_net_amount = quantize_price( + line_2.undiscounted_total_price_net_amount - catalogue_discount.amount_value, + currency, + ) + assert ( + line_2.undiscounted_total_price_net_amount + == variant_2_undiscounted_unit_price * line_2.quantity + ) + assert ( + line_2.undiscounted_total_price_gross_amount + == line_2.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_2.undiscounted_unit_price_net_amount == variant_2_undiscounted_unit_price + ) + assert ( + line_2.undiscounted_unit_price_gross_amount + == variant_2_undiscounted_unit_price * tax_rate + ) + assert ( + line_2.base_unit_price_amount + == variant_2_undiscounted_unit_price - rule_catalogue_reward + ) + assert line_2.total_price_net_amount == line_2_total_net_amount + assert line_2.total_price_gross_amount == quantize_price( + line_2_total_net_amount * tax_rate, currency + ) + assert ( + line_2.unit_price_net_amount + == variant_2_undiscounted_unit_price - rule_catalogue_reward + ) + assert line_2.unit_price_gross_amount == quantize_price( + line_2.unit_price_net_amount * tax_rate, currency + ) + + shipping_price = order.shipping_price_net_amount + assert ( + order.undiscounted_total_net_amount + == line_1.undiscounted_total_price_net_amount + + line_2.undiscounted_total_price_net_amount + + shipping_price + ) + assert ( + order.undiscounted_total_gross_amount + == order.undiscounted_total_net_amount * tax_rate + ) + total_net_amount = quantize_price( + order.undiscounted_total_net_amount - catalogue_discount.amount_value, + currency, + ) + assert order.total_net_amount == total_net_amount + assert order.total_gross_amount == quantize_price( + total_net_amount * tax_rate, currency + ) + assert ( + order.subtotal_net_amount == line_1_total_net_amount + line_2_total_net_amount + ) + assert order.subtotal_gross_amount == quantize_price( + order.subtotal_net_amount * tax_rate, currency + ) + + +def test_fetch_order_prices_catalogue_and_order_discounts_exceed_total_flat_rates( + draft_order_and_promotions, + plugins_manager, +): + # given + order, rule_catalogue, rule_total, _ = draft_order_and_promotions + rule_total.reward_value = Decimal(100000) + rule_total.save(update_fields=["reward_value"]) + catalogue_promotion_id = graphene.Node.to_global_id( + "Promotion", rule_catalogue.promotion_id + ) + order_promotion_id = graphene.Node.to_global_id( + "Promotion", rule_total.promotion_id + ) + rule_catalogue_reward = rule_catalogue.reward_value + currency = order.currency + + tc = order.channel.tax_configuration + tc.country_exceptions.all().delete() + tc.prices_entered_with_tax = False + tc.tax_calculation_strategy = TaxCalculationStrategy.FLAT_RATES + tc.save() + tax_rate = Decimal("1.23") + + # when + order, lines = calculations.fetch_order_prices_if_expired( + order, plugins_manager, None, True + ) + + # then + line_1 = [line for line in lines if line.quantity == 3][0] + line_2 = [line for line in lines if line.quantity == 2][0] + catalogue_discount = OrderLineDiscount.objects.get() + order_discount = OrderDiscount.objects.get() + + shipping_price = order.shipping_price_net_amount + rule_total_reward = quantize_price( + order.undiscounted_total_net_amount + - shipping_price + - rule_catalogue_reward * line_2.quantity, + currency, + ) + assert order_discount.order == order + assert order_discount.amount_value == rule_total_reward + assert order_discount.value == rule_total.reward_value + assert order_discount.value_type == DiscountValueType.FIXED + assert order_discount.type == DiscountType.ORDER_PROMOTION + assert order_discount.reason == f"Promotion: {order_promotion_id}" + + variant_1 = line_1.variant + variant_1_listing = variant_1.channel_listings.get(channel=order.channel) + variant_1_undiscounted_unit_price = variant_1_listing.price_amount + assert not line_1.discounts.exists() + assert ( + line_1.undiscounted_total_price_net_amount + == variant_1_undiscounted_unit_price * line_1.quantity + ) + assert ( + line_1.undiscounted_total_price_gross_amount + == line_1.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_1.undiscounted_unit_price_net_amount == variant_1_undiscounted_unit_price + ) + assert ( + line_1.undiscounted_unit_price_gross_amount + == variant_1_undiscounted_unit_price * tax_rate + ) + assert line_1.base_unit_price_amount == variant_1_undiscounted_unit_price + assert line_1.total_price_net_amount == Decimal(0) + assert line_1.total_price_gross_amount == Decimal(0) + assert line_1.unit_price_net_amount == Decimal(0) + assert line_1.unit_price_gross_amount == Decimal(0) + + assert catalogue_discount.line == line_2 + assert catalogue_discount.amount_value == rule_catalogue_reward * line_2.quantity + assert catalogue_discount.value == rule_catalogue_reward + assert catalogue_discount.value_type == DiscountValueType.FIXED + assert catalogue_discount.type == DiscountType.PROMOTION + assert catalogue_discount.reason == f"Promotion: {catalogue_promotion_id}" + + variant_2 = line_2.variant + variant_2_listing = variant_2.channel_listings.get(channel=order.channel) + variant_2_undiscounted_unit_price = variant_2_listing.price_amount + assert ( + line_2.undiscounted_total_price_net_amount + == variant_2_undiscounted_unit_price * line_2.quantity + ) + assert ( + line_2.undiscounted_total_price_gross_amount + == line_2.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_2.undiscounted_unit_price_net_amount == variant_2_undiscounted_unit_price + ) + assert ( + line_2.undiscounted_unit_price_gross_amount + == variant_2_undiscounted_unit_price * tax_rate + ) + assert ( + line_2.base_unit_price_amount + == variant_2_undiscounted_unit_price - rule_catalogue_reward + ) + assert line_2.total_price_net_amount == Decimal(0) + assert line_2.total_price_gross_amount == Decimal(0) + assert line_2.unit_price_net_amount == Decimal(0) + assert line_2.unit_price_gross_amount == Decimal(0) + + assert ( + order.undiscounted_total_net_amount + == line_1.undiscounted_total_price_net_amount + + line_2.undiscounted_total_price_net_amount + + shipping_price + ) + assert ( + order.undiscounted_total_gross_amount + == order.undiscounted_total_net_amount * tax_rate + ) + assert order.total_net_amount == shipping_price + assert order.total_gross_amount == shipping_price * tax_rate + assert order.subtotal_net_amount == Decimal(0) + assert order.subtotal_gross_amount == Decimal(0) + + +def test_fetch_order_prices_manual_discount_and_order_discount_flat_rates( + order_with_lines_and_order_promotion, + plugins_manager, +): + # given + order = order_with_lines_and_order_promotion + assert OrderDiscount.objects.exists() + currency = order.currency + + tc = order.channel.tax_configuration + tc.country_exceptions.all().delete() + tc.prices_entered_with_tax = False + tc.tax_calculation_strategy = TaxCalculationStrategy.FLAT_RATES + tc.save() + tax_rate = Decimal("1.23") + + discount_value = Decimal("50") + manual_discount = order.discounts.create( + value_type=DiscountValueType.PERCENTAGE, + value=discount_value, + name="Manual order discount", + type=DiscountType.MANUAL, + ) + + # when + order, lines = calculations.fetch_order_prices_if_expired( + order, plugins_manager, None, True + ) + + # then + assert not OrderLineDiscount.objects.exists() + assert OrderDiscount.objects.count() == 1 + manual_discount.refresh_from_db() + + assert manual_discount.order == order + assert manual_discount.amount_value == Decimal( + order.undiscounted_total_net_amount / 2 + ) + assert manual_discount.value == discount_value + assert manual_discount.value_type == DiscountValueType.PERCENTAGE + assert manual_discount.type == DiscountType.MANUAL + assert not manual_discount.reason + + line_1 = [line for line in lines if line.quantity == 3][0] + line_2 = [line for line in lines if line.quantity == 2][0] + + variant_1 = line_1.variant + variant_1_listing = variant_1.channel_listings.get(channel=order.channel) + variant_1_undiscounted_unit_price = variant_1_listing.price_amount + line_1_total_net_amount = quantize_price( + line_1.undiscounted_total_price_net_amount * discount_value / 100, currency + ) + + assert ( + line_1.undiscounted_total_price_net_amount + == variant_1_undiscounted_unit_price * line_1.quantity + ) + assert ( + line_1.undiscounted_total_price_gross_amount + == line_1.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_1.undiscounted_unit_price_net_amount == variant_1_undiscounted_unit_price + ) + assert ( + line_1.undiscounted_unit_price_gross_amount + == variant_1_undiscounted_unit_price * tax_rate + ) + assert line_1.base_unit_price_amount == variant_1_undiscounted_unit_price + assert line_1.total_price_net_amount == line_1_total_net_amount + assert line_1.total_price_gross_amount == quantize_price( + line_1_total_net_amount * tax_rate, currency + ) + assert line_1.unit_price_net_amount == quantize_price( + line_1_total_net_amount / line_1.quantity, currency + ) + assert line_1.unit_price_gross_amount == quantize_price( + line_1.unit_price_net_amount * tax_rate, currency + ) + + variant_2 = line_2.variant + variant_2_listing = variant_2.channel_listings.get(channel=order.channel) + variant_2_undiscounted_unit_price = variant_2_listing.price_amount + line_2_total_net_amount = quantize_price( + line_2.undiscounted_total_price_net_amount * discount_value / 100, currency + ) + assert ( + line_2.undiscounted_total_price_net_amount + == variant_2_undiscounted_unit_price * line_2.quantity + ) + assert ( + line_2.undiscounted_total_price_gross_amount + == line_2.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_2.undiscounted_unit_price_net_amount == variant_2_undiscounted_unit_price + ) + assert ( + line_2.undiscounted_unit_price_gross_amount + == variant_2_undiscounted_unit_price * tax_rate + ) + assert line_2.base_unit_price_amount == variant_2_undiscounted_unit_price + assert line_2.total_price_net_amount == line_2_total_net_amount + assert line_2.total_price_gross_amount == quantize_price( + line_2_total_net_amount * tax_rate, currency + ) + assert line_2.unit_price_net_amount == quantize_price( + line_2_total_net_amount / line_2.quantity, currency + ) + assert line_2.unit_price_gross_amount == quantize_price( + line_2.unit_price_net_amount * tax_rate, currency + ) + + undiscounted_shipping_price = order.base_shipping_price_amount + total_net_amount = quantize_price( + order.undiscounted_total_net_amount * discount_value / 100, currency + ) + assert ( + order.undiscounted_total_net_amount + == line_1.undiscounted_total_price_net_amount + + line_2.undiscounted_total_price_net_amount + + undiscounted_shipping_price + ) + assert ( + order.undiscounted_total_gross_amount + == order.undiscounted_total_net_amount * tax_rate + ) + assert order.total_net_amount == total_net_amount + assert order.total_gross_amount == quantize_price( + total_net_amount * tax_rate, currency + ) + assert ( + order.subtotal_net_amount == line_1_total_net_amount + line_2_total_net_amount + ) + assert order.subtotal_gross_amount == quantize_price( + order.subtotal_net_amount * tax_rate, currency + ) + + +def test_fetch_order_prices_manual_discount_and_gift_discount_flat_rates( + order_with_lines_and_gift_promotion, + plugins_manager, +): + # given + order = order_with_lines_and_gift_promotion + assert OrderLineDiscount.objects.exists() + currency = order.currency + + tc = order.channel.tax_configuration + tc.country_exceptions.all().delete() + tc.prices_entered_with_tax = False + tc.tax_calculation_strategy = TaxCalculationStrategy.FLAT_RATES + tc.save() + tax_rate = Decimal("1.23") + + discount_value = Decimal("50") + manual_discount = order.discounts.create( + value_type=DiscountValueType.PERCENTAGE, + value=discount_value, + name="Manual order discount", + type=DiscountType.MANUAL, + ) + + # when + order, lines = calculations.fetch_order_prices_if_expired( + order, plugins_manager, None, True + ) + + # then + assert not OrderLineDiscount.objects.exists() + assert OrderDiscount.objects.count() == 1 + assert len(lines) == 2 + manual_discount.refresh_from_db() + + assert manual_discount.order == order + assert manual_discount.amount_value == Decimal( + order.undiscounted_total_net_amount / 2 + ) + assert manual_discount.value == discount_value + assert manual_discount.value_type == DiscountValueType.PERCENTAGE + assert manual_discount.type == DiscountType.MANUAL + assert not manual_discount.reason + + line_1 = [line for line in lines if line.quantity == 3][0] + line_2 = [line for line in lines if line.quantity == 2][0] + assert not [line for line in lines if line.is_gift] + + variant_1 = line_1.variant + variant_1_listing = variant_1.channel_listings.get(channel=order.channel) + variant_1_undiscounted_unit_price = variant_1_listing.price_amount + line_1_total_net_amount = quantize_price( + line_1.undiscounted_total_price_net_amount * discount_value / 100, currency + ) + + assert ( + line_1.undiscounted_total_price_net_amount + == variant_1_undiscounted_unit_price * line_1.quantity + ) + assert ( + line_1.undiscounted_total_price_gross_amount + == line_1.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_1.undiscounted_unit_price_net_amount == variant_1_undiscounted_unit_price + ) + assert ( + line_1.undiscounted_unit_price_gross_amount + == variant_1_undiscounted_unit_price * tax_rate + ) + assert line_1.base_unit_price_amount == variant_1_undiscounted_unit_price + assert line_1.total_price_net_amount == line_1_total_net_amount + assert line_1.total_price_gross_amount == quantize_price( + line_1_total_net_amount * tax_rate, currency + ) + assert line_1.unit_price_net_amount == quantize_price( + line_1_total_net_amount / line_1.quantity, currency + ) + assert line_1.unit_price_gross_amount == quantize_price( + line_1.unit_price_net_amount * tax_rate, currency + ) + + variant_2 = line_2.variant + variant_2_listing = variant_2.channel_listings.get(channel=order.channel) + variant_2_undiscounted_unit_price = variant_2_listing.price_amount + line_2_total_net_amount = quantize_price( + line_2.undiscounted_total_price_net_amount * discount_value / 100, currency + ) + assert ( + line_2.undiscounted_total_price_net_amount + == variant_2_undiscounted_unit_price * line_2.quantity + ) + assert ( + line_2.undiscounted_total_price_gross_amount + == line_2.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_2.undiscounted_unit_price_net_amount == variant_2_undiscounted_unit_price + ) + assert ( + line_2.undiscounted_unit_price_gross_amount + == variant_2_undiscounted_unit_price * tax_rate + ) + assert line_2.base_unit_price_amount == variant_2_undiscounted_unit_price + assert line_2.total_price_net_amount == line_2_total_net_amount + assert line_2.total_price_gross_amount == quantize_price( + line_2_total_net_amount * tax_rate, currency + ) + assert line_2.unit_price_net_amount == quantize_price( + line_2_total_net_amount / line_2.quantity, currency + ) + assert line_2.unit_price_gross_amount == quantize_price( + line_2.unit_price_net_amount * tax_rate, currency + ) + + undiscounted_shipping_price = order.base_shipping_price_amount + total_net_amount = quantize_price( + order.undiscounted_total_net_amount * discount_value / 100, currency + ) + assert ( + order.undiscounted_total_net_amount + == line_1.undiscounted_total_price_net_amount + + line_2.undiscounted_total_price_net_amount + + undiscounted_shipping_price + ) + assert ( + order.undiscounted_total_gross_amount + == order.undiscounted_total_net_amount * tax_rate + ) + assert order.total_net_amount == total_net_amount + assert order.total_gross_amount == quantize_price( + total_net_amount * tax_rate, currency + ) + assert ( + order.subtotal_net_amount == line_1_total_net_amount + line_2_total_net_amount + ) + assert order.subtotal_gross_amount == quantize_price( + order.subtotal_net_amount * tax_rate, currency + ) + assert ( + order.shipping_price_net_amount + == undiscounted_shipping_price * discount_value / 100 + ) + assert order.shipping_price_gross_amount == quantize_price( + order.shipping_price_net_amount * tax_rate, currency + ) + + +def test_fetch_order_prices_manual_discount_and_catalogue_discount_flat_rates( + order_with_lines_and_catalogue_promotion, + plugins_manager, +): + # given + order = order_with_lines_and_catalogue_promotion + currency = order.currency + rule = PromotionRule.objects.get() + rule_catalogue_reward = rule.reward_value + promotion_id = graphene.Node.to_global_id("Promotion", rule.promotion_id) + + tc = order.channel.tax_configuration + tc.country_exceptions.all().delete() + tc.prices_entered_with_tax = False + tc.tax_calculation_strategy = TaxCalculationStrategy.FLAT_RATES + tc.save() + tax_rate = Decimal("1.23") + + manual_discount_value = Decimal("50") + manual_discount = order.discounts.create( + value_type=DiscountValueType.PERCENTAGE, + value=manual_discount_value, + name="Manual order discount", + type=DiscountType.MANUAL, + ) + + # when + order, lines = calculations.fetch_order_prices_if_expired( + order, plugins_manager, None, True + ) + + # then + catalogue_discount = OrderLineDiscount.objects.get() + assert OrderDiscount.objects.count() == 1 + + manual_discount.refresh_from_db() + manual_discount_amount = Decimal( + (order.undiscounted_total_net_amount - catalogue_discount.amount_value) + * manual_discount_value + / 100 + ) + assert manual_discount.order == order + assert manual_discount.amount_value == manual_discount_amount + assert manual_discount.value == manual_discount_value + assert manual_discount.value_type == DiscountValueType.PERCENTAGE + assert manual_discount.type == DiscountType.MANUAL + assert not manual_discount.reason + + line_1 = [line for line in lines if line.quantity == 3][0] + line_2 = [line for line in lines if line.quantity == 2][0] + + assert catalogue_discount.line == line_1 + assert catalogue_discount.amount_value == rule_catalogue_reward * line_1.quantity + assert catalogue_discount.value == rule_catalogue_reward + assert catalogue_discount.value_type == DiscountValueType.FIXED + assert catalogue_discount.type == DiscountType.PROMOTION + assert catalogue_discount.reason == f"Promotion: {promotion_id}" + + variant_1 = line_1.variant + variant_1_listing = variant_1.channel_listings.get(channel=order.channel) + variant_1_undiscounted_unit_price = variant_1_listing.price_amount + line_1_total_net_amount = quantize_price( + (variant_1_undiscounted_unit_price - rule_catalogue_reward) + * line_1.quantity + * manual_discount_value + / 100, + currency, + ) + assert ( + line_1.undiscounted_total_price_net_amount + == variant_1_undiscounted_unit_price * line_1.quantity + ) + assert ( + line_1.undiscounted_total_price_gross_amount + == line_1.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_1.undiscounted_unit_price_net_amount == variant_1_undiscounted_unit_price + ) + assert ( + line_1.undiscounted_unit_price_gross_amount + == variant_1_undiscounted_unit_price * tax_rate + ) + assert ( + line_1.base_unit_price_amount + == variant_1_undiscounted_unit_price - rule_catalogue_reward + ) + assert line_1.total_price_net_amount == line_1_total_net_amount + assert line_1.total_price_gross_amount == quantize_price( + line_1_total_net_amount * tax_rate, currency + ) + assert line_1.unit_price_net_amount == quantize_price( + line_1_total_net_amount / line_1.quantity, currency + ) + assert line_1.unit_price_gross_amount == round_up( + line_1.unit_price_net_amount * tax_rate + ) + assert line_1.unit_discount_amount == rule_catalogue_reward + assert line_1.unit_discount_reason == f"Promotion: {promotion_id}" + assert line_1.unit_discount_value == rule_catalogue_reward + assert line_1.unit_discount_type == DiscountValueType.FIXED + + variant_2 = line_2.variant + variant_2_listing = variant_2.channel_listings.get(channel=order.channel) + variant_2_undiscounted_unit_price = variant_2_listing.price_amount + line_2_total_net_amount = quantize_price( + line_2.undiscounted_total_price_net_amount * manual_discount_value / 100, + currency, + ) + assert ( + line_2.undiscounted_total_price_net_amount + == variant_2_undiscounted_unit_price * line_2.quantity + ) + assert ( + line_2.undiscounted_total_price_gross_amount + == line_2.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_2.undiscounted_unit_price_net_amount == variant_2_undiscounted_unit_price + ) + assert ( + line_2.undiscounted_unit_price_gross_amount + == variant_2_undiscounted_unit_price * tax_rate + ) + assert line_2.base_unit_price_amount == variant_2_undiscounted_unit_price + assert line_2.total_price_net_amount == line_2_total_net_amount + assert line_2.total_price_gross_amount == quantize_price( + line_2_total_net_amount * tax_rate, currency + ) + assert line_2.unit_price_net_amount == quantize_price( + line_2_total_net_amount / line_2.quantity, currency + ) + assert line_2.unit_price_gross_amount == quantize_price( + line_2.unit_price_net_amount * tax_rate, currency + ) + + undiscounted_shipping_price = order.base_shipping_price_amount + total_net_amount = quantize_price( + (order.undiscounted_total_net_amount - catalogue_discount.amount_value) + * manual_discount_value + / 100, + currency, + ) + assert ( + order.undiscounted_total_net_amount + == line_1.undiscounted_total_price_net_amount + + line_2.undiscounted_total_price_net_amount + + undiscounted_shipping_price + ) + assert ( + order.undiscounted_total_gross_amount + == order.undiscounted_total_net_amount * tax_rate + ) + assert order.total_net_amount == total_net_amount + assert order.total_gross_amount == round_up(total_net_amount * tax_rate) + assert ( + order.subtotal_net_amount == line_1_total_net_amount + line_2_total_net_amount + ) + assert order.subtotal_gross_amount == quantize_price( + order.subtotal_net_amount * tax_rate, currency + ) + assert ( + order.shipping_price_net_amount + == undiscounted_shipping_price * manual_discount_value / 100 + ) + assert order.shipping_price_gross_amount == quantize_price( + order.shipping_price_net_amount * tax_rate, currency + ) + + +def test_fetch_order_prices_manual_line_discount_and_catalogue_discount_flat_rates( + order_with_lines_and_catalogue_promotion, + plugins_manager, +): + # given + order = order_with_lines_and_catalogue_promotion + + tc = order.channel.tax_configuration + tc.country_exceptions.all().delete() + tc.prices_entered_with_tax = False + tc.tax_calculation_strategy = TaxCalculationStrategy.FLAT_RATES + tc.save() + + line_1 = order.lines.get(quantity=3) + variant_1 = line_1.variant + variant_1_listing = variant_1.channel_listings.get(channel=order.channel) + + manual_discount_value = Decimal("5") + manual_discount_value_type = DiscountValueType.FIXED + manual_discount_reason = "Manual line discount" + manual_discount = line_1.discounts.create( + value_type=manual_discount_value_type, + value=manual_discount_value, + name="Manual order line discount", + type=DiscountType.MANUAL, + reason=manual_discount_reason, + ) + + # when + order, lines = calculations.fetch_order_prices_if_expired( + order, plugins_manager, None, True + ) + + # then + assert OrderLineDiscount.objects.count() == 1 + assert not OrderDiscount.objects.exists() + manual_discount.refresh_from_db() + + line_1 = [line for line in lines if line.quantity == 3][0] + + assert line_1.base_unit_price_amount == variant_1_listing.price_amount + assert manual_discount.line == line_1 + assert manual_discount.value == manual_discount_value + assert manual_discount.value_type == manual_discount_value_type + assert manual_discount.type == DiscountType.MANUAL + assert manual_discount.reason == manual_discount_reason + + # TODO https://github.com/saleor/saleor/issues/15517 + # line_1_total_net_amount = quantize_price( + # (variant_1_listing.price_amount - manual_discount_value) * line_1.quantity, + # order.currency + # ) + # assert line_1.total_price_net_amount == line_1_total_net_amount + + +def test_fetch_order_prices_catalogue_discount_race_condition( + order_with_lines_and_catalogue_promotion, + plugins_manager, +): + # given + order = order_with_lines_and_catalogue_promotion + OrderLineDiscount.objects.all().delete() + + # when + def call_before_creating_catalogue_line_discount(*args, **kwargs): + calculations.fetch_order_prices_if_expired(order, plugins_manager, None, True) + + with before_after.before( + "saleor.discount.utils.prepare_line_discount_objects_for_catalogue_promotions", + call_before_creating_catalogue_line_discount, + ): + calculations.fetch_order_prices_if_expired(order, plugins_manager, None, True) + + # then + assert OrderLineDiscount.objects.count() == 1 diff --git a/saleor/order/tests/test_order.py b/saleor/order/tests/test_order.py index 89c6ad9f3d6..820b01a7c60 100644 --- a/saleor/order/tests/test_order.py +++ b/saleor/order/tests/test_order.py @@ -658,12 +658,12 @@ def test_get_order_weight_non_existing_product( app=None, manager=anonymous_plugins, ) - old_weight = order.get_total_weight() + old_weight = order.weight product.delete() order.refresh_from_db() - new_weight = order.get_total_weight() + new_weight = order.weight assert old_weight == new_weight @@ -678,7 +678,9 @@ def test_get_voucher_discount_for_order_voucher_validation( quantity = order_with_lines.get_total_quantity() customer_email = order_with_lines.get_customer_email() - validate_voucher_in_order(order_with_lines) + validate_voucher_in_order( + order_with_lines, order_with_lines.lines.all(), order_with_lines.channel + ) mock_validate_voucher.assert_called_once_with( voucher, @@ -699,7 +701,9 @@ def test_validate_voucher_in_order_without_voucher( assert not order_with_lines.voucher - validate_voucher_in_order(order_with_lines) + validate_voucher_in_order( + order_with_lines, order_with_lines.lines.all(), order_with_lines.channel + ) mock_validate_voucher.assert_not_called() @@ -745,6 +749,7 @@ def test_value_voucher_order_discount( billing_address=address_usa, channel=channel_USD, ) + order.lines = Mock(all=Mock(return_value=[])) discount = get_voucher_discount_for_order(order) assert discount == Money(expected_value, "USD") @@ -784,6 +789,7 @@ def test_shipping_voucher_order_discount( voucher=voucher, channel=channel_USD, ) + order.lines = Mock(all=Mock(return_value=[])) discount = get_voucher_discount_for_order(order) assert discount == Money(expected_value, "USD") @@ -840,6 +846,7 @@ def test_shipping_voucher_checkout_discount_not_applicable_returns_zero( voucher=voucher, channel=channel_USD, ) + order.lines = Mock(all=Mock(return_value=[])) with pytest.raises(NotApplicable): get_voucher_discount_for_order(order) diff --git a/saleor/order/utils.py b/saleor/order/utils.py index ed0317155c5..f159dc5b1ec 100644 --- a/saleor/order/utils.py +++ b/saleor/order/utils.py @@ -611,7 +611,8 @@ def get_all_shipping_methods_for_order( if not order.is_shipping_required(): return [] - if not order.shipping_address: + shipping_address = order.shipping_address + if not shipping_address: return [] all_methods = [] @@ -622,7 +623,8 @@ def get_all_shipping_methods_for_order( order, channel_id=order.channel_id, price=order.subtotal.gross, - country_code=order.shipping_address.country.code, + shipping_address=shipping_address, + country_code=shipping_address.country.code, ) .prefetch_related("channel_listings") ) @@ -662,6 +664,10 @@ def is_shipping_required(lines: Iterable["OrderLine"]): return any(line.is_shipping_required for line in lines) +def get_total_quantity(lines: Iterable["OrderLine"]): + return sum([line.quantity for line in lines]) + + def get_valid_collection_points_for_order( lines: Iterable["OrderLine"], channel_id: int, @@ -739,7 +745,7 @@ def get_voucher_discount_for_order(order: Order) -> Money: """ if not order.voucher: return zero_money(order.currency) - validate_voucher_in_order(order) + validate_voucher_in_order(order, order.lines.all(), order.channel) subtotal = order.subtotal if order.voucher.type == VoucherType.ENTIRE_ORDER: return order.voucher.get_discount_amount_for(subtotal.gross, order.channel) @@ -835,6 +841,9 @@ def update_discount_for_order_line( value: Optional[Decimal], ): """Update discount fields for order line. Apply discount to the price.""" + # TODO: Move price calculation to fetch_order_prices_if_expired function. + # Here we should only create order line discount object + # https://github.com/saleor/saleor/issues/15517 current_value = order_line.unit_discount_value current_value_type = order_line.unit_discount_type value = value or current_value @@ -887,6 +896,52 @@ def update_discount_for_order_line( # from db order_line.save(update_fields=fields_to_update) + _update_manual_order_line_discount_object( + value, value_type, reason, order_line, order.currency + ) + + +def _update_manual_order_line_discount_object( + value, value_type, reason, order_line, currency +): + discount_to_update = None + discount_to_delete_ids = [] + discounts = order_line.discounts.all() + for discount in discounts: + if discount.type == DiscountType.MANUAL and not discount_to_update: + discount_to_update = discount + else: + discount_to_delete_ids.append(discount.pk) + + if discount_to_delete_ids: + OrderLineDiscount.objects.filter(id__in=discount_to_delete_ids).delete() + + amount_value = quantize_price( + order_line.unit_discount.amount * order_line.quantity, currency + ) + if not discount_to_update: + order_line.discounts.create( + type=DiscountType.MANUAL, + value_type=value_type, + value=value, + amount_value=amount_value, + currency=currency, + reason=reason, + ) + else: + update_fields = [] + if discount_to_update.value_type != value_type: + discount_to_update.value_type = value_type + update_fields.append("value_type") + if discount_to_update.value != value: + discount_to_update.value = value + discount_to_update.amount_value = amount_value + update_fields.extend(["value", "amount_value"]) + if discount_to_update.reason != reason: + discount_to_update.reason = reason + update_fields.append("reason") + discount_to_update.save(update_fields=update_fields) + def remove_discount_from_order_line(order_line: OrderLine, order: "Order"): """Drop discount applied to order line. Restore undiscounted price.""" @@ -916,6 +971,7 @@ def remove_discount_from_order_line(order_line: OrderLine, order: "Order"): "tax_rate", ] ) + order_line.discounts.all().delete() def update_order_charge_status(order: Order, granted_refund_amount: Decimal): diff --git a/saleor/payment/gateway.py b/saleor/payment/gateway.py index 9a834d9e9f4..32ec48b58f6 100644 --- a/saleor/payment/gateway.py +++ b/saleor/payment/gateway.py @@ -487,7 +487,7 @@ def list_payment_sources( gateway: str, customer_id: str, manager: "PluginsManager", - channel_slug: str, + channel_slug: Optional[str], ) -> list["CustomerSource"]: return manager.list_payment_sources(gateway, customer_id, channel_slug=channel_slug) diff --git a/saleor/payment/gateways/adyen/tests/conftest.py b/saleor/payment/gateways/adyen/tests/conftest.py index c94aa53ab8d..9e1cd43dbe5 100644 --- a/saleor/payment/gateways/adyen/tests/conftest.py +++ b/saleor/payment/gateways/adyen/tests/conftest.py @@ -64,6 +64,7 @@ def fun( ) manager = get_plugins_manager(allow_replica=False) + manager.get_all_plugins() return manager.plugins_per_channel[channel_USD.slug][0] return fun diff --git a/saleor/payment/gateways/authorize_net/tests/conftest.py b/saleor/payment/gateways/authorize_net/tests/conftest.py index c93faddc8d3..831481cb27c 100644 --- a/saleor/payment/gateways/authorize_net/tests/conftest.py +++ b/saleor/payment/gateways/authorize_net/tests/conftest.py @@ -71,4 +71,5 @@ def authorize_net_plugin(_, settings, channel_USD, authorize_net_gateway_config) ) manager = get_plugins_manager(allow_replica=False) + manager.get_all_plugins() return manager.plugins_per_channel[channel_USD.slug][0] diff --git a/saleor/payment/gateways/np_atobarai/tests/conftest.py b/saleor/payment/gateways/np_atobarai/tests/conftest.py index c342200c4db..6e31fe500ed 100644 --- a/saleor/payment/gateways/np_atobarai/tests/conftest.py +++ b/saleor/payment/gateways/np_atobarai/tests/conftest.py @@ -57,6 +57,7 @@ def fun( ) manager = get_plugins_manager(allow_replica=False) + manager.get_all_plugins() return manager.plugins_per_channel[channel_USD.slug][0] return fun diff --git a/saleor/payment/gateways/stripe/tests/conftest.py b/saleor/payment/gateways/stripe/tests/conftest.py index 12097220b22..8937ff0cc21 100644 --- a/saleor/payment/gateways/stripe/tests/conftest.py +++ b/saleor/payment/gateways/stripe/tests/conftest.py @@ -110,6 +110,7 @@ def fun( ) manager = get_plugins_manager(allow_replica=False) + manager.get_all_plugins() return manager.plugins_per_channel[channel_USD.slug][0] return fun diff --git a/saleor/payment/utils.py b/saleor/payment/utils.py index b54f90dbca3..621342ea74a 100644 --- a/saleor/payment/utils.py +++ b/saleor/payment/utils.py @@ -24,6 +24,7 @@ from ..checkout.fetch import fetch_checkout_info, fetch_checkout_lines from ..checkout.models import Checkout from ..checkout.payment_utils import update_refundable_for_checkout +from ..core.db.connection import allow_writer from ..core.prices import quantize_price from ..core.tracing import traced_atomic_transaction from ..graphql.core.utils import str_to_enum @@ -388,7 +389,8 @@ def create_payment( "metadata": {} if metadata is None else metadata, } - payment, _ = Payment.objects.get_or_create(defaults=defaults, **data) + with allow_writer(): + payment, _ = Payment.objects.get_or_create(defaults=defaults, **data) return payment @@ -1008,6 +1010,7 @@ def get_failed_type_based_on_event(event: TransactionEvent): return event.type +@allow_writer() def create_failed_transaction_event( event: TransactionEvent, cause: str, @@ -1485,7 +1488,8 @@ def create_manual_adjustment_events( user=user, ) if events_to_create: - return TransactionEvent.objects.bulk_create(events_to_create) + with allow_writer(): + return TransactionEvent.objects.bulk_create(events_to_create) return [] @@ -1512,6 +1516,7 @@ def get_transaction_item_params( } +@allow_writer() def create_transaction_for_order( order: "Order", user: Optional["User"], @@ -1540,6 +1545,7 @@ def create_transaction_for_order( return transaction_item +@allow_writer() def handle_transaction_initialize_session( source_object: Union[Checkout, Order], payment_gateway_data: PaymentGatewayData, diff --git a/saleor/permission/enums.py b/saleor/permission/enums.py index 9f49e27b5de..fac96c66cae 100644 --- a/saleor/permission/enums.py +++ b/saleor/permission/enums.py @@ -160,12 +160,16 @@ def get_permissions( codenames = get_permissions_codename() else: codenames = split_permission_codename(permissions) - return get_permissions_from_codenames(codenames) + return get_permissions_from_codenames(codenames, database_connection_name) -def get_permissions_from_codenames(permission_codenames: list[str]) -> QuerySet: +def get_permissions_from_codenames( + permission_codenames: list[str], + database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME, +) -> QuerySet: return ( - Permission.objects.filter(codename__in=permission_codenames) + Permission.objects.using(database_connection_name) + .filter(codename__in=permission_codenames) .prefetch_related("content_type") .order_by("codename") ) diff --git a/saleor/plugins/admin_email/plugin.py b/saleor/plugins/admin_email/plugin.py index a13fa41194d..9b3aef56081 100644 --- a/saleor/plugins/admin_email/plugin.py +++ b/saleor/plugins/admin_email/plugin.py @@ -1,4 +1,5 @@ import logging +from copy import deepcopy from dataclasses import asdict from typing import Union @@ -143,7 +144,7 @@ class AdminEmailPlugin(BasePlugin): "label": "CSV export failed template", }, } - CONFIG_STRUCTURE.update(DEFAULT_EMAIL_CONFIG_STRUCTURE) + CONFIG_STRUCTURE.update(deepcopy(DEFAULT_EMAIL_CONFIG_STRUCTURE)) CONFIG_STRUCTURE["host"]["help_text"] += ( " Leave it blank if you want to use system environment - EMAIL_HOST." ) diff --git a/saleor/plugins/admin_email/tests/conftest.py b/saleor/plugins/admin_email/tests/conftest.py index dc619b27f39..b373c6d8cdf 100644 --- a/saleor/plugins/admin_email/tests/conftest.py +++ b/saleor/plugins/admin_email/tests/conftest.py @@ -123,6 +123,7 @@ def fun( }, ) manager = get_plugins_manager(allow_replica=False) + manager.get_all_plugins() return manager.global_plugins[0] return fun diff --git a/saleor/plugins/admin_email/tests/test_plugin.py b/saleor/plugins/admin_email/tests/test_plugin.py index f88910fac09..8d35df6035b 100644 --- a/saleor/plugins/admin_email/tests/test_plugin.py +++ b/saleor/plugins/admin_email/tests/test_plugin.py @@ -8,7 +8,11 @@ from ....core.notify_events import NotifyEventType from ....graphql.tests.utils import get_graphql_content -from ...email_common import DEFAULT_EMAIL_VALUE, get_email_template +from ...email_common import ( + DEFAULT_EMAIL_CONFIG_STRUCTURE, + DEFAULT_EMAIL_VALUE, + get_email_template, +) from ...manager import get_plugins_manager from ...models import PluginConfiguration from ..constants import ( @@ -24,7 +28,7 @@ send_staff_order_confirmation, send_staff_reset_password, ) -from ..plugin import get_admin_event_map +from ..plugin import AdminEmailPlugin, get_admin_event_map def test_event_map(): @@ -281,6 +285,7 @@ def test_plugin_manager_doesnt_load_email_templates_from_db( ): settings.PLUGINS = ["saleor.plugins.admin_email.plugin.AdminEmailPlugin"] manager = get_plugins_manager(allow_replica=False) + manager.get_all_plugins() plugin = manager.all_plugins[0] email_config_item = None @@ -292,3 +297,30 @@ def test_plugin_manager_doesnt_load_email_templates_from_db( # email template from DB but returns default email value. assert email_config_item assert email_config_item["value"] == DEFAULT_EMAIL_VALUE + + +def test_plugin_dont_change_default_help_text_config_value(): + assert ( + AdminEmailPlugin.CONFIG_STRUCTURE["host"]["help_text"] + != DEFAULT_EMAIL_CONFIG_STRUCTURE["host"]["help_text"] + ) + assert ( + AdminEmailPlugin.CONFIG_STRUCTURE["port"]["help_text"] + != DEFAULT_EMAIL_CONFIG_STRUCTURE["port"]["help_text"] + ) + assert ( + AdminEmailPlugin.CONFIG_STRUCTURE["username"]["help_text"] + != DEFAULT_EMAIL_CONFIG_STRUCTURE["username"]["help_text"] + ) + assert ( + AdminEmailPlugin.CONFIG_STRUCTURE["password"]["help_text"] + != DEFAULT_EMAIL_CONFIG_STRUCTURE["password"]["help_text"] + ) + assert ( + AdminEmailPlugin.CONFIG_STRUCTURE["use_tls"]["help_text"] + != DEFAULT_EMAIL_CONFIG_STRUCTURE["use_tls"]["help_text"] + ) + assert ( + AdminEmailPlugin.CONFIG_STRUCTURE["use_ssl"]["help_text"] + != DEFAULT_EMAIL_CONFIG_STRUCTURE["use_ssl"]["help_text"] + ) diff --git a/saleor/plugins/avatax/plugin.py b/saleor/plugins/avatax/plugin.py index 31c067f12bd..c6456613f9b 100644 --- a/saleor/plugins/avatax/plugin.py +++ b/saleor/plugins/avatax/plugin.py @@ -822,31 +822,10 @@ def _get_shipping_tax_rate( ) return base_rate - def assign_tax_code_to_object_meta( + def get_tax_code_from_object_meta( self, - obj: "TaxClass", - tax_code: Optional[str], + obj: Union["Product", "ProductType", "TaxClass"], previous_value: Any, - ): - if not self.active: - return previous_value - - if tax_code is None and obj.pk: - obj.delete_value_from_metadata(META_CODE_KEY) - obj.delete_value_from_metadata(META_DESCRIPTION_KEY) - return previous_value - - codes = get_cached_tax_codes_or_fetch(self.config) - if tax_code not in codes: - return previous_value - - tax_description = codes.get(tax_code) - tax_item = {META_CODE_KEY: tax_code, META_DESCRIPTION_KEY: tax_description} - obj.store_value_in_metadata(items=tax_item) - return previous_value - - def get_tax_code_from_object_meta( - self, obj: Union["Product", "ProductType", "TaxClass"], previous_value: Any ) -> TaxType: if not self.active: return previous_value @@ -867,11 +846,6 @@ def get_tax_code_from_object_meta( description=tax_description, ) - def show_taxes_on_storefront(self, previous_value: bool) -> bool: - if not self.active: - return previous_value - return False - @classmethod def validate_authentication(cls, plugin_configuration: "PluginConfiguration"): conf = { diff --git a/saleor/plugins/avatax/tests/cassettes/test_avatax/test_calculate_order_total_gift_promotion.yaml b/saleor/plugins/avatax/tests/cassettes/test_avatax/test_calculate_order_total_gift_promotion.yaml new file mode 100644 index 00000000000..a4a6bf4fe7c --- /dev/null +++ b/saleor/plugins/avatax/tests/cassettes/test_avatax/test_calculate_order_total_gift_promotion.yaml @@ -0,0 +1,81 @@ +interactions: +- request: + body: '{"createTransactionModel": {"companyCode": "DEFAULT", "type": "SalesInvoice", + "lines": [{"quantity": 3, "amount": "30.000", "taxCode": "O9999999", "taxIncluded": + true, "itemCode": "SKU_AA", "discounted": false, "description": "Test product"}, + {"quantity": 2, "amount": "40.000", "taxCode": "O9999999", "taxIncluded": true, + "itemCode": "SKU_B", "discounted": false, "description": "Test product 2"}, + {"quantity": 1, "amount": "0.000", "taxCode": "O9999999", "taxIncluded": true, + "itemCode": "SKU_A", "discounted": false, "description": "Test product"}, {"quantity": + 1, "amount": "10.000", "taxCode": "FR000000", "taxIncluded": true, "itemCode": + "Shipping", "discounted": false, "description": null}], "code": "368cf947-7561-44f7-b39a-32a8dcc8482e", + "date": "2024-02-26", "customerCode": 0, "discount": null, "addresses": {"shipFrom": + {"line1": "Teczowa 7", "line2": "", "city": "Wroclaw", "region": "", "country": + "PL", "postalCode": "53-601"}, "shipTo": {"line1": "T\u0119czowa 7", "line2": + "", "city": "WROC\u0141AW", "region": "", "country": "PL", "postalCode": "53-601"}}, + "commit": false, "currencyCode": "USD", "email": "test@example.com"}}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate, br + Authorization: + - Basic Og== + Connection: + - keep-alive + Content-Length: + - '1145' + User-Agent: + - Saleor/3.20 + method: POST + uri: https://sandbox-rest.avatax.com/api/v2/transactions/createoradjust + response: + body: + string: '{"id":85050296660452,"code":"368cf947-7561-44f7-b39a-32a8dcc8482e","companyId":7799660,"date":"2024-02-26","status":"Saved","type":"SalesInvoice","batchCode":"","currencyCode":"USD","exchangeRateCurrencyCode":"USD","customerUsageType":"","entityUseCode":"","customerVendorCode":"0","customerCode":"0","exemptNo":"","reconciled":false,"locationCode":"","reportingLocationCode":"","purchaseOrderNo":"","referenceCode":"","salespersonCode":"","taxOverrideType":"None","taxOverrideAmount":0.0,"taxOverrideReason":"","totalAmount":65.04,"totalExempt":0.0,"totalDiscount":0.0,"totalTax":14.96,"totalTaxable":65.04,"totalTaxCalculated":14.96,"adjustmentReason":"NotAdjusted","adjustmentDescription":"","locked":false,"region":"","country":"PL","version":1,"softwareVersion":"24.2.0.0","originAddressId":85050296660454,"destinationAddressId":85050296660453,"exchangeRateEffectiveDate":"2024-02-26","exchangeRate":1.0,"description":"","email":"test@example.com","businessIdentificationNo":"","modifiedDate":"2024-02-26T18:58:42.8269135Z","modifiedUserId":6479978,"taxDate":"2024-02-26","lines":[{"id":85050296660458,"transactionId":85050296660452,"lineNumber":"1","boundaryOverrideId":0,"customerUsageType":"","entityUseCode":"","description":"Test + product","destinationAddressId":85050296660453,"originAddressId":85050296660454,"discountAmount":0.0,"discountTypeId":0,"exemptAmount":0.0,"exemptCertId":0,"exemptNo":"","isItemTaxable":true,"isSSTP":false,"itemCode":"SKU_AA","lineAmount":24.3900,"quantity":3.0,"ref1":"","ref2":"","reportingDate":"2024-02-26","revAccount":"","sourcing":"Destination","tax":5.61,"taxableAmount":24.39,"taxCalculated":5.61,"taxCode":"O9999999","taxCodeId":9111,"taxDate":"2024-02-26","taxEngine":"","taxOverrideType":"None","businessIdentificationNo":"","taxOverrideAmount":0.0,"taxOverrideReason":"","taxIncluded":true,"details":[{"id":85050296660482,"transactionLineId":85050296660458,"transactionId":85050296660452,"addressId":85050296660453,"country":"PL","region":"PL","countyFIPS":"","stateFIPS":"PL","exemptAmount":0.0000,"exemptReasonId":4,"inState":true,"jurisCode":"PL","jurisName":"POLAND","jurisdictionId":200102,"signatureCode":"","stateAssignedNo":"","jurisType":"CNT","jurisdictionType":"Country","nonTaxableAmount":0.0000,"nonTaxableRuleId":0,"nonTaxableType":"RateRule","rate":0.230000,"rateRuleId":411502,"rateSourceId":0,"serCode":"","sourcing":"Destination","tax":5.6100,"taxableAmount":24.3900,"taxType":"Output","taxSubTypeId":"O","taxTypeGroupId":"InputAndOutput","taxName":"Standard + Rate","taxAuthorityTypeId":45,"taxRegionId":205102,"taxCalculated":5.6100,"taxOverride":0.0000,"rateType":"Standard","rateTypeCode":"S","taxableUnits":24.3900,"nonTaxableUnits":0.0000,"exemptUnits":0.0000,"unitOfBasis":"PerCurrencyUnit","isNonPassThru":false,"isFee":false,"reportingTaxableUnits":24.39,"reportingNonTaxableUnits":0.0,"reportingExemptUnits":0.0,"reportingTax":5.61,"reportingTaxCalculated":5.61,"liabilityType":"Seller","chargedTo":"Buyer"}],"nonPassthroughDetails":[],"lineLocationTypes":[{"documentLineLocationTypeId":85050296660463,"documentLineId":85050296660458,"documentAddressId":85050296660454,"locationTypeCode":"ShipFrom"},{"documentLineLocationTypeId":85050296660464,"documentLineId":85050296660458,"documentAddressId":85050296660453,"locationTypeCode":"ShipTo"}],"parameters":[{"name":"Transport","value":"None"},{"name":"IsMarketplace","value":"False"},{"name":"IsTriangulation","value":"false"},{"name":"IsGoodsSecondHand","value":"false"}],"hsCode":"","costInsuranceFreight":0.0,"vatCode":"PLS-230C","vatNumberTypeId":0},{"id":85050296660459,"transactionId":85050296660452,"lineNumber":"2","boundaryOverrideId":0,"customerUsageType":"","entityUseCode":"","description":"Test + product 2","destinationAddressId":85050296660453,"originAddressId":85050296660454,"discountAmount":0.0,"discountTypeId":0,"exemptAmount":0.0,"exemptCertId":0,"exemptNo":"","isItemTaxable":true,"isSSTP":false,"itemCode":"SKU_B","lineAmount":32.5200,"quantity":2.0,"ref1":"","ref2":"","reportingDate":"2024-02-26","revAccount":"","sourcing":"Destination","tax":7.48,"taxableAmount":32.52,"taxCalculated":7.48,"taxCode":"O9999999","taxCodeId":9111,"taxDate":"2024-02-26","taxEngine":"","taxOverrideType":"None","businessIdentificationNo":"","taxOverrideAmount":0.0,"taxOverrideReason":"","taxIncluded":true,"details":[{"id":85050296660503,"transactionLineId":85050296660459,"transactionId":85050296660452,"addressId":85050296660453,"country":"PL","region":"PL","countyFIPS":"","stateFIPS":"PL","exemptAmount":0.0000,"exemptReasonId":4,"inState":true,"jurisCode":"PL","jurisName":"POLAND","jurisdictionId":200102,"signatureCode":"","stateAssignedNo":"","jurisType":"CNT","jurisdictionType":"Country","nonTaxableAmount":0.0000,"nonTaxableRuleId":0,"nonTaxableType":"RateRule","rate":0.230000,"rateRuleId":411502,"rateSourceId":0,"serCode":"","sourcing":"Destination","tax":7.4800,"taxableAmount":32.5200,"taxType":"Output","taxSubTypeId":"O","taxTypeGroupId":"InputAndOutput","taxName":"Standard + Rate","taxAuthorityTypeId":45,"taxRegionId":205102,"taxCalculated":7.4800,"taxOverride":0.0000,"rateType":"Standard","rateTypeCode":"S","taxableUnits":32.5200,"nonTaxableUnits":0.0000,"exemptUnits":0.0000,"unitOfBasis":"PerCurrencyUnit","isNonPassThru":false,"isFee":false,"reportingTaxableUnits":32.52,"reportingNonTaxableUnits":0.0,"reportingExemptUnits":0.0,"reportingTax":7.48,"reportingTaxCalculated":7.48,"liabilityType":"Seller","chargedTo":"Buyer"}],"nonPassthroughDetails":[],"lineLocationTypes":[{"documentLineLocationTypeId":85050296660484,"documentLineId":85050296660459,"documentAddressId":85050296660454,"locationTypeCode":"ShipFrom"},{"documentLineLocationTypeId":85050296660485,"documentLineId":85050296660459,"documentAddressId":85050296660453,"locationTypeCode":"ShipTo"}],"parameters":[{"name":"Transport","value":"None"},{"name":"IsMarketplace","value":"False"},{"name":"IsTriangulation","value":"false"},{"name":"IsGoodsSecondHand","value":"false"}],"hsCode":"","costInsuranceFreight":0.0,"vatCode":"PLS-230C","vatNumberTypeId":0},{"id":85050296660460,"transactionId":85050296660452,"lineNumber":"3","boundaryOverrideId":0,"customerUsageType":"","entityUseCode":"","description":"Test + product","destinationAddressId":85050296660453,"originAddressId":85050296660454,"discountAmount":0.0,"discountTypeId":0,"exemptAmount":0.0,"exemptCertId":0,"exemptNo":"","isItemTaxable":false,"isSSTP":false,"itemCode":"SKU_A","lineAmount":0.0,"quantity":1.0,"ref1":"","ref2":"","reportingDate":"2024-02-26","revAccount":"","sourcing":"Destination","tax":0.0,"taxableAmount":0.0,"taxCalculated":0.0,"taxCode":"O9999999","taxCodeId":9111,"taxDate":"2024-02-26","taxEngine":"","taxOverrideType":"None","businessIdentificationNo":"","taxOverrideAmount":0.0,"taxOverrideReason":"","taxIncluded":true,"details":[{"id":85050296660524,"transactionLineId":85050296660460,"transactionId":85050296660452,"addressId":85050296660453,"country":"PL","region":"PL","countyFIPS":"","stateFIPS":"PL","exemptAmount":0.0000,"exemptReasonId":4,"inState":true,"jurisCode":"PL","jurisName":"POLAND","jurisdictionId":200102,"signatureCode":"","stateAssignedNo":"","jurisType":"CNT","jurisdictionType":"Country","nonTaxableAmount":0.0000,"nonTaxableRuleId":0,"nonTaxableType":"RateRule","rate":0.230000,"rateRuleId":411502,"rateSourceId":0,"serCode":"","sourcing":"Destination","tax":0.0000,"taxableAmount":0.0000,"taxType":"Output","taxSubTypeId":"O","taxTypeGroupId":"InputAndOutput","taxName":"Standard + Rate","taxAuthorityTypeId":45,"taxRegionId":205102,"taxCalculated":0.0000,"taxOverride":0.0000,"rateType":"Standard","rateTypeCode":"S","taxableUnits":0.0000,"nonTaxableUnits":0.0000,"exemptUnits":0.0000,"unitOfBasis":"PerCurrencyUnit","isNonPassThru":false,"isFee":false,"reportingTaxableUnits":0.0,"reportingNonTaxableUnits":0.0,"reportingExemptUnits":0.0,"reportingTax":0.0,"reportingTaxCalculated":0.0,"liabilityType":"Seller","chargedTo":"Buyer"}],"nonPassthroughDetails":[],"lineLocationTypes":[{"documentLineLocationTypeId":85050296660505,"documentLineId":85050296660460,"documentAddressId":85050296660454,"locationTypeCode":"ShipFrom"},{"documentLineLocationTypeId":85050296660506,"documentLineId":85050296660460,"documentAddressId":85050296660453,"locationTypeCode":"ShipTo"}],"parameters":[{"name":"Transport","value":"None"},{"name":"IsMarketplace","value":"False"},{"name":"IsTriangulation","value":"false"},{"name":"IsGoodsSecondHand","value":"false"}],"hsCode":"","costInsuranceFreight":0.0,"vatCode":"PLS-230C","vatNumberTypeId":0},{"id":85050296660461,"transactionId":85050296660452,"lineNumber":"4","boundaryOverrideId":0,"customerUsageType":"","entityUseCode":"","description":"","destinationAddressId":85050296660453,"originAddressId":85050296660454,"discountAmount":0.0,"discountTypeId":0,"exemptAmount":0.0,"exemptCertId":0,"exemptNo":"","isItemTaxable":true,"isSSTP":false,"itemCode":"Shipping","lineAmount":8.1300,"quantity":1.0,"ref1":"","ref2":"","reportingDate":"2024-02-26","revAccount":"","sourcing":"Destination","tax":1.87,"taxableAmount":8.13,"taxCalculated":1.87,"taxCode":"FR000000","taxCodeId":8550,"taxDate":"2024-02-26","taxEngine":"","taxOverrideType":"None","businessIdentificationNo":"","taxOverrideAmount":0.0,"taxOverrideReason":"","taxIncluded":true,"details":[{"id":85050296660545,"transactionLineId":85050296660461,"transactionId":85050296660452,"addressId":85050296660453,"country":"PL","region":"PL","countyFIPS":"","stateFIPS":"PL","exemptAmount":0.0000,"exemptReasonId":4,"inState":true,"jurisCode":"PL","jurisName":"POLAND","jurisdictionId":200102,"signatureCode":"","stateAssignedNo":"","jurisType":"CNT","jurisdictionType":"Country","nonTaxableAmount":0.0000,"nonTaxableRuleId":0,"nonTaxableType":"RateRule","rate":0.230000,"rateRuleId":411502,"rateSourceId":0,"serCode":"","sourcing":"Destination","tax":1.8700,"taxableAmount":8.1300,"taxType":"Output","taxSubTypeId":"O","taxTypeGroupId":"InputAndOutput","taxName":"Standard + Rate","taxAuthorityTypeId":45,"taxRegionId":205102,"taxCalculated":1.8700,"taxOverride":0.0000,"rateType":"Standard","rateTypeCode":"S","taxableUnits":8.1300,"nonTaxableUnits":0.0000,"exemptUnits":0.0000,"unitOfBasis":"PerCurrencyUnit","isNonPassThru":false,"isFee":false,"reportingTaxableUnits":8.13,"reportingNonTaxableUnits":0.0,"reportingExemptUnits":0.0,"reportingTax":1.87,"reportingTaxCalculated":1.87,"liabilityType":"Seller","chargedTo":"Buyer"}],"nonPassthroughDetails":[],"lineLocationTypes":[{"documentLineLocationTypeId":85050296660526,"documentLineId":85050296660461,"documentAddressId":85050296660454,"locationTypeCode":"ShipFrom"},{"documentLineLocationTypeId":85050296660527,"documentLineId":85050296660461,"documentAddressId":85050296660453,"locationTypeCode":"ShipTo"}],"parameters":[{"name":"Transport","value":"None"},{"name":"IsMarketplace","value":"False"},{"name":"IsTriangulation","value":"false"},{"name":"IsGoodsSecondHand","value":"false"}],"hsCode":"","costInsuranceFreight":0.0,"vatCode":"PLS-230D","vatNumberTypeId":0}],"addresses":[{"id":85050296660453,"transactionId":85050296660452,"boundaryLevel":"Zip5","line1":"Teczowa + 7","line2":"","line3":"","city":"WROCLAW","region":"","postalCode":"53-601","country":"PL","taxRegionId":205102},{"id":85050296660454,"transactionId":85050296660452,"boundaryLevel":"Zip5","line1":"Teczowa + 7","line2":"","line3":"","city":"Wroclaw","region":"","postalCode":"53-601","country":"PL","taxRegionId":205102}],"locationTypes":[{"documentLocationTypeId":85050296660456,"documentId":85050296660452,"documentAddressId":85050296660454,"locationTypeCode":"ShipFrom"},{"documentLocationTypeId":85050296660457,"documentId":85050296660452,"documentAddressId":85050296660453,"locationTypeCode":"ShipTo"}],"summary":[{"country":"PL","region":"PL","jurisType":"Country","jurisCode":"PL","jurisName":"POLAND","taxAuthorityType":45,"stateAssignedNo":"","taxType":"Output","taxSubType":"O","taxName":"Standard + Rate","rateType":"Standard","taxable":65.04,"rate":0.230000,"tax":14.96,"taxCalculated":14.96,"nonTaxable":0.00,"exemption":0.00}]}' + headers: + Connection: + - keep-alive + Content-Type: + - application/json; charset=utf-8 + Date: + - Mon, 26 Feb 2024 18:58:42 GMT + Location: + - /api/v2/companies/7799660/transactions/85050296660452 + ServerDuration: + - '00:00:00.0815874' + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + api-supported-versions: + - '2.0' + cache-control: + - private, no-cache, no-store + referrer-policy: + - same-origin + strict-transport-security: + - max-age=31536000; includeSubdomains + x-avalara-uid: + - df8331af-ab28-4015-9bee-4d945f48634c + x-correlation-id: + - df8331af-ab28-4015-9bee-4d945f48634c + x-frame-options: + - sameorigin + x-permitted-cross-domain-policies: + - none + x-xss-protection: + - 1; mode=block + status: + code: 201 + message: Created +version: 1 diff --git a/saleor/plugins/avatax/tests/cassettes/test_avatax/test_calculate_order_total_order_promotion.yaml b/saleor/plugins/avatax/tests/cassettes/test_avatax/test_calculate_order_total_order_promotion.yaml new file mode 100644 index 00000000000..0ff21af4c2f --- /dev/null +++ b/saleor/plugins/avatax/tests/cassettes/test_avatax/test_calculate_order_total_order_promotion.yaml @@ -0,0 +1,77 @@ +interactions: +- request: + body: '{"createTransactionModel": {"companyCode": "DEFAULT", "type": "SalesInvoice", + "lines": [{"quantity": 3, "amount": "30.000", "taxCode": "O9999999", "taxIncluded": + true, "itemCode": "SKU_AA", "discounted": true, "description": "Test product"}, + {"quantity": 2, "amount": "40.000", "taxCode": "O9999999", "taxIncluded": true, + "itemCode": "SKU_B", "discounted": true, "description": "Test product 2"}, {"quantity": + 1, "amount": "10.000", "taxCode": "FR000000", "taxIncluded": true, "itemCode": + "Shipping", "discounted": false, "description": null}], "code": "a01b5217-42b3-48dc-889f-954154d1c3f6", + "date": "2024-02-26", "customerCode": 0, "discount": "25.000", "addresses": + {"shipFrom": {"line1": "Teczowa 7", "line2": "", "city": "Wroclaw", "region": + "", "country": "PL", "postalCode": "53-601"}, "shipTo": {"line1": "T\u0119czowa + 7", "line2": "", "city": "WROC\u0141AW", "region": "", "country": "PL", "postalCode": + "53-601"}}, "commit": false, "currencyCode": "USD", "email": "test@example.com"}}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate, br + Authorization: + - Basic Og== + Connection: + - keep-alive + Content-Length: + - '994' + User-Agent: + - Saleor/3.20 + method: POST + uri: https://sandbox-rest.avatax.com/api/v2/transactions/createoradjust + response: + body: + string: '{"id":85050296358456,"code":"a01b5217-42b3-48dc-889f-954154d1c3f6","companyId":7799660,"date":"2024-02-26","status":"Saved","type":"SalesInvoice","batchCode":"","currencyCode":"USD","exchangeRateCurrencyCode":"USD","customerUsageType":"","entityUseCode":"","customerVendorCode":"0","customerCode":"0","exemptNo":"","reconciled":false,"locationCode":"","reportingLocationCode":"","purchaseOrderNo":"","referenceCode":"","salespersonCode":"","taxOverrideType":"None","taxOverrideAmount":0.0,"taxOverrideReason":"","totalAmount":69.71,"totalExempt":0.0,"totalDiscount":25.0,"totalTax":10.29,"totalTaxable":44.71,"totalTaxCalculated":10.29,"adjustmentReason":"NotAdjusted","adjustmentDescription":"","locked":false,"region":"","country":"PL","version":1,"softwareVersion":"24.2.0.0","originAddressId":85050296358458,"destinationAddressId":85050296358457,"exchangeRateEffectiveDate":"2024-02-26","exchangeRate":1.0,"description":"","email":"test@example.com","businessIdentificationNo":"","modifiedDate":"2024-02-26T18:49:50.8851355Z","modifiedUserId":6479978,"taxDate":"2024-02-26","lines":[{"id":85050296358462,"transactionId":85050296358456,"lineNumber":"1","boundaryOverrideId":0,"customerUsageType":"","entityUseCode":"","description":"Test + product","destinationAddressId":85050296358457,"originAddressId":85050296358458,"discountAmount":10.71,"discountTypeId":0,"exemptAmount":0.0,"exemptCertId":0,"exemptNo":"","isItemTaxable":true,"isSSTP":false,"itemCode":"SKU_AA","lineAmount":26.3900,"quantity":3.0,"ref1":"","ref2":"","reportingDate":"2024-02-26","revAccount":"","sourcing":"Destination","tax":3.61,"taxableAmount":15.68,"taxCalculated":3.61,"taxCode":"O9999999","taxCodeId":9111,"taxDate":"2024-02-26","taxEngine":"","taxOverrideType":"None","businessIdentificationNo":"","taxOverrideAmount":0.0,"taxOverrideReason":"","taxIncluded":true,"details":[{"id":85050296358485,"transactionLineId":85050296358462,"transactionId":85050296358456,"addressId":85050296358457,"country":"PL","region":"PL","countyFIPS":"","stateFIPS":"PL","exemptAmount":0.0000,"exemptReasonId":4,"inState":true,"jurisCode":"PL","jurisName":"POLAND","jurisdictionId":200102,"signatureCode":"","stateAssignedNo":"","jurisType":"CNT","jurisdictionType":"Country","nonTaxableAmount":0.0000,"nonTaxableRuleId":0,"nonTaxableType":"RateRule","rate":0.230000,"rateRuleId":411502,"rateSourceId":0,"serCode":"","sourcing":"Destination","tax":3.6100,"taxableAmount":15.6800,"taxType":"Output","taxSubTypeId":"O","taxTypeGroupId":"InputAndOutput","taxName":"Standard + Rate","taxAuthorityTypeId":45,"taxRegionId":205102,"taxCalculated":3.6100,"taxOverride":0.0000,"rateType":"Standard","rateTypeCode":"S","taxableUnits":15.6800,"nonTaxableUnits":0.0000,"exemptUnits":0.0000,"unitOfBasis":"PerCurrencyUnit","isNonPassThru":false,"isFee":false,"reportingTaxableUnits":15.68,"reportingNonTaxableUnits":0.0,"reportingExemptUnits":0.0,"reportingTax":3.61,"reportingTaxCalculated":3.61,"liabilityType":"Seller","chargedTo":"Buyer"}],"nonPassthroughDetails":[],"lineLocationTypes":[{"documentLineLocationTypeId":85050296358466,"documentLineId":85050296358462,"documentAddressId":85050296358458,"locationTypeCode":"ShipFrom"},{"documentLineLocationTypeId":85050296358467,"documentLineId":85050296358462,"documentAddressId":85050296358457,"locationTypeCode":"ShipTo"}],"parameters":[{"name":"Transport","value":"None"},{"name":"IsMarketplace","value":"False"},{"name":"IsTriangulation","value":"false"},{"name":"IsGoodsSecondHand","value":"false"}],"hsCode":"","costInsuranceFreight":0.0,"vatCode":"PLS-230C","vatNumberTypeId":0},{"id":85050296358463,"transactionId":85050296358456,"lineNumber":"2","boundaryOverrideId":0,"customerUsageType":"","entityUseCode":"","description":"Test + product 2","destinationAddressId":85050296358457,"originAddressId":85050296358458,"discountAmount":14.29,"discountTypeId":0,"exemptAmount":0.0,"exemptCertId":0,"exemptNo":"","isItemTaxable":true,"isSSTP":false,"itemCode":"SKU_B","lineAmount":35.1900,"quantity":2.0,"ref1":"","ref2":"","reportingDate":"2024-02-26","revAccount":"","sourcing":"Destination","tax":4.81,"taxableAmount":20.9,"taxCalculated":4.81,"taxCode":"O9999999","taxCodeId":9111,"taxDate":"2024-02-26","taxEngine":"","taxOverrideType":"None","businessIdentificationNo":"","taxOverrideAmount":0.0,"taxOverrideReason":"","taxIncluded":true,"details":[{"id":85050296358506,"transactionLineId":85050296358463,"transactionId":85050296358456,"addressId":85050296358457,"country":"PL","region":"PL","countyFIPS":"","stateFIPS":"PL","exemptAmount":0.0000,"exemptReasonId":4,"inState":true,"jurisCode":"PL","jurisName":"POLAND","jurisdictionId":200102,"signatureCode":"","stateAssignedNo":"","jurisType":"CNT","jurisdictionType":"Country","nonTaxableAmount":0.0000,"nonTaxableRuleId":0,"nonTaxableType":"RateRule","rate":0.230000,"rateRuleId":411502,"rateSourceId":0,"serCode":"","sourcing":"Destination","tax":4.8100,"taxableAmount":20.9000,"taxType":"Output","taxSubTypeId":"O","taxTypeGroupId":"InputAndOutput","taxName":"Standard + Rate","taxAuthorityTypeId":45,"taxRegionId":205102,"taxCalculated":4.8100,"taxOverride":0.0000,"rateType":"Standard","rateTypeCode":"S","taxableUnits":20.9000,"nonTaxableUnits":0.0000,"exemptUnits":0.0000,"unitOfBasis":"PerCurrencyUnit","isNonPassThru":false,"isFee":false,"reportingTaxableUnits":20.9,"reportingNonTaxableUnits":0.0,"reportingExemptUnits":0.0,"reportingTax":4.81,"reportingTaxCalculated":4.81,"liabilityType":"Seller","chargedTo":"Buyer"}],"nonPassthroughDetails":[],"lineLocationTypes":[{"documentLineLocationTypeId":85050296358487,"documentLineId":85050296358463,"documentAddressId":85050296358458,"locationTypeCode":"ShipFrom"},{"documentLineLocationTypeId":85050296358488,"documentLineId":85050296358463,"documentAddressId":85050296358457,"locationTypeCode":"ShipTo"}],"parameters":[{"name":"Transport","value":"None"},{"name":"IsMarketplace","value":"False"},{"name":"IsTriangulation","value":"false"},{"name":"IsGoodsSecondHand","value":"false"}],"hsCode":"","costInsuranceFreight":0.0,"vatCode":"PLS-230C","vatNumberTypeId":0},{"id":85050296358464,"transactionId":85050296358456,"lineNumber":"3","boundaryOverrideId":0,"customerUsageType":"","entityUseCode":"","description":"","destinationAddressId":85050296358457,"originAddressId":85050296358458,"discountAmount":0.0,"discountTypeId":0,"exemptAmount":0.0,"exemptCertId":0,"exemptNo":"","isItemTaxable":true,"isSSTP":false,"itemCode":"Shipping","lineAmount":8.1300,"quantity":1.0,"ref1":"","ref2":"","reportingDate":"2024-02-26","revAccount":"","sourcing":"Destination","tax":1.87,"taxableAmount":8.13,"taxCalculated":1.87,"taxCode":"FR000000","taxCodeId":8550,"taxDate":"2024-02-26","taxEngine":"","taxOverrideType":"None","businessIdentificationNo":"","taxOverrideAmount":0.0,"taxOverrideReason":"","taxIncluded":true,"details":[{"id":85050296358527,"transactionLineId":85050296358464,"transactionId":85050296358456,"addressId":85050296358457,"country":"PL","region":"PL","countyFIPS":"","stateFIPS":"PL","exemptAmount":0.0000,"exemptReasonId":4,"inState":true,"jurisCode":"PL","jurisName":"POLAND","jurisdictionId":200102,"signatureCode":"","stateAssignedNo":"","jurisType":"CNT","jurisdictionType":"Country","nonTaxableAmount":0.0000,"nonTaxableRuleId":0,"nonTaxableType":"RateRule","rate":0.230000,"rateRuleId":411502,"rateSourceId":0,"serCode":"","sourcing":"Destination","tax":1.8700,"taxableAmount":8.1300,"taxType":"Output","taxSubTypeId":"O","taxTypeGroupId":"InputAndOutput","taxName":"Standard + Rate","taxAuthorityTypeId":45,"taxRegionId":205102,"taxCalculated":1.8700,"taxOverride":0.0000,"rateType":"Standard","rateTypeCode":"S","taxableUnits":8.1300,"nonTaxableUnits":0.0000,"exemptUnits":0.0000,"unitOfBasis":"PerCurrencyUnit","isNonPassThru":false,"isFee":false,"reportingTaxableUnits":8.13,"reportingNonTaxableUnits":0.0,"reportingExemptUnits":0.0,"reportingTax":1.87,"reportingTaxCalculated":1.87,"liabilityType":"Seller","chargedTo":"Buyer"}],"nonPassthroughDetails":[],"lineLocationTypes":[{"documentLineLocationTypeId":85050296358508,"documentLineId":85050296358464,"documentAddressId":85050296358458,"locationTypeCode":"ShipFrom"},{"documentLineLocationTypeId":85050296358509,"documentLineId":85050296358464,"documentAddressId":85050296358457,"locationTypeCode":"ShipTo"}],"parameters":[{"name":"Transport","value":"None"},{"name":"IsMarketplace","value":"False"},{"name":"IsTriangulation","value":"false"},{"name":"IsGoodsSecondHand","value":"false"}],"hsCode":"","costInsuranceFreight":0.0,"vatCode":"PLS-230D","vatNumberTypeId":0}],"addresses":[{"id":85050296358457,"transactionId":85050296358456,"boundaryLevel":"Zip5","line1":"Teczowa + 7","line2":"","line3":"","city":"WROCLAW","region":"","postalCode":"53-601","country":"PL","taxRegionId":205102},{"id":85050296358458,"transactionId":85050296358456,"boundaryLevel":"Zip5","line1":"Teczowa + 7","line2":"","line3":"","city":"Wroclaw","region":"","postalCode":"53-601","country":"PL","taxRegionId":205102}],"locationTypes":[{"documentLocationTypeId":85050296358460,"documentId":85050296358456,"documentAddressId":85050296358458,"locationTypeCode":"ShipFrom"},{"documentLocationTypeId":85050296358461,"documentId":85050296358456,"documentAddressId":85050296358457,"locationTypeCode":"ShipTo"}],"summary":[{"country":"PL","region":"PL","jurisType":"Country","jurisCode":"PL","jurisName":"POLAND","taxAuthorityType":45,"stateAssignedNo":"","taxType":"Output","taxSubType":"O","taxName":"Standard + Rate","rateType":"Standard","taxable":44.71,"rate":0.230000,"tax":10.29,"taxCalculated":10.29,"nonTaxable":0.00,"exemption":0.00}]}' + headers: + Connection: + - keep-alive + Content-Type: + - application/json; charset=utf-8 + Date: + - Mon, 26 Feb 2024 18:49:50 GMT + Location: + - /api/v2/companies/7799660/transactions/85050296358456 + ServerDuration: + - '00:00:00.0712472' + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + api-supported-versions: + - '2.0' + cache-control: + - private, no-cache, no-store + referrer-policy: + - same-origin + strict-transport-security: + - max-age=31536000; includeSubdomains + x-avalara-uid: + - 33c0538c-1f94-4fad-b391-ed49bb1ad281 + x-correlation-id: + - 33c0538c-1f94-4fad-b391-ed49bb1ad281 + x-frame-options: + - sameorigin + x-permitted-cross-domain-policies: + - none + x-xss-protection: + - 1; mode=block + status: + code: 201 + message: Created +version: 1 diff --git a/saleor/plugins/avatax/tests/deprecated/test_tax_code_mutations.py b/saleor/plugins/avatax/tests/deprecated/test_tax_code_mutations.py deleted file mode 100644 index 4922a0e8acb..00000000000 --- a/saleor/plugins/avatax/tests/deprecated/test_tax_code_mutations.py +++ /dev/null @@ -1,259 +0,0 @@ -import graphene -from django.test import override_settings - -from .....graphql.tests.utils import get_graphql_content -from .....tax.models import TaxClass - -PRODUCT_TYPE_CREATE_MUTATION_TAX_CODE = """ - mutation createProductType($name: String, $taxCode: String) { - productTypeCreate(input: {name: $name, taxCode: $taxCode}) { - productType { - name - slug - taxClass { - id - name - metadata { - key - value - } - } - } - errors { - field - message - code - } - } - } -""" - - -@override_settings(PLUGINS=["saleor.plugins.avatax.plugin.AvataxPlugin"]) -def test_product_type_create_tax_code_creates_new_tax_class( - staff_api_client, - permission_manage_product_types_and_attributes, - plugin_configuration, - monkeypatch, -): - # given - plugin_configuration() - tax_code = "P0000000" - monkeypatch.setattr( - "saleor.plugins.avatax.plugin.get_cached_tax_codes_or_fetch", - lambda _: {tax_code: "desc"}, - ) - variables = {"name": "New product type", "taxCode": tax_code} - TaxClass.objects.all().delete() - - # when - response = staff_api_client.post_graphql( - PRODUCT_TYPE_CREATE_MUTATION_TAX_CODE, - variables, - permissions=[permission_manage_product_types_and_attributes], - ) - - # then - content = get_graphql_content(response) - tax_class = TaxClass.objects.first() - assert tax_class - assert tax_class.name == tax_code - assert tax_class.metadata - assert tax_class.metadata["avatax.code"] == tax_code - assert content["data"]["productTypeCreate"]["productType"]["taxClass"][ - "id" - ] == graphene.Node.to_global_id("TaxClass", tax_class.pk) - - -PRODUCT_TYPE_UPDATE_MUTATION_TAX_CODE = """ - mutation updateProductType($id: ID!, $taxCode: String) { - productTypeUpdate(id: $id, input: {taxCode: $taxCode}) { - productType { - name - slug - taxClass { - id - name - metadata { - key - value - } - } - } - errors { - field - message - code - } - } - } -""" - - -@override_settings(PLUGINS=["saleor.plugins.avatax.plugin.AvataxPlugin"]) -def test_product_type_update_tax_code_creates_new_tax_class( - staff_api_client, - product_type, - permission_manage_product_types_and_attributes, - plugin_configuration, - monkeypatch, -): - # given - plugin_configuration() - tax_code = "P0000000" - monkeypatch.setattr( - "saleor.plugins.avatax.plugin.get_cached_tax_codes_or_fetch", - lambda _: {tax_code: "desc"}, - ) - variables = { - "id": graphene.Node.to_global_id("ProductType", product_type.pk), - "taxCode": tax_code, - } - - # when - response = staff_api_client.post_graphql( - PRODUCT_TYPE_UPDATE_MUTATION_TAX_CODE, - variables, - permissions=[permission_manage_product_types_and_attributes], - ) - - # then - content = get_graphql_content(response) - product_type.refresh_from_db() - tax_class = product_type.tax_class - assert tax_class - assert tax_class.metadata["avatax.code"] == tax_code - assert content["data"]["productTypeUpdate"]["productType"]["taxClass"][ - "id" - ] == graphene.Node.to_global_id("TaxClass", tax_class.pk) - - -PRODUCT_CREATE_MUTATION_TAX_CODE = """ - mutation ProductCreate($name: String!, $productTypeId: ID!, $taxCode: String) { - productCreate( - input: {name: $name, productType: $productTypeId, taxCode: $taxCode} - ) { - errors { - field - message - } - product { - id - name - taxClass { - id - name - metadata { - key - value - } - } - } - } - } -""" - - -@override_settings(PLUGINS=["saleor.plugins.avatax.plugin.AvataxPlugin"]) -def test_product_create_tax_code_creates_new_tax_class( - staff_api_client, - product_type, - permission_manage_products, - plugin_configuration, - monkeypatch, -): - # given - plugin_configuration() - tax_code = "P0000000" - monkeypatch.setattr( - "saleor.plugins.avatax.plugin.get_cached_tax_codes_or_fetch", - lambda _: {tax_code: "desc"}, - ) - variables = { - "name": "New product", - "productTypeId": graphene.Node.to_global_id("ProductType", product_type.pk), - "taxCode": tax_code, - } - TaxClass.objects.all().delete() - - # when - response = staff_api_client.post_graphql( - PRODUCT_CREATE_MUTATION_TAX_CODE, - variables, - permissions=[permission_manage_products], - ) - - # then - content = get_graphql_content(response) - tax_class = TaxClass.objects.first() - assert tax_class - assert tax_class.name == tax_code - assert tax_class.metadata - assert tax_class.metadata["avatax.code"] == tax_code - assert content["data"]["productCreate"]["product"]["taxClass"][ - "id" - ] == graphene.Node.to_global_id("TaxClass", tax_class.pk) - - -PRODUCT_UPDATE_MUTATION_TAX_CODE = """ - mutation ProductUpdate($id: ID!, $taxCode: String) { - productUpdate(id: $id, input: {taxCode: $taxCode}) { - errors { - field - message - } - product { - id - name - taxClass { - id - name - metadata { - key - value - } - } - } - } - } -""" - - -@override_settings(PLUGINS=["saleor.plugins.avatax.plugin.AvataxPlugin"]) -def test_product_update_tax_code_creates_new_tax_class( - staff_api_client, - permission_manage_products, - product, - plugin_configuration, - monkeypatch, -): - # given - plugin_configuration() - tax_code = "P0000000" - monkeypatch.setattr( - "saleor.plugins.avatax.plugin.get_cached_tax_codes_or_fetch", - lambda _: {tax_code: "desc"}, - ) - variables = { - "id": graphene.Node.to_global_id("Product", product.pk), - "taxCode": tax_code, - } - - # when - response = staff_api_client.post_graphql( - PRODUCT_UPDATE_MUTATION_TAX_CODE, - variables, - permissions=[permission_manage_products], - ) - - # then - content = get_graphql_content(response) - product.refresh_from_db() - tax_class = product.tax_class - assert tax_class - assert tax_class.name == tax_code - assert tax_class.metadata - assert tax_class.metadata["avatax.code"] == tax_code - assert content["data"]["productUpdate"]["product"]["taxClass"][ - "id" - ] == graphene.Node.to_global_id("TaxClass", tax_class.pk) diff --git a/saleor/plugins/avatax/tests/test_avatax.py b/saleor/plugins/avatax/tests/test_avatax.py index d81b293880f..eb76f92509a 100644 --- a/saleor/plugins/avatax/tests/test_avatax.py +++ b/saleor/plugins/avatax/tests/test_avatax.py @@ -74,6 +74,11 @@ def order_set_shipping_method(order, shipping_method): ) +def assign_tax_code_to_object_meta(obj: "TaxClass", tax_code: str): + tax_item = {META_CODE_KEY: tax_code} + obj.store_value_in_metadata(items=tax_item) + + @pytest.mark.vcr @pytest.mark.parametrize( ("expected_net", "expected_gross", "prices_entered_with_tax"), @@ -1218,7 +1223,7 @@ def test_calculate_checkout_total_uses_default_calculation( line = checkout_with_item.lines.first() product = line.variant.product product.metadata = {} - manager.assign_tax_code_to_object_meta(product.product_type.tax_class, "PC040156") + assign_tax_code_to_object_meta(product.product_type.tax_class, "PC040156") product.save() product.product_type.save() product_with_single_variant.tax_class = tax_class_zero_rates @@ -1286,7 +1291,7 @@ def test_calculate_checkout_total_uses_default_calculation_with_promotion( line = checkout.lines.first() product = line.variant.product product.metadata = {} - manager.assign_tax_code_to_object_meta(product.product_type.tax_class, "PC040156") + assign_tax_code_to_object_meta(product.product_type.tax_class, "PC040156") product.save() product.product_type.save() product_with_single_variant.tax_class = tax_class_zero_rates @@ -1358,7 +1363,7 @@ def test_calculate_checkout_total( line = checkout_with_item.lines.first() product = line.variant.product - manager.assign_tax_code_to_object_meta(product.tax_class, "PS081282") + assign_tax_code_to_object_meta(product.tax_class, "PS081282") product.save() product.tax_class.save() @@ -1418,7 +1423,7 @@ def test_calculate_checkout_total_with_order_promotion( line = checkout.lines.first() product = line.variant.product - manager.assign_tax_code_to_object_meta(product.tax_class, "PS081282") + assign_tax_code_to_object_meta(product.tax_class, "PS081282") product.save() product.tax_class.save() @@ -1480,7 +1485,7 @@ def test_calculate_checkout_total_with_gift_promotion( line = checkout.lines.get(is_gift=False) product = line.variant.product - manager.assign_tax_code_to_object_meta(product.tax_class, "PS081282") + assign_tax_code_to_object_meta(product.tax_class, "PS081282") product.save() product.tax_class.save() @@ -1559,7 +1564,7 @@ def test_calculate_checkout_total_with_promotion( line = checkout.lines.first() product = line.variant.product - manager.assign_tax_code_to_object_meta(product.tax_class, "PS081282") + assign_tax_code_to_object_meta(product.tax_class, "PS081282") product.save() product.tax_class.save() @@ -1641,7 +1646,7 @@ def test_calculate_checkout_total_for_JPY( line = checkout.lines.first() product = line.variant.product - manager.assign_tax_code_to_object_meta(product.tax_class, "PS081282") + assign_tax_code_to_object_meta(product.tax_class, "PS081282") product.save() product.tax_class.save() @@ -1715,7 +1720,7 @@ def test_calculate_checkout_total_for_JPY_with_promotion( variant = line.variant product = variant.product - manager.assign_tax_code_to_object_meta(product.tax_class, "PS081282") + assign_tax_code_to_object_meta(product.tax_class, "PS081282") product.save() product.tax_class.save() @@ -2058,7 +2063,7 @@ def test_calculate_checkout_total_not_charged_product_and_shipping_with_0_price( variant = line.variant product = variant.product product.metadata = {} - manager.assign_tax_code_to_object_meta(product.product_type.tax_class, "PS081282") + assign_tax_code_to_object_meta(product.product_type.tax_class, "PS081282") product.save() product.product_type.save() @@ -2465,6 +2470,56 @@ def test_calculate_order_total_for_JPY( assert price == TaxedMoney(net=Money("3496", "JPY"), gross=Money("4300", "JPY")) +@pytest.mark.vcr +@override_settings(PLUGINS=["saleor.plugins.avatax.plugin.AvataxPlugin"]) +def test_calculate_order_total_order_promotion( + order_with_lines_and_order_promotion, + shipping_zone, + site_settings, + address, + plugin_configuration, +): + plugin_configuration() + manager = get_plugins_manager(allow_replica=False) + order = order_with_lines_and_order_promotion + method = shipping_zone.shipping_methods.get() + order.shipping_address = order.billing_address.get_copy() + order_set_shipping_method(order, method) + order.save() + + site_settings.company_address = address + site_settings.save() + + price = manager.calculate_order_total(order, order.lines.all()) + price = quantize_price(price, price.currency) + assert price == TaxedMoney(net=Money("44.71", "USD"), gross=Money("55.00", "USD")) + + +@pytest.mark.vcr +@override_settings(PLUGINS=["saleor.plugins.avatax.plugin.AvataxPlugin"]) +def test_calculate_order_total_gift_promotion( + order_with_lines_and_gift_promotion, + shipping_zone, + site_settings, + address, + plugin_configuration, +): + plugin_configuration() + manager = get_plugins_manager(allow_replica=False) + order = order_with_lines_and_gift_promotion + method = shipping_zone.shipping_methods.get() + order.shipping_address = order.billing_address.get_copy() + order_set_shipping_method(order, method) + order.save() + + site_settings.company_address = address + site_settings.save() + + price = manager.calculate_order_total(order, order.lines.all()) + price = quantize_price(price, price.currency) + assert price == TaxedMoney(net=Money("65.04", "USD"), gross=Money("80.00", "USD")) + + @pytest.mark.vcr @override_settings(PLUGINS=["saleor.plugins.avatax.plugin.AvataxPlugin"]) def test_calculate_order_shipping_entire_order_voucher( @@ -4358,7 +4413,7 @@ def test_get_order_shipping_tax_rate_skip_plugin( def test_get_plugin_configuration(settings, channel_USD): settings.PLUGINS = ["saleor.plugins.avatax.plugin.AvataxPlugin"] manager = get_plugins_manager(allow_replica=False) - plugin = manager.get_plugin(AvataxPlugin.PLUGIN_ID) + plugin = manager.get_plugin(AvataxPlugin.PLUGIN_ID, channel_slug=channel_USD.slug) configuration_fields = [ configuration_item["name"] for configuration_item in plugin.configuration @@ -4442,12 +4497,6 @@ def test_save_plugin_configuration_cannot_be_enabled_without_config( ) -def test_show_taxes_on_storefront(plugin_configuration): - plugin_configuration() - manager = get_plugins_manager(allow_replica=False) - assert manager.show_taxes_on_storefront() is False - - @patch("saleor.plugins.avatax.plugin.api_post_request_task.delay") @override_settings(PLUGINS=["saleor.plugins.avatax.plugin.AvataxPlugin"]) def test_order_confirmed( @@ -4746,11 +4795,13 @@ def test_plugin_uses_configuration_from_db( manager.preprocess_order_creation(checkout_info, lines) -def test_skip_disabled_plugin(settings, plugin_configuration): +def test_skip_disabled_plugin(settings, plugin_configuration, channel_USD): plugin_configuration(username=None, password=None) settings.PLUGINS = ["saleor.plugins.avatax.plugin.AvataxPlugin"] manager = get_plugins_manager(allow_replica=False) - plugin: AvataxPlugin = manager.get_plugin(AvataxPlugin.PLUGIN_ID) + plugin: AvataxPlugin = manager.get_plugin( + AvataxPlugin.PLUGIN_ID, channel_slug=channel_USD.slug + ) assert ( plugin._skip_plugin( @@ -4760,14 +4811,18 @@ def test_skip_disabled_plugin(settings, plugin_configuration): ) -def test_get_tax_code_from_object_meta(product, settings, plugin_configuration): +def test_get_tax_code_from_object_meta( + product, settings, plugin_configuration, channel_USD +): product.tax_class.store_value_in_metadata( {META_CODE_KEY: "KEY", META_DESCRIPTION_KEY: "DESC"} ) plugin_configuration(username=None, password=None) settings.PLUGINS = ["saleor.plugins.avatax.plugin.AvataxPlugin"] manager = get_plugins_manager(allow_replica=False) - tax_type = manager.get_tax_code_from_object_meta(product.tax_class) + tax_type = manager.get_tax_code_from_object_meta( + product.tax_class, channel_USD.slug + ) assert isinstance(tax_type, TaxType) assert tax_type.code == "KEY" @@ -5927,108 +5982,6 @@ def test_get_order_lines_data_adds_lines_with_taxes_disabled_for_line( assert len(lines_data) == len(order_with_lines.lines.all()) -def test_assign_tax_code_to_object_meta( - settings, channel_USD, plugin_configuration, product, monkeypatch -): - # given - settings.PLUGINS = ["saleor.plugins.avatax.plugin.AvataxPlugin"] - plugin_configuration(channel=channel_USD) - - tax_code = "standard" - description = "desc" - - monkeypatch.setattr( - "saleor.plugins.avatax.plugin.get_cached_tax_codes_or_fetch", - lambda _: {tax_code: description}, - ) - - manager = get_plugins_manager(allow_replica=False) - - # when - manager.assign_tax_code_to_object_meta(product.tax_class, tax_code) - - # then - assert product.tax_class.metadata == { - META_CODE_KEY: tax_code, - META_DESCRIPTION_KEY: description, - } - - -def test_assign_tax_code_to_object_meta_none_as_tax_code( - settings, channel_USD, plugin_configuration, product, monkeypatch -): - # given - settings.PLUGINS = ["saleor.plugins.avatax.plugin.AvataxPlugin"] - plugin_configuration(channel=channel_USD) - - tax_code = None - description = "desc" - - monkeypatch.setattr( - "saleor.plugins.avatax.plugin.get_cached_tax_codes_or_fetch", - lambda _: {"standard": description}, - ) - manager = get_plugins_manager(allow_replica=False) - - # when - manager.assign_tax_code_to_object_meta(product.tax_class, tax_code) - - # then - assert product.metadata == {} - - -def test_assign_tax_code_to_object_meta_no_obj_id_and_none_as_tax_code( - settings, channel_USD, plugin_configuration, monkeypatch -): - # given - settings.PLUGINS = ["saleor.plugins.avatax.plugin.AvataxPlugin"] - plugin_configuration(channel=channel_USD) - - tax_code = None - description = "desc" - - monkeypatch.setattr( - "saleor.plugins.avatax.plugin.get_cached_tax_codes_or_fetch", - lambda _: {"standard": description}, - ) - - tax_class = TaxClass(name="A new tax class.") - manager = get_plugins_manager(allow_replica=False) - - # when - manager.assign_tax_code_to_object_meta(tax_class, tax_code) - - # then - assert tax_class.metadata == {} - - -def test_assign_tax_code_to_object_meta_no_obj_id( - settings, channel_USD, plugin_configuration, monkeypatch -): - # given - settings.PLUGINS = ["saleor.plugins.avatax.plugin.AvataxPlugin"] - plugin_configuration(channel=channel_USD) - - tax_code = "standard" - description = "desc" - - monkeypatch.setattr( - "saleor.plugins.avatax.plugin.get_cached_tax_codes_or_fetch", - lambda _: {tax_code: description}, - ) - tax_class = TaxClass(name="A new product.") - manager = get_plugins_manager(allow_replica=False) - - # when - manager.assign_tax_code_to_object_meta(tax_class, tax_code) - - # then - assert tax_class.metadata == { - META_CODE_KEY: tax_code, - META_DESCRIPTION_KEY: description, - } - - @patch("saleor.plugins.avatax.plugin.get_checkout_tax_data") def test_calculate_checkout_shipping_validates_checkout( mocked_func, settings, channel_USD, plugin_configuration, checkout_with_item diff --git a/saleor/plugins/base_plugin.py b/saleor/plugins/base_plugin.py index 9fa29eb73cd..0d65d718d32 100644 --- a/saleor/plugins/base_plugin.py +++ b/saleor/plugins/base_plugin.py @@ -244,9 +244,6 @@ def __str__(self): # status is changed. app_status_changed: Callable[["App", None], None] - # Assign tax code dedicated to plugin. - assign_tax_code_to_object_meta: Callable[["TaxClass", Union[str, None], Any], Any] - # Trigger when attribute is created. # # Overwrite this method if you need to trigger specific logic after an attribute is @@ -625,7 +622,11 @@ def __str__(self): # Return tax code from object meta. get_tax_code_from_object_meta: Callable[ - [Union["Product", "ProductType", "TaxClass"], "TaxType"], "TaxType" + [ + Union["Product", "ProductType", "TaxClass"], + "TaxType", + ], + "TaxType", ] # Return list of all tax categories. @@ -1144,12 +1145,6 @@ def __str__(self): # metadata is updated. shipping_zone_metadata_updated: Callable[["ShippingZone", None], None] - # Define if storefront should add info about taxes to the price. - # - # It is used only by the old storefront. The returned value determines if - # storefront should append info to the price about "including/excluding X% VAT". - show_taxes_on_storefront: Callable[[bool], bool] - # Trigger when staff user is created. # # Overwrite this method if you need to trigger specific logic after a staff user is diff --git a/saleor/plugins/manager.py b/saleor/plugins/manager.py index ae05b8c4640..6d550932235 100644 --- a/saleor/plugins/manager.py +++ b/saleor/plugins/manager.py @@ -14,6 +14,7 @@ from ..channel.models import Channel from ..checkout import base_calculations +from ..core.db.connection import allow_writer from ..core.models import EventDelivery from ..core.payments import PaymentInterface from ..core.prices import quantize_price @@ -133,69 +134,82 @@ def _load_plugin( def __init__(self, plugins: list[str], requestor_getter=None, allow_replica=True): with opentracing.global_tracer().start_active_span("PluginsManager.__init__"): + self.plugins = plugins self._allow_replica = allow_replica self.all_plugins = [] self.global_plugins = [] self.plugins_per_channel = defaultdict(list) + self.loaded_all_channels = False + self.loaded_channels: set[str] = set() + self.loaded_global = False + self.requestor_getter = requestor_getter - channel_map = self._get_channel_map() - global_db_configs, channel_db_configs = self._get_db_plugin_configs( - channel_map - ) + def _ensure_channel_plugins_loaded( + self, channel_slug: Optional[str], channel: Optional[Channel] = None + ): + if channel_slug is None and not self.loaded_global: + global_db_config = self._get_db_plugin_configs(None) - for plugin_path in plugins: + for plugin_path in self.plugins: with opentracing.global_tracer().start_active_span(f"{plugin_path}"): PluginClass = import_string(plugin_path) if not getattr(PluginClass, "CONFIGURATION_PER_CHANNEL", False): plugin = self._load_plugin( PluginClass, - global_db_configs, - requestor_getter=requestor_getter, - allow_replica=allow_replica, + global_db_config, + requestor_getter=self.requestor_getter, + allow_replica=self._allow_replica, ) self.global_plugins.append(plugin) self.all_plugins.append(plugin) - else: - for channel in channel_map.values(): - channel_configs = channel_db_configs.get(channel, {}) - plugin = self._load_plugin( - PluginClass, - channel_configs, - channel, - requestor_getter, - allow_replica, - ) - self.plugins_per_channel[channel.slug].append(plugin) - self.all_plugins.append(plugin) - - for channel in channel_map.values(): - self.plugins_per_channel[channel.slug].extend(self.global_plugins) - - def _get_db_plugin_configs(self, channel_map): + self.loaded_global = True + + if channel_slug is not None and channel_slug not in self.loaded_channels: + if channel is None: + channel = ( + Channel.objects.using(self.database) + .filter(slug=channel_slug) + .first() + ) + if not channel: + return + + channel_db_config = self._get_db_plugin_configs(channel) + + for plugin_path in self.plugins: + with opentracing.global_tracer().start_active_span(f"{plugin_path}"): + PluginClass = import_string(plugin_path) + if getattr(PluginClass, "CONFIGURATION_PER_CHANNEL", False): + plugin = self._load_plugin( + PluginClass, + channel_db_config, + channel=channel, + requestor_getter=self.requestor_getter, + allow_replica=self._allow_replica, + ) + self.plugins_per_channel[channel_slug].append(plugin) + self.all_plugins.append(plugin) + + self._ensure_channel_plugins_loaded(None) + self.plugins_per_channel[channel_slug].extend(self.global_plugins) + self.loaded_channels.add(channel_slug) + + def _get_db_plugin_configs(self, channel: Optional[Channel]): with opentracing.global_tracer().start_active_span("_get_db_plugin_configs"): plugin_manager_configs = PluginConfiguration.objects.using( self.database - ).all() - channel_configs: defaultdict[Channel, dict] = defaultdict(dict) - global_configs = {} + ).filter(channel=channel) + configs = {} for db_plugin_config in plugin_manager_configs.iterator(): - channel = channel_map.get(db_plugin_config.channel_id) - if channel is None: - global_configs[db_plugin_config.identifier] = db_plugin_config - else: - db_plugin_config.channel = channel - channel_configs[channel][db_plugin_config.identifier] = ( - db_plugin_config - ) - - return global_configs, channel_configs + configs[db_plugin_config.identifier] = db_plugin_config + return configs def __run_method_on_plugins( self, method_name: str, default_value: Any, *args, - channel_slug: Optional[str] = None, + channel_slug: Optional[str], plugin_ids: Optional[list[str]] = None, **kwargs, ): @@ -248,7 +262,13 @@ def change_user_address( ) -> "Address": default_value = address return self.__run_method_on_plugins( - "change_user_address", default_value, address, address_type, user, save + "change_user_address", + default_value, + address, + address_type, + user, + save, + channel_slug=None, ) def calculate_checkout_total( @@ -614,11 +634,15 @@ def get_order_line_tax_rate( def get_tax_rate_type_choices(self) -> list[TaxType]: default_value: list = [] - return self.__run_method_on_plugins("get_tax_rate_type_choices", default_value) - - def show_taxes_on_storefront(self) -> bool: - default_value = False - return self.__run_method_on_plugins("show_taxes_on_storefront", default_value) + plugins = self.get_all_plugins() + return ( + self.__run_plugin_method_until_first_success( + "get_tax_rate_type_choices", + channel_slug=None, + plugins=plugins, + ) + or default_value + ) def get_taxes_for_checkout( self, checkout_info, lines, app_identifier @@ -655,96 +679,131 @@ def preprocess_order_creation( def customer_created(self, customer: "User"): default_value = None - return self.__run_method_on_plugins("customer_created", default_value, customer) + return self.__run_method_on_plugins( + "customer_created", default_value, customer, channel_slug=None + ) def customer_deleted(self, customer: "User", webhooks=None): default_value = None return self.__run_method_on_plugins( - "customer_deleted", default_value, customer, webhooks=webhooks + "customer_deleted", + default_value, + customer, + webhooks=webhooks, + channel_slug=None, ) def customer_updated(self, customer: "User", webhooks=None): default_value = None return self.__run_method_on_plugins( - "customer_updated", default_value, customer, webhooks=webhooks + "customer_updated", + default_value, + customer, + webhooks=webhooks, + channel_slug=None, ) def customer_metadata_updated(self, customer: "User", webhooks=None): default_value = None return self.__run_method_on_plugins( - "customer_metadata_updated", default_value, customer, webhooks=webhooks + "customer_metadata_updated", + default_value, + customer, + webhooks=webhooks, + channel_slug=None, ) def collection_created(self, collection: "Collection"): default_value = None return self.__run_method_on_plugins( - "collection_created", default_value, collection + "collection_created", default_value, collection, channel_slug=None ) def collection_updated(self, collection: "Collection"): default_value = None return self.__run_method_on_plugins( - "collection_updated", default_value, collection + "collection_updated", default_value, collection, channel_slug=None ) def collection_deleted(self, collection: "Collection", webhooks=None): default_value = None return self.__run_method_on_plugins( - "collection_deleted", default_value, collection, webhooks=webhooks + "collection_deleted", + default_value, + collection, + webhooks=webhooks, + channel_slug=None, ) def collection_metadata_updated(self, collection: "Collection"): default_value = None return self.__run_method_on_plugins( - "collection_metadata_updated", default_value, collection + "collection_metadata_updated", default_value, collection, channel_slug=None ) def product_created(self, product: "Product", webhooks=None): default_value = None return self.__run_method_on_plugins( - "product_created", default_value, product, webhooks=webhooks + "product_created", + default_value, + product, + webhooks=webhooks, + channel_slug=None, ) def product_updated(self, product: "Product", webhooks=None): default_value = None return self.__run_method_on_plugins( - "product_updated", default_value, product, webhooks=webhooks + "product_updated", + default_value, + product, + webhooks=webhooks, + channel_slug=None, ) def product_deleted(self, product: "Product", variants: list[int], webhooks=None): default_value = None return self.__run_method_on_plugins( - "product_deleted", default_value, product, variants, webhooks=webhooks + "product_deleted", + default_value, + product, + variants, + webhooks=webhooks, + channel_slug=None, ) def product_media_created(self, media: "ProductMedia"): default_value = None return self.__run_method_on_plugins( - "product_media_created", default_value, media + "product_media_created", default_value, media, channel_slug=None ) def product_media_updated(self, media: "ProductMedia"): default_value = None return self.__run_method_on_plugins( - "product_media_updated", default_value, media + "product_media_updated", default_value, media, channel_slug=None ) def product_media_deleted(self, media: "ProductMedia"): default_value = None return self.__run_method_on_plugins( - "product_media_deleted", default_value, media + "product_media_deleted", default_value, media, channel_slug=None ) def product_metadata_updated(self, product: "Product"): default_value = None return self.__run_method_on_plugins( - "product_metadata_updated", default_value, product + "product_metadata_updated", default_value, product, channel_slug=None ) def product_variant_created(self, product_variant: "ProductVariant", webhooks=None): default_value = None return self.__run_method_on_plugins( - "product_variant_created", default_value, product_variant, webhooks=webhooks + "product_variant_created", + default_value, + product_variant, + webhooks=webhooks, + channel_slug=None, ) def product_variant_updated( @@ -757,42 +816,62 @@ def product_variant_updated( product_variant, webhooks=webhooks, **kwargs, + channel_slug=None, ) def product_variant_deleted(self, product_variant: "ProductVariant", webhooks=None): default_value = None return self.__run_method_on_plugins( - "product_variant_deleted", default_value, product_variant, webhooks=webhooks + "product_variant_deleted", + default_value, + product_variant, + webhooks=webhooks, + channel_slug=None, ) def product_variant_out_of_stock(self, stock: "Stock", webhooks=None): default_value = None self.__run_method_on_plugins( - "product_variant_out_of_stock", default_value, stock, webhooks=webhooks + "product_variant_out_of_stock", + default_value, + stock, + webhooks=webhooks, + channel_slug=None, ) def product_variant_back_in_stock(self, stock: "Stock", webhooks=None): default_value = None self.__run_method_on_plugins( - "product_variant_back_in_stock", default_value, stock, webhooks=webhooks + "product_variant_back_in_stock", + default_value, + stock, + webhooks=webhooks, + channel_slug=None, ) def product_variant_stock_updated(self, stock: "Stock", webhooks=None): default_value = None self.__run_method_on_plugins( - "product_variant_stock_updated", default_value, stock, webhooks=webhooks + "product_variant_stock_updated", + default_value, + stock, + webhooks=webhooks, + channel_slug=None, ) def product_variant_metadata_updated(self, product_variant: "ProductVariant"): default_value = None self.__run_method_on_plugins( - "product_variant_metadata_updated", default_value, product_variant + "product_variant_metadata_updated", + default_value, + product_variant, + channel_slug=None, ) def product_export_completed(self, export: "ExportFile"): default_value = None return self.__run_method_on_plugins( - "product_export_completed", default_value, export + "product_export_completed", default_value, export, channel_slug=None ) def order_created(self, order: "Order"): @@ -804,7 +883,7 @@ def order_created(self, order: "Order"): def event_delivery_retry(self, event_delivery: "EventDelivery"): default_value = None return self.__run_method_on_plugins( - "event_delivery_retry", default_value, event_delivery + "event_delivery_retry", default_value, event_delivery, channel_slug=None ) def order_confirmed(self, order: "Order"): @@ -834,73 +913,100 @@ def draft_order_deleted(self, order: "Order"): def sale_created(self, sale: "Promotion", current_catalogue): default_value = None return self.__run_method_on_plugins( - "sale_created", default_value, sale, current_catalogue + "sale_created", default_value, sale, current_catalogue, channel_slug=None ) def sale_deleted(self, sale: "Promotion", previous_catalogue, webhooks=None): default_value = None return self.__run_method_on_plugins( - "sale_deleted", default_value, sale, previous_catalogue, webhooks=webhooks + "sale_deleted", + default_value, + sale, + previous_catalogue, + webhooks=webhooks, + channel_slug=None, ) def sale_updated(self, sale: "Promotion", previous_catalogue, current_catalogue): default_value = None return self.__run_method_on_plugins( - "sale_updated", default_value, sale, previous_catalogue, current_catalogue + "sale_updated", + default_value, + sale, + previous_catalogue, + current_catalogue, + channel_slug=None, ) def sale_toggle(self, sale: "Promotion", catalogue, webhooks=None): default_value = None return self.__run_method_on_plugins( - "sale_toggle", default_value, sale, catalogue, webhooks=webhooks + "sale_toggle", + default_value, + sale, + catalogue, + webhooks=webhooks, + channel_slug=None, ) def promotion_created(self, promotion: "Promotion"): default_value = None return self.__run_method_on_plugins( - "promotion_created", default_value, promotion + "promotion_created", default_value, promotion, channel_slug=None ) def promotion_updated(self, promotion: "Promotion"): default_value = None return self.__run_method_on_plugins( - "promotion_updated", default_value, promotion + "promotion_updated", default_value, promotion, channel_slug=None ) def promotion_deleted(self, promotion: "Promotion", webhooks=None): default_value = None return self.__run_method_on_plugins( - "promotion_deleted", default_value, promotion, webhooks=webhooks + "promotion_deleted", + default_value, + promotion, + webhooks=webhooks, + channel_slug=None, ) def promotion_started(self, promotion: "Promotion", webhooks=None): default_value = None return self.__run_method_on_plugins( - "promotion_started", default_value, promotion, webhooks=webhooks + "promotion_started", + default_value, + promotion, + webhooks=webhooks, + channel_slug=None, ) def promotion_ended(self, promotion: "Promotion", webhooks=None): default_value = None return self.__run_method_on_plugins( - "promotion_ended", default_value, promotion, webhooks=webhooks + "promotion_ended", + default_value, + promotion, + webhooks=webhooks, + channel_slug=None, ) def promotion_rule_created(self, promotion_rule: "PromotionRule"): default_value = None return self.__run_method_on_plugins( - "promotion_rule_created", default_value, promotion_rule + "promotion_rule_created", default_value, promotion_rule, channel_slug=None ) def promotion_rule_updated(self, promotion_rule: "PromotionRule"): default_value = None return self.__run_method_on_plugins( - "promotion_rule_updated", default_value, promotion_rule + "promotion_rule_updated", default_value, promotion_rule, channel_slug=None ) def promotion_rule_deleted(self, promotion_rule: "PromotionRule"): default_value = None return self.__run_method_on_plugins( - "promotion_rule_deleted", default_value, promotion_rule + "promotion_rule_deleted", default_value, promotion_rule, channel_slug=None ) def invoice_request( @@ -999,12 +1105,17 @@ def order_fulfilled(self, order: "Order"): def order_metadata_updated(self, order: "Order"): default_value = None return self.__run_method_on_plugins( - "order_metadata_updated", default_value, order + "order_metadata_updated", + default_value, + order, + channel_slug=order.channel.slug, ) def order_bulk_created(self, orders: list["Order"]): default_value = None - return self.__run_method_on_plugins("order_bulk_created", default_value, orders) + return self.__run_method_on_plugins( + "order_bulk_created", default_value, orders, channel_slug=None + ) def fulfillment_created( self, fulfillment: "Fulfillment", notify_customer: Optional[bool] = True @@ -1042,7 +1153,10 @@ def fulfillment_approved( def fulfillment_metadata_updated(self, fulfillment: "Fulfillment"): default_value = None return self.__run_method_on_plugins( - "fulfillment_metadata_updated", default_value, fulfillment + "fulfillment_metadata_updated", + default_value, + fulfillment, + channel_slug=fulfillment.order.channel.slug, ) def tracking_number_updated(self, fulfillment: "Fulfillment"): @@ -1084,55 +1198,64 @@ def checkout_fully_paid(self, checkout: "Checkout"): def checkout_metadata_updated(self, checkout: "Checkout"): default_value = None return self.__run_method_on_plugins( - "checkout_metadata_updated", default_value, checkout + "checkout_metadata_updated", + default_value, + checkout, + channel_slug=checkout.channel.slug, ) def page_created(self, page: "Page"): default_value = None - return self.__run_method_on_plugins("page_created", default_value, page) + return self.__run_method_on_plugins( + "page_created", default_value, page, channel_slug=None + ) def page_updated(self, page: "Page"): default_value = None - return self.__run_method_on_plugins("page_updated", default_value, page) + return self.__run_method_on_plugins( + "page_updated", default_value, page, channel_slug=None + ) def page_deleted(self, page: "Page"): default_value = None - return self.__run_method_on_plugins("page_deleted", default_value, page) + return self.__run_method_on_plugins( + "page_deleted", default_value, page, channel_slug=None + ) def page_type_created(self, page_type: "PageType"): default_value = None return self.__run_method_on_plugins( - "page_type_created", default_value, page_type + "page_type_created", default_value, page_type, channel_slug=None ) def page_type_updated(self, page_type: "PageType"): default_value = None return self.__run_method_on_plugins( - "page_type_updated", default_value, page_type + "page_type_updated", default_value, page_type, channel_slug=None ) def page_type_deleted(self, page_type: "PageType"): default_value = None return self.__run_method_on_plugins( - "page_type_deleted", default_value, page_type + "page_type_deleted", default_value, page_type, channel_slug=None ) def permission_group_created(self, group: "Group"): default_value = None return self.__run_method_on_plugins( - "permission_group_created", default_value, group + "permission_group_created", default_value, group, channel_slug=None ) def permission_group_updated(self, group: "Group"): default_value = None return self.__run_method_on_plugins( - "permission_group_updated", default_value, group + "permission_group_updated", default_value, group, channel_slug=None ) def permission_group_deleted(self, group: "Group"): default_value = None return self.__run_method_on_plugins( - "permission_group_deleted", default_value, group + "permission_group_deleted", default_value, group, channel_slug=None ) def transaction_charge_requested( @@ -1211,12 +1334,17 @@ def transaction_process_session( def transaction_item_metadata_updated(self, transaction_item: "TransactionItem"): default_value = None return self.__run_method_on_plugins( - "transaction_item_metadata_updated", default_value, transaction_item + "transaction_item_metadata_updated", + default_value, + transaction_item, + channel_slug=None, ) def account_confirmed(self, user: "User"): default_value = None - return self.__run_method_on_plugins("account_confirmed", default_value, user) + return self.__run_method_on_plugins( + "account_confirmed", default_value, user, channel_slug=None + ) def account_confirmation_requested( self, user: "User", channel_slug: str, token: str, redirect_url: Optional[str] @@ -1229,6 +1357,7 @@ def account_confirmation_requested( channel_slug, token=token, redirect_url=redirect_url, + channel_slug=channel_slug, ) def account_change_email_requested( @@ -1248,6 +1377,7 @@ def account_change_email_requested( token=token, redirect_url=redirect_url, new_email=new_email, + channel_slug=channel_slug, ) def account_email_changed( @@ -1259,6 +1389,7 @@ def account_email_changed( "account_email_changed", default_value, user, + channel_slug=None, ) def account_set_password_requested( @@ -1276,6 +1407,7 @@ def account_set_password_requested( channel_slug, token=token, redirect_url=redirect_url, + channel_slug=channel_slug, ) def account_delete_requested( @@ -1289,132 +1421,195 @@ def account_delete_requested( channel_slug, token=token, redirect_url=redirect_url, + channel_slug=channel_slug, ) def account_deleted(self, user: "User"): default_value = None - return self.__run_method_on_plugins("account_deleted", default_value, user) + return self.__run_method_on_plugins( + "account_deleted", default_value, user, channel_slug=None + ) def address_created(self, address: "Address"): default_value = None - return self.__run_method_on_plugins("address_created", default_value, address) + return self.__run_method_on_plugins( + "address_created", default_value, address, channel_slug=None + ) def address_updated(self, address: "Address"): default_value = None - return self.__run_method_on_plugins("address_updated", default_value, address) + return self.__run_method_on_plugins( + "address_updated", default_value, address, channel_slug=None + ) def address_deleted(self, address: "Address"): default_value = None - return self.__run_method_on_plugins("address_deleted", default_value, address) + return self.__run_method_on_plugins( + "address_deleted", default_value, address, channel_slug=None + ) def app_installed(self, app: "App"): default_value = None - return self.__run_method_on_plugins("app_installed", default_value, app) + return self.__run_method_on_plugins( + "app_installed", default_value, app, channel_slug=None + ) def app_updated(self, app: "App"): default_value = None - return self.__run_method_on_plugins("app_updated", default_value, app) + return self.__run_method_on_plugins( + "app_updated", default_value, app, channel_slug=None + ) def app_deleted(self, app: "App"): default_value = None - return self.__run_method_on_plugins("app_deleted", default_value, app) + return self.__run_method_on_plugins( + "app_deleted", default_value, app, channel_slug=None + ) def app_status_changed(self, app: "App"): default_value = None - return self.__run_method_on_plugins("app_status_changed", default_value, app) + return self.__run_method_on_plugins( + "app_status_changed", default_value, app, channel_slug=None + ) def attribute_created(self, attribute: "Attribute"): default_value = None return self.__run_method_on_plugins( - "attribute_created", default_value, attribute + "attribute_created", default_value, attribute, channel_slug=None ) def attribute_updated(self, attribute: "Attribute", webhooks=None): default_value = None return self.__run_method_on_plugins( - "attribute_updated", default_value, attribute, webhooks=webhooks + "attribute_updated", + default_value, + attribute, + webhooks=webhooks, + channel_slug=None, ) def attribute_deleted(self, attribute: "Attribute", webhooks=None): default_value = None return self.__run_method_on_plugins( - "attribute_deleted", default_value, attribute, webhooks=webhooks + "attribute_deleted", + default_value, + attribute, + webhooks=webhooks, + channel_slug=None, ) def attribute_value_created(self, attribute_value: "AttributeValue", webhooks=None): default_value = None return self.__run_method_on_plugins( - "attribute_value_created", default_value, attribute_value, webhooks=webhooks + "attribute_value_created", + default_value, + attribute_value, + webhooks=webhooks, + channel_slug=None, ) def attribute_value_updated(self, attribute_value: "AttributeValue"): default_value = None return self.__run_method_on_plugins( - "attribute_value_updated", default_value, attribute_value + "attribute_value_updated", default_value, attribute_value, channel_slug=None ) def attribute_value_deleted(self, attribute_value: "AttributeValue", webhooks=None): default_value = None return self.__run_method_on_plugins( - "attribute_value_deleted", default_value, attribute_value, webhooks=webhooks + "attribute_value_deleted", + default_value, + attribute_value, + webhooks=webhooks, + channel_slug=None, ) def category_created(self, category: "Category"): default_value = None - return self.__run_method_on_plugins("category_created", default_value, category) + return self.__run_method_on_plugins( + "category_created", default_value, category, channel_slug=None + ) def category_updated(self, category: "Category"): default_value = None - return self.__run_method_on_plugins("category_updated", default_value, category) + return self.__run_method_on_plugins( + "category_updated", default_value, category, channel_slug=None + ) def category_deleted(self, category: "Category", webhooks=None): default_value = None return self.__run_method_on_plugins( - "category_deleted", default_value, category, webhooks=webhooks + "category_deleted", + default_value, + category, + webhooks=webhooks, + channel_slug=None, ) def channel_created(self, channel: "Channel"): default_value = None - return self.__run_method_on_plugins("channel_created", default_value, channel) + return self.__run_method_on_plugins( + "channel_created", default_value, channel, channel_slug=channel.slug + ) def channel_updated(self, channel: "Channel", webhooks=None): default_value = None return self.__run_method_on_plugins( - "channel_updated", default_value, channel, webhooks=webhooks + "channel_updated", + default_value, + channel, + webhooks=webhooks, + channel_slug=channel.slug, ) def channel_deleted(self, channel: "Channel"): default_value = None - return self.__run_method_on_plugins("channel_deleted", default_value, channel) + return self.__run_method_on_plugins( + "channel_deleted", default_value, channel, channel_slug=None + ) def channel_status_changed(self, channel: "Channel"): default_value = None return self.__run_method_on_plugins( - "channel_status_changed", default_value, channel + "channel_status_changed", default_value, channel, channel_slug=channel.slug ) def channel_metadata_updated(self, channel: "Channel"): default_value = None return self.__run_method_on_plugins( - "channel_metadata_updated", default_value, channel + "channel_metadata_updated", + default_value, + channel, + channel_slug=channel.slug, ) def gift_card_created(self, gift_card: "GiftCard", webhooks=None): default_value = None return self.__run_method_on_plugins( - "gift_card_created", default_value, gift_card, webhooks=webhooks + "gift_card_created", + default_value, + gift_card, + webhooks=webhooks, + channel_slug=None, ) def gift_card_updated(self, gift_card: "GiftCard"): default_value = None return self.__run_method_on_plugins( - "gift_card_updated", default_value, gift_card + "gift_card_updated", + default_value, + gift_card, + channel_slug=None, ) def gift_card_deleted(self, gift_card: "GiftCard", webhooks=None): default_value = None return self.__run_method_on_plugins( - "gift_card_deleted", default_value, gift_card, webhooks=webhooks + "gift_card_deleted", + default_value, + gift_card, + webhooks=webhooks, + channel_slug=None, ) def gift_card_sent(self, gift_card: "GiftCard", channel_slug: str, email: str): @@ -1425,112 +1620,184 @@ def gift_card_sent(self, gift_card: "GiftCard", channel_slug: str, email: str): gift_card, channel_slug, email, + channel_slug=channel_slug, ) def gift_card_status_changed(self, gift_card: "GiftCard", webhooks=None): default_value = None return self.__run_method_on_plugins( - "gift_card_status_changed", default_value, gift_card, webhooks=webhooks + "gift_card_status_changed", + default_value, + gift_card, + webhooks=webhooks, + channel_slug=None, ) def gift_card_metadata_updated(self, gift_card: "GiftCard"): default_value = None return self.__run_method_on_plugins( - "gift_card_metadata_updated", default_value, gift_card + "gift_card_metadata_updated", + default_value, + gift_card, + channel_slug=None, ) def gift_card_export_completed(self, export: "ExportFile"): default_value = None return self.__run_method_on_plugins( - "gift_card_export_completed", default_value, export + "gift_card_export_completed", + default_value, + export, + channel_slug=None, ) def menu_created(self, menu: "Menu"): default_value = None - return self.__run_method_on_plugins("menu_created", default_value, menu) + return self.__run_method_on_plugins( + "menu_created", + default_value, + menu, + channel_slug=None, + ) def menu_updated(self, menu: "Menu"): default_value = None - return self.__run_method_on_plugins("menu_updated", default_value, menu) + return self.__run_method_on_plugins( + "menu_updated", + default_value, + menu, + channel_slug=None, + ) def menu_deleted(self, menu: "Menu", webhooks=None): default_value = None return self.__run_method_on_plugins( - "menu_deleted", default_value, menu, webhooks=webhooks + "menu_deleted", + default_value, + menu, + webhooks=webhooks, + channel_slug=None, ) def menu_item_created(self, menu_item: "MenuItem"): default_value = None return self.__run_method_on_plugins( - "menu_item_created", default_value, menu_item + "menu_item_created", + default_value, + menu_item, + channel_slug=None, ) def menu_item_updated(self, menu_item: "MenuItem"): default_value = None return self.__run_method_on_plugins( - "menu_item_updated", default_value, menu_item + "menu_item_updated", + default_value, + menu_item, + channel_slug=None, ) def menu_item_deleted(self, menu_item: "MenuItem", webhooks=None): default_value = None return self.__run_method_on_plugins( - "menu_item_deleted", default_value, menu_item, webhooks=webhooks + "menu_item_deleted", + default_value, + menu_item, + webhooks=webhooks, + channel_slug=None, ) def shipping_price_created(self, shipping_method: "ShippingMethod"): default_value = None return self.__run_method_on_plugins( - "shipping_price_created", default_value, shipping_method + "shipping_price_created", + default_value, + shipping_method, + channel_slug=None, ) def shipping_price_updated(self, shipping_method: "ShippingMethod"): default_value = None return self.__run_method_on_plugins( - "shipping_price_updated", default_value, shipping_method + "shipping_price_updated", + default_value, + shipping_method, + channel_slug=None, ) def shipping_price_deleted(self, shipping_method: "ShippingMethod", webhooks=None): default_value = None return self.__run_method_on_plugins( - "shipping_price_deleted", default_value, shipping_method, webhooks=webhooks + "shipping_price_deleted", + default_value, + shipping_method, + webhooks=webhooks, + channel_slug=None, ) def shipping_zone_created(self, shipping_zone: "ShippingZone"): default_value = None return self.__run_method_on_plugins( - "shipping_zone_created", default_value, shipping_zone + "shipping_zone_created", + default_value, + shipping_zone, + channel_slug=None, ) def shipping_zone_updated(self, shipping_zone: "ShippingZone"): default_value = None return self.__run_method_on_plugins( - "shipping_zone_updated", default_value, shipping_zone + "shipping_zone_updated", + default_value, + shipping_zone, + channel_slug=None, ) def shipping_zone_deleted(self, shipping_zone: "ShippingZone", webhooks=None): default_value = None return self.__run_method_on_plugins( - "shipping_zone_deleted", default_value, shipping_zone, webhooks=webhooks + "shipping_zone_deleted", + default_value, + shipping_zone, + webhooks=webhooks, + channel_slug=None, ) def shipping_zone_metadata_updated(self, shipping_zone: "ShippingZone"): default_value = None return self.__run_method_on_plugins( - "shipping_zone_metadata_updated", default_value, shipping_zone + "shipping_zone_metadata_updated", + default_value, + shipping_zone, + channel_slug=None, ) def staff_created(self, staff_user: "User"): default_value = None - return self.__run_method_on_plugins("staff_created", default_value, staff_user) + return self.__run_method_on_plugins( + "staff_created", + default_value, + staff_user, + channel_slug=None, + ) def staff_updated(self, staff_user: "User"): default_value = None - return self.__run_method_on_plugins("staff_updated", default_value, staff_user) + return self.__run_method_on_plugins( + "staff_updated", + default_value, + staff_user, + channel_slug=None, + ) def staff_deleted(self, staff_user: "User", webhooks=None): default_value = None return self.__run_method_on_plugins( - "staff_deleted", default_value, staff_user, webhooks=webhooks + "staff_deleted", + default_value, + staff_user, + webhooks=webhooks, + channel_slug=None, ) def staff_set_password_requested( @@ -1544,6 +1811,7 @@ def staff_set_password_requested( channel_slug, token=token, redirect_url=redirect_url, + channel_slug=channel_slug, ) def thumbnail_created( @@ -1552,79 +1820,124 @@ def thumbnail_created( ): default_value = None return self.__run_method_on_plugins( - "thumbnail_created", default_value, thumbnail + "thumbnail_created", + default_value, + thumbnail, + channel_slug=None, ) def warehouse_created(self, warehouse: "Warehouse"): default_value = None return self.__run_method_on_plugins( - "warehouse_created", default_value, warehouse + "warehouse_created", + default_value, + warehouse, + channel_slug=None, ) def warehouse_updated(self, warehouse: "Warehouse"): default_value = None return self.__run_method_on_plugins( - "warehouse_updated", default_value, warehouse + "warehouse_updated", + default_value, + warehouse, + channel_slug=None, ) def warehouse_deleted(self, warehouse: "Warehouse"): default_value = None return self.__run_method_on_plugins( - "warehouse_deleted", default_value, warehouse + "warehouse_deleted", + default_value, + warehouse, + channel_slug=None, ) def warehouse_metadata_updated(self, warehouse: "Warehouse"): default_value = None return self.__run_method_on_plugins( - "warehouse_metadata_updated", default_value, warehouse + "warehouse_metadata_updated", + default_value, + warehouse, + channel_slug=None, ) def voucher_created(self, voucher: "Voucher", code: str): default_value = None return self.__run_method_on_plugins( - "voucher_created", default_value, voucher, code + "voucher_created", + default_value, + voucher, + code, + channel_slug=None, ) def voucher_updated(self, voucher: "Voucher", code: str): default_value = None return self.__run_method_on_plugins( - "voucher_updated", default_value, voucher, code + "voucher_updated", + default_value, + voucher, + code, + channel_slug=None, ) def voucher_deleted(self, voucher: "Voucher", code: str, webhooks=None): default_value = None return self.__run_method_on_plugins( - "voucher_deleted", default_value, voucher, code, webhooks=webhooks + "voucher_deleted", + default_value, + voucher, + code, + webhooks=webhooks, + channel_slug=None, ) def voucher_codes_created(self, voucher_codes: list["VoucherCode"], webhooks=None): default_value = None return self.__run_method_on_plugins( - "voucher_codes_created", default_value, voucher_codes, webhooks=webhooks + "voucher_codes_created", + default_value, + voucher_codes, + webhooks=webhooks, + channel_slug=None, ) def voucher_codes_deleted(self, voucher_codes: list["VoucherCode"], webhooks=None): default_value = None return self.__run_method_on_plugins( - "voucher_codes_deleted", default_value, voucher_codes, webhooks=webhooks + "voucher_codes_deleted", + default_value, + voucher_codes, + webhooks=webhooks, + channel_slug=None, ) def voucher_metadata_updated(self, voucher: "Voucher"): default_value = None return self.__run_method_on_plugins( - "voucher_metadata_updated", default_value, voucher + "voucher_metadata_updated", + default_value, + voucher, + channel_slug=None, ) def voucher_code_export_completed(self, export: "ExportFile"): default_value = None return self.__run_method_on_plugins( - "voucher_code_export_completed", default_value, export + "voucher_code_export_completed", + default_value, + export, + channel_slug=None, ) def shop_metadata_updated(self, shop: "SiteSettings"): default_value = None return self.__run_method_on_plugins( - "shop_metadata_updated", default_value, shop + "shop_metadata_updated", + default_value, + shop, + channel_slug=None, ) def initialize_payment( @@ -1641,6 +1954,7 @@ def initialize_payment( method_name, previous_value=default_value, payment_data=payment_data, + channel_slug=channel_slug, ) def authorize_payment( @@ -1716,7 +2030,7 @@ def list_payment_sources( self, gateway: str, customer_id: str, - channel_slug: str, + channel_slug: Optional[str], ) -> list["CustomerSource"]: default_value: list = [] gtw = self.get_plugin(gateway, channel_slug=channel_slug) @@ -1727,13 +2041,15 @@ def list_payment_sources( raise Exception(f"Payment plugin {gateway} is inaccessible!") def list_stored_payment_methods( - self, list_stored_payment_methods_data: "ListStoredPaymentMethodsRequestData" + self, + list_stored_payment_methods_data: "ListStoredPaymentMethodsRequestData", ) -> list["PaymentMethodData"]: default_value: list = [] return self.__run_method_on_plugins( "list_stored_payment_methods", default_value, list_stored_payment_methods_data, + channel_slug=list_stored_payment_methods_data.channel.slug, ) def stored_payment_method_request_delete( @@ -1748,6 +2064,7 @@ def stored_payment_method_request_delete( "stored_payment_method_request_delete", default_response, request_delete_data, + channel_slug=request_delete_data.channel.slug, ) return response @@ -1765,6 +2082,7 @@ def payment_gateway_initialize_tokenization( "payment_gateway_initialize_tokenization", default_response, request_data, + channel_slug=request_data.channel.slug, ) return response @@ -1782,6 +2100,7 @@ def payment_method_initialize_tokenization( "payment_method_initialize_tokenization", default_response, request_data, + channel_slug=request_data.channel.slug, ) return response @@ -1799,21 +2118,30 @@ def payment_method_process_tokenization( "payment_method_process_tokenization", default_response, request_data, + channel_slug=request_data.channel.slug, ) return response def translation_created(self, translation: "Translation"): default_value = None return self.__run_method_on_plugins( - "translation_created", default_value, translation + "translation_created", default_value, translation, channel_slug=None ) def translation_updated(self, translation: "Translation"): default_value = None return self.__run_method_on_plugins( - "translation_updated", default_value, translation + "translation_updated", default_value, translation, channel_slug=None ) + def get_all_plugins(self, active_only=False): + if not self.loaded_all_channels: + channels = Channel.objects.using(self.database).all() + for channel in channels.iterator(): + self._ensure_channel_plugins_loaded(channel.slug, channel=channel) + self.loaded_all_channels = True + return self.get_plugins(active_only=active_only) + def get_plugins( self, channel_slug: Optional[str] = None, @@ -1821,9 +2149,11 @@ def get_plugins( plugin_ids: Optional[list[str]] = None, ) -> list["BasePlugin"]: """Return list of plugins for a given channel.""" - if channel_slug: + if channel_slug is not None: + self._ensure_channel_plugins_loaded(channel_slug) plugins = self.plugins_per_channel[channel_slug] else: + self._ensure_channel_plugins_loaded(None) plugins = self.all_plugins if active_only: @@ -1843,7 +2173,17 @@ def list_payment_gateways( active_only: bool = True, ) -> list["PaymentGateway"]: channel_slug = checkout_info.channel.slug if checkout_info else channel_slug - plugins = self.get_plugins(channel_slug=channel_slug, active_only=active_only) + + if channel_slug is not None: + plugins = self.get_plugins( + channel_slug=channel_slug, active_only=active_only + ) + else: + # Backwards compatibility for: https://github.com/saleor/saleor/pull/15769/ + # Load all channel plugins and global plugins if channel_slug is None, as + # it was done before the mentioned PR. + plugins = self.get_all_plugins(active_only=active_only) + payment_plugins = [ plugin for plugin in plugins if "process_payment" in type(plugin).__dict__ ] @@ -1938,9 +2278,11 @@ def __run_plugin_method_until_first_success( self, method_name: str, *args, - channel_slug: Optional[str] = None, + channel_slug: Optional[str], + plugins: Optional[list["BasePlugin"]] = None, ): - plugins = self.get_plugins(channel_slug=channel_slug) + if plugins is None: + plugins = self.get_plugins(channel_slug=channel_slug) for plugin in plugins: result = self.__run_method_on_single_plugin( plugin, method_name, None, *args @@ -1971,18 +2313,17 @@ def _get_all_plugin_configs(self): # FIXME these methods should be more generic - def assign_tax_code_to_object_meta(self, obj: "TaxClass", tax_code: Optional[str]): - default_value = None - return self.__run_method_on_plugins( - "assign_tax_code_to_object_meta", default_value, obj, tax_code - ) - def get_tax_code_from_object_meta( - self, obj: Union["Product", "ProductType", "TaxClass"] + self, + obj: Union["Product", "ProductType", "TaxClass"], + channel_slug: Optional[str], ) -> TaxType: default_value = TaxType(code="", description="") return self.__run_method_on_plugins( - "get_tax_code_from_object_meta", default_value, obj + "get_tax_code_from_object_meta", + default_value, + obj, + channel_slug=channel_slug, ) def save_plugin_configuration( @@ -1997,7 +2338,7 @@ def save_plugin_configuration( return None else: channel = None - plugins = self.global_plugins + plugins = self.get_plugins() for plugin in plugins: if plugin.PLUGIN_ID == plugin_id: @@ -2038,6 +2379,10 @@ def webhook_endpoint_without_channel( default_value = HttpResponseNotFound() plugin = self.get_plugin(plugin_id) + if not plugin: + self.get_all_plugins() + plugin = self.get_plugin(plugin_id) + if not plugin: return default_value return self.__run_method_on_single_plugin( @@ -2045,7 +2390,7 @@ def webhook_endpoint_without_channel( ) def webhook( - self, request: SaleorContext, plugin_id: str, channel_slug: Optional[str] = None + self, request: SaleorContext, plugin_id: str, channel_slug: Optional[str] ) -> HttpResponse: split_path = request.path.split(plugin_id, maxsplit=1) path = None @@ -2124,7 +2469,9 @@ def external_refresh( def authenticate_user(self, request: SaleorContext) -> Optional["User"]: """Authenticate user which should be assigned to the request.""" default_value = None - return self.__run_method_on_plugins("authenticate_user", default_value, request) + return self.__run_method_on_plugins( + "authenticate_user", default_value, request, channel_slug=None + ) def external_logout( self, plugin_id: str, data: dict, request: SaleorContext @@ -2164,6 +2511,7 @@ def excluded_shipping_methods_for_order( def excluded_shipping_methods_for_checkout( self, checkout: "Checkout", + channel: "Channel", available_shipping_methods: list["ShippingMethodData"], ) -> list[ExcludedShippingMethod]: return self.__run_method_on_plugins( @@ -2171,7 +2519,7 @@ def excluded_shipping_methods_for_checkout( [], checkout, available_shipping_methods, - channel_slug=checkout.channel.slug, + channel_slug=channel.slug, ) def perform_mutation( @@ -2194,11 +2542,13 @@ def perform_mutation( root=root, info=info, data=data, + channel_slug=None, ) def is_event_active_for_any_plugin( self, event: str, channel_slug: Optional[str] = None ) -> bool: + self._ensure_channel_plugins_loaded(channel_slug) """Check if any plugin supports defined event.""" plugins = ( self.plugins_per_channel[channel_slug] if channel_slug else self.all_plugins @@ -2206,16 +2556,14 @@ def is_event_active_for_any_plugin( only_active_plugins = [plugin for plugin in plugins if plugin.active] return any([plugin.is_event_active(event) for plugin in only_active_plugins]) - def _get_channel_map(self): - return { - channel.pk: channel - for channel in Channel.objects.using(self.database).all().iterator() - } - def get_plugins_manager( allow_replica: bool, requestor_getter: Optional[Callable[[], "Requestor"]] = None, ) -> PluginsManager: with opentracing.global_tracer().start_active_span("get_plugins_manager"): - return PluginsManager(settings.PLUGINS, requestor_getter, allow_replica) + if allow_replica: + return PluginsManager(settings.PLUGINS, requestor_getter, allow_replica) + else: + with allow_writer(): + return PluginsManager(settings.PLUGINS, requestor_getter, allow_replica) diff --git a/saleor/plugins/openid_connect/tests/conftest.py b/saleor/plugins/openid_connect/tests/conftest.py index bf665936d36..a2ca010de2d 100644 --- a/saleor/plugins/openid_connect/tests/conftest.py +++ b/saleor/plugins/openid_connect/tests/conftest.py @@ -86,6 +86,7 @@ def fun( }, ) manager = get_plugins_manager(allow_replica=False) + manager.get_all_plugins() return manager.all_plugins[0] return fun diff --git a/saleor/plugins/sendgrid/tests/conftest.py b/saleor/plugins/sendgrid/tests/conftest.py index d45a7810cf0..fc8b1b69679 100644 --- a/saleor/plugins/sendgrid/tests/conftest.py +++ b/saleor/plugins/sendgrid/tests/conftest.py @@ -102,6 +102,7 @@ def fun( }, ) manager = get_plugins_manager(allow_replica=False) + manager.get_all_plugins() return manager.plugins_per_channel[channel_USD.slug][0] return fun diff --git a/saleor/plugins/tests/sample_plugins.py b/saleor/plugins/tests/sample_plugins.py index ba777550bed..52d678edeb9 100644 --- a/saleor/plugins/tests/sample_plugins.py +++ b/saleor/plugins/tests/sample_plugins.py @@ -178,9 +178,6 @@ def calculate_order_line_unit( def get_tax_rate_type_choices(self, previous_value): return [TaxType(code="123", description="abc")] - def show_taxes_on_storefront(self, previous_value: bool) -> bool: - return True - def external_authentication_url( self, data: dict, request: WSGIRequest, previous_value ) -> dict: @@ -480,6 +477,7 @@ class SampleAuthorizationPlugin(BasePlugin): PLUGIN_ID = "saleor.sample.authorization" PLUGIN_NAME = "SampleAuthorization" DEFAULT_ACTIVE = True + CONFIGURATION_PER_CHANNEL = False def authenticate_user(self, request, previous_value) -> Optional[User]: # This function will be mocked in test diff --git a/saleor/plugins/tests/test_manager.py b/saleor/plugins/tests/test_manager.py index 3ab391901a0..2ee6b88e011 100644 --- a/saleor/plugins/tests/test_manager.py +++ b/saleor/plugins/tests/test_manager.py @@ -54,6 +54,7 @@ def test_get_plugins_manager(settings): plugin_path = "saleor.plugins.tests.sample_plugins.PluginSample" settings.PLUGINS = [plugin_path] manager = get_plugins_manager(allow_replica=False) + manager.get_all_plugins() assert isinstance(manager, PluginsManager) assert len(manager.all_plugins) == 1 @@ -66,6 +67,8 @@ def test_manager_with_default_configuration_for_channel_plugins( "saleor.plugins.tests.sample_plugins.PluginSample", ] manager = get_plugins_manager(allow_replica=False) + manager.get_all_plugins() + assert len(manager.global_plugins) == 1 assert isinstance(manager.global_plugins[0], PluginSample) assert {channel_PLN.slug, channel_USD.slug} == set( @@ -92,6 +95,7 @@ def test_manager_with_channel_plugins( "saleor.plugins.tests.sample_plugins.ChannelPluginSample", ] manager = get_plugins_manager(allow_replica=False) + manager.get_all_plugins() assert {channel_PLN.slug, channel_USD.slug} == set( manager.plugins_per_channel.keys() @@ -539,14 +543,6 @@ def sample_none_data(obj): return None -@pytest.mark.parametrize( - ("plugins", "show_taxes"), - [(["saleor.plugins.tests.sample_plugins.PluginSample"], True), ([], False)], -) -def test_manager_show_taxes_on_storefront(plugins, show_taxes): - assert show_taxes == PluginsManager(plugins=plugins).show_taxes_on_storefront() - - @pytest.mark.parametrize( ("plugins", "expected_tax_data"), [ @@ -818,7 +814,7 @@ def test_manager_webhook(rf): plugin_path = "/webhook/paid" request = rf.post(path=f"/plugins/{PluginSample.PLUGIN_ID}{plugin_path}") - response = manager.webhook(request, PluginSample.PLUGIN_ID) + response = manager.webhook(request, PluginSample.PLUGIN_ID, channel_slug=None) assert isinstance(response, JsonResponse) assert response.status_code == 200 assert response.content.decode() == json.dumps({"received": True, "paid": True}) @@ -832,7 +828,7 @@ def test_manager_webhook_plugin_doesnt_have_webhook_support(rf): manager = PluginsManager(plugins=plugins) plugin_path = "/webhook/paid" request = rf.post(path=f"/plugins/{PluginInactive.PLUGIN_ID}{plugin_path}") - response = manager.webhook(request, PluginSample.PLUGIN_ID) + response = manager.webhook(request, PluginSample.PLUGIN_ID, channel_slug=None) assert isinstance(response, HttpResponseNotFound) assert response.status_code == 404 @@ -845,7 +841,7 @@ def test_manager_inncorrect_plugin(rf): manager = PluginsManager(plugins=plugins) plugin_path = "/webhook/paid" request = rf.post(path=f"/plugins/incorrect.plugin.id{plugin_path}") - response = manager.webhook(request, "incorrect.plugin.id") + response = manager.webhook(request, "incorrect.plugin.id", channel_slug=None) assert isinstance(response, HttpResponseNotFound) assert response.status_code == 404 @@ -966,8 +962,7 @@ def test_list_external_authentications_active_only(channel_USD): def test_run_method_on_plugins_default_value(plugins_manager): default_value = "default" value = plugins_manager._PluginsManager__run_method_on_plugins( - method_name="test_method", - default_value=default_value, + method_name="test_method", default_value=default_value, channel_slug=None ) assert value == default_value @@ -980,6 +975,7 @@ def test_run_method_on_plugins_default_value_when_not_existing_method_is_called( value = all_plugins_manager._PluginsManager__run_method_on_plugins( method_name="test_method", default_value=default_value, + channel_slug=channel_USD.slug, ) assert value == default_value @@ -992,6 +988,7 @@ def test_run_method_on_plugins_value_overridden_by_plugin_method( value = all_plugins_manager._PluginsManager__run_method_on_plugins( method_name="get_supported_currencies", default_value="default_value", + channel_slug=channel_USD.slug, ) assert value == expected @@ -1006,6 +1003,7 @@ def test_run_method_on_plugins_only_on_active_ones( all_plugins_manager._PluginsManager__run_method_on_plugins( method_name="test_method_name", default_value="default_value", + channel_slug=channel_USD.slug, ) active_plugins_count = len(ACTIVE_PLUGINS) @@ -1017,7 +1015,11 @@ def test_run_method_on_plugins_only_on_active_ones( assert mocked_method.call_count == active_plugins_count called_plugins_id = [arg.args[0].PLUGIN_ID for arg in mocked_method.call_args_list] - expected_active_plugins_id = [p.PLUGIN_ID for p in ACTIVE_PLUGINS] + expected_active_plugins_id = [ + p.PLUGIN_ID + for p in all_plugins_manager.plugins_per_channel[channel_USD.slug] + if p.active + ] assert called_plugins_id == expected_active_plugins_id @@ -1039,7 +1041,7 @@ def test_run_method_on_plugins_only_for_given_channel( ) pln_plugin = ChannelPluginSample(active=True, channel=channel_PLN, configuration=[]) - plugins_manager.plugins = [usd_plugin_1, usd_plugin_2, pln_plugin] + plugins_manager.all_plugins = [usd_plugin_1, usd_plugin_2, pln_plugin] plugins_manager.plugins_per_channel[channel_USD.slug] = [usd_plugin_1, usd_plugin_2] plugins_manager.plugins_per_channel[channel_PLN.slug] = [pln_plugin] @@ -1125,6 +1127,7 @@ def fake_request_getter(mock): manager = PluginsManager( plugins=plugins, requestor_getter=partial(fake_request_getter, user_mock) ) + manager.get_all_plugins() user_mock.assert_not_called() plugin = manager.all_plugins.pop() @@ -1548,19 +1551,34 @@ def test_plugin_manager_database(allow_replica, expected_connection_name): assert manager.database == expected_connection_name -def test_plugin_manager__get_channel_map( - channel_USD, channel_PLN, channel_JPY, other_channel_USD -): +def test_loaded_all_channels(channel_USD, channel_PLN, django_assert_num_queries): + # given + plugins = [ + "saleor.plugins.tests.sample_plugins.PluginSample", + ] + manager = PluginsManager(plugins=plugins) + + # then + with django_assert_num_queries(4): + plugins = manager.get_all_plugins() + assert plugins + + with django_assert_num_queries(0): + plugins = manager.get_all_plugins() + assert plugins + + +def test_get_plugin_invalid_channel(): # given - manager = PluginsManager(["saleor.plugins.tests.sample_plugins.PluginSample"]) + plugins = [ + "saleor.plugins.tests.sample_plugins.PluginSample", + ] + manager = PluginsManager(plugins=plugins) # when - channel_map = manager._get_channel_map() + plugin = manager.get_plugin( + "saleor.plugins.tests.sample_plugins.PluginSample", channel_slug="invalid" + ) # then - assert channel_map == { - channel_USD.pk: channel_USD, - channel_PLN.pk: channel_PLN, - channel_JPY.pk: channel_JPY, - other_channel_USD.pk: other_channel_USD, - } + assert plugin is None diff --git a/saleor/plugins/user_email/tests/conftest.py b/saleor/plugins/user_email/tests/conftest.py index 1e288faa247..415f9cdded0 100644 --- a/saleor/plugins/user_email/tests/conftest.py +++ b/saleor/plugins/user_email/tests/conftest.py @@ -240,6 +240,7 @@ def fun( }, ) manager = get_plugins_manager(allow_replica=False) + manager.get_all_plugins() return manager.plugins_per_channel[channel_USD.slug][0] return fun diff --git a/saleor/plugins/user_email/tests/test_plugin.py b/saleor/plugins/user_email/tests/test_plugin.py index c115046a11d..0bb68742d69 100644 --- a/saleor/plugins/user_email/tests/test_plugin.py +++ b/saleor/plugins/user_email/tests/test_plugin.py @@ -302,6 +302,7 @@ def test_plugin_manager_doesnt_load_email_templates_from_db( ): settings.PLUGINS = ["saleor.plugins.user_email.plugin.UserEmailPlugin"] manager = get_plugins_manager(allow_replica=False) + manager.get_all_plugins() plugin = manager.all_plugins[0] email_config_item = None diff --git a/saleor/plugins/webhook/conftest.py b/saleor/plugins/webhook/conftest.py index 13423ed389b..3ad3b31085d 100644 --- a/saleor/plugins/webhook/conftest.py +++ b/saleor/plugins/webhook/conftest.py @@ -71,6 +71,7 @@ def webhook_plugin(settings): def factory(): settings.PLUGINS = ["saleor.plugins.webhook.plugin.WebhookPlugin"] manager = get_plugins_manager(allow_replica=False) + manager.get_all_plugins() return manager.global_plugins[0] return factory diff --git a/saleor/plugins/webhook/plugin.py b/saleor/plugins/webhook/plugin.py index d3fcff9b704..04b6d9c93ab 100644 --- a/saleor/plugins/webhook/plugin.py +++ b/saleor/plugins/webhook/plugin.py @@ -3119,7 +3119,9 @@ def get_shipping_methods_for_checkout( return methods def get_tax_code_from_object_meta( - self, obj: Union["Product", "ProductType", "TaxClass"], previous_value: Any + self, + obj: Union["Product", "ProductType", "TaxClass"], + previous_value: Any, ): """Get tax code and description for a product or product type. diff --git a/saleor/plugins/webhook/tests/conftest.py b/saleor/plugins/webhook/tests/conftest.py index 91258c140dd..1edbd8341c9 100644 --- a/saleor/plugins/webhook/tests/conftest.py +++ b/saleor/plugins/webhook/tests/conftest.py @@ -18,6 +18,7 @@ def webhook_plugin(settings): def factory() -> WebhookPlugin: settings.PLUGINS = ["saleor.plugins.webhook.plugin.WebhookPlugin"] manager = get_plugins_manager(allow_replica=False) + manager.get_all_plugins() return manager.global_plugins[0] return factory diff --git a/saleor/plugins/webhook/tests/subscription_webhooks/fixtures.py b/saleor/plugins/webhook/tests/subscription_webhooks/fixtures.py index 0ab0f291ce3..00c6e43bebf 100644 --- a/saleor/plugins/webhook/tests/subscription_webhooks/fixtures.py +++ b/saleor/plugins/webhook/tests/subscription_webhooks/fixtures.py @@ -942,6 +942,110 @@ def subscription_translation_created_webhook(subscription_webhook): ) +@pytest.fixture +def subscription_product_translation_created_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_CREATED_PRODUCT, + WebhookEventAsyncType.TRANSLATION_CREATED, + ) + + +@pytest.fixture +def subscription_product_variant_translation_created_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_CREATED_PRODUCT_VARIANT, + WebhookEventAsyncType.TRANSLATION_CREATED, + ) + + +@pytest.fixture +def subscription_collection_translation_created_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_CREATED_COLLECTION, + WebhookEventAsyncType.TRANSLATION_CREATED, + ) + + +@pytest.fixture +def subscription_category_translation_created_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_CREATED_CATEGORY, + WebhookEventAsyncType.TRANSLATION_CREATED, + ) + + +@pytest.fixture +def subscription_attribute_translation_created_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_CREATED_ATTRIBUTE, + WebhookEventAsyncType.TRANSLATION_CREATED, + ) + + +@pytest.fixture +def subscription_attribute_value_translation_created_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_CREATED_ATTRIBUTE_VALUE, + WebhookEventAsyncType.TRANSLATION_CREATED, + ) + + +@pytest.fixture +def subscription_page_translation_created_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_CREATED_PAGE, + WebhookEventAsyncType.TRANSLATION_CREATED, + ) + + +@pytest.fixture +def subscription_shipping_method_translation_created_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_CREATED_SHIPPING_METHOD, + WebhookEventAsyncType.TRANSLATION_CREATED, + ) + + +@pytest.fixture +def subscription_promotion_translation_created_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_CREATED_PROMOTION, + WebhookEventAsyncType.TRANSLATION_CREATED, + ) + + +@pytest.fixture +def subscription_sale_translation_created_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_CREATED_SALE, + WebhookEventAsyncType.TRANSLATION_CREATED, + ) + + +@pytest.fixture +def subscription_promotion_rule_translation_created_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_CREATED_PROMOTION_RULE, + WebhookEventAsyncType.TRANSLATION_CREATED, + ) + + +@pytest.fixture +def subscription_voucher_translation_created_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_CREATED_VOUCHER, + WebhookEventAsyncType.TRANSLATION_CREATED, + ) + + +@pytest.fixture +def subscription_menu_item_translation_created_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_CREATED_MENU_ITEM, + WebhookEventAsyncType.TRANSLATION_CREATED, + ) + + @pytest.fixture def subscription_translation_updated_webhook(subscription_webhook): return subscription_webhook( @@ -950,6 +1054,102 @@ def subscription_translation_updated_webhook(subscription_webhook): ) +@pytest.fixture +def subscription_product_translation_updated_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_UPDATED_PRODUCT, + WebhookEventAsyncType.TRANSLATION_UPDATED, + ) + + +@pytest.fixture +def subscription_product_variant_translation_updated_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_UPDATED_PRODUCT_VARIANT, + WebhookEventAsyncType.TRANSLATION_UPDATED, + ) + + +@pytest.fixture +def subscription_collection_translation_updated_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_UPDATED_COLLECTION, + WebhookEventAsyncType.TRANSLATION_UPDATED, + ) + + +@pytest.fixture +def subscription_category_translation_updated_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_UPDATED_CATEGORY, + WebhookEventAsyncType.TRANSLATION_UPDATED, + ) + + +@pytest.fixture +def subscription_attribute_translation_updated_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_UPDATED_ATTRIBUTE, + WebhookEventAsyncType.TRANSLATION_UPDATED, + ) + + +@pytest.fixture +def subscription_attribute_value_translation_updated_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_UPDATED_ATTRIBUTE_VALUE, + WebhookEventAsyncType.TRANSLATION_UPDATED, + ) + + +@pytest.fixture +def subscription_page_translation_updated_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_UPDATED_PAGE, + WebhookEventAsyncType.TRANSLATION_UPDATED, + ) + + +@pytest.fixture +def subscription_shipping_method_translation_updated_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_UPDATED_SHIPPING_METHOD, + WebhookEventAsyncType.TRANSLATION_UPDATED, + ) + + +@pytest.fixture +def subscription_promotion_translation_updated_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_UPDATED_PROMOTION, + WebhookEventAsyncType.TRANSLATION_UPDATED, + ) + + +@pytest.fixture +def subscription_promotion_rule_translation_updated_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_UPDATED_PROMOTION_RULE, + WebhookEventAsyncType.TRANSLATION_UPDATED, + ) + + +@pytest.fixture +def subscription_voucher_translation_updated_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_UPDATED_VOUCHER, + WebhookEventAsyncType.TRANSLATION_UPDATED, + ) + + +@pytest.fixture +def subscription_menu_item_translation_updated_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_UPDATED_MENU_ITEM, + WebhookEventAsyncType.TRANSLATION_UPDATED, + ) + + @pytest.fixture def subscription_warehouse_created_webhook(subscription_webhook): return subscription_webhook( diff --git a/saleor/plugins/webhook/tests/subscription_webhooks/subscription_queries.py b/saleor/plugins/webhook/tests/subscription_webhooks/subscription_queries.py index 91c3138406b..2a141a325bd 100644 --- a/saleor/plugins/webhook/tests/subscription_webhooks/subscription_queries.py +++ b/saleor/plugins/webhook/tests/subscription_webhooks/subscription_queries.py @@ -1,4 +1,9 @@ +from enum import Enum + +from graphene.utils.str_converters import to_snake_case + from .....graphql.tests.queries import fragments +from .....graphql.webhook.subscription_types import TRANSLATIONS_TYPES_MAP ACCOUNT_CONFIRMATION_REQUESTED = ( fragments.CUSTOMER_DETAILS @@ -2035,6 +2040,140 @@ } """ +TranslationTypes = Enum( + "TranslationTypes", + { + to_snake_case(k.__name__).upper(): k.__name__ + for k in TRANSLATIONS_TYPES_MAP.keys() + }, +) + + +class TranslationQueryType(Enum): + CREATED = "TranslationCreated" + UPDATED = "TranslationUpdated" + + +def build_translation_query( + type: TranslationTypes, + query_type: TranslationQueryType, + translated_object_id: str, +) -> str: + return ( # noqa: UP031 + """ + subscription { + event { + ... on %s { + translation { + ... on %s { + id + name + translatableContent { + %s + name + } + } + } + } + } + } + """ + ) % (query_type.value, type.value, translated_object_id) + + +TRANSLATION_CREATED_PRODUCT = build_translation_query( + TranslationTypes.PRODUCT_TRANSLATION, + TranslationQueryType.CREATED, + "productId", +) +TRANSLATION_CREATED_PRODUCT_VARIANT = build_translation_query( + TranslationTypes.PRODUCT_VARIANT_TRANSLATION, + TranslationQueryType.CREATED, + "productVariantId", +) +TRANSLATION_CREATED_COLLECTION = build_translation_query( + TranslationTypes.COLLECTION_TRANSLATION, + TranslationQueryType.CREATED, + "collectionId", +) +TRANSLATION_CREATED_CATEGORY = build_translation_query( + TranslationTypes.CATEGORY_TRANSLATION, + TranslationQueryType.CREATED, + "categoryId", +) +TRANSLATION_CREATED_ATTRIBUTE = build_translation_query( + TranslationTypes.ATTRIBUTE_TRANSLATION, + TranslationQueryType.CREATED, + "attributeId", +) +TRANSLATION_CREATED_ATTRIBUTE_VALUE = build_translation_query( + TranslationTypes.ATTRIBUTE_VALUE_TRANSLATION, + TranslationQueryType.CREATED, + "attributeValueId", +) +TRANSLATION_CREATED_SHIPPING_METHOD = build_translation_query( + TranslationTypes.SHIPPING_METHOD_TRANSLATION, + TranslationQueryType.CREATED, + "shippingMethodId", +) +TRANSLATION_CREATED_PROMOTION = build_translation_query( + TranslationTypes.PROMOTION_TRANSLATION, + TranslationQueryType.CREATED, + "promotionId", +) +TRANSLATION_CREATED_PROMOTION_RULE = build_translation_query( + TranslationTypes.PROMOTION_RULE_TRANSLATION, + TranslationQueryType.CREATED, + "promotionRuleId", +) +TRANSLATION_CREATED_VOUCHER = build_translation_query( + TranslationTypes.VOUCHER_TRANSLATION, + TranslationQueryType.CREATED, + "voucherId", +) +TRANSLATION_CREATED_MENU_ITEM = build_translation_query( + TranslationTypes.MENU_ITEM_TRANSLATION, + TranslationQueryType.CREATED, + "menuItemId", +) +TRANSLATION_CREATED_PAGE = """ + subscription { + event { + ... on TranslationCreated { + translation { + ... on PageTranslation { + id + title + translatableContent { + pageId + title + } + } + } + } + } + } +""" +TRANSLATION_CREATED_SALE = """ + subscription { + event { + ... on TranslationCreated { + translation { + ... on SaleTranslation { + __typename + id + name + translatableContent { + saleId + name + } + } + } + } + } + } +""" + TRANSLATION_UPDATED = """ subscription { event { @@ -2087,6 +2226,80 @@ } """ +TRANSLATION_UPDATED_PRODUCT = build_translation_query( + TranslationTypes.PRODUCT_TRANSLATION, + TranslationQueryType.UPDATED, + "productId", +) +TRANSLATION_UPDATED_PRODUCT_VARIANT = build_translation_query( + TranslationTypes.PRODUCT_VARIANT_TRANSLATION, + TranslationQueryType.UPDATED, + "productVariantId", +) +TRANSLATION_UPDATED_COLLECTION = build_translation_query( + TranslationTypes.COLLECTION_TRANSLATION, + TranslationQueryType.UPDATED, + "collectionId", +) +TRANSLATION_UPDATED_CATEGORY = build_translation_query( + TranslationTypes.CATEGORY_TRANSLATION, + TranslationQueryType.UPDATED, + "categoryId", +) +TRANSLATION_UPDATED_ATTRIBUTE = build_translation_query( + TranslationTypes.ATTRIBUTE_TRANSLATION, + TranslationQueryType.UPDATED, + "attributeId", +) +TRANSLATION_UPDATED_ATTRIBUTE_VALUE = build_translation_query( + TranslationTypes.ATTRIBUTE_VALUE_TRANSLATION, + TranslationQueryType.UPDATED, + "attributeValueId", +) +TRANSLATION_UPDATED_SHIPPING_METHOD = build_translation_query( + TranslationTypes.SHIPPING_METHOD_TRANSLATION, + TranslationQueryType.UPDATED, + "shippingMethodId", +) +TRANSLATION_UPDATED_PROMOTION = build_translation_query( + TranslationTypes.PROMOTION_TRANSLATION, + TranslationQueryType.UPDATED, + "promotionId", +) +TRANSLATION_UPDATED_PROMOTION_RULE = build_translation_query( + TranslationTypes.PROMOTION_RULE_TRANSLATION, + TranslationQueryType.UPDATED, + "promotionRuleId", +) +TRANSLATION_UPDATED_VOUCHER = build_translation_query( + TranslationTypes.VOUCHER_TRANSLATION, + TranslationQueryType.UPDATED, + "voucherId", +) +TRANSLATION_UPDATED_MENU_ITEM = build_translation_query( + TranslationTypes.MENU_ITEM_TRANSLATION, + TranslationQueryType.UPDATED, + "menuItemId", +) +TRANSLATION_UPDATED_PAGE = """ + subscription { + event { + ... on TranslationUpdated { + translation { + ... on PageTranslation { + id + title + translatableContent { + pageId + title + } + } + } + } + } + } +""" + TEST_VALID_SUBSCRIPTION = """ subscription{ event{ diff --git a/saleor/plugins/webhook/tests/subscription_webhooks/test_create_deliveries_for_translation_subscription.py b/saleor/plugins/webhook/tests/subscription_webhooks/test_create_deliveries_for_translation_subscription.py index 6d265780511..03070f4cd05 100644 --- a/saleor/plugins/webhook/tests/subscription_webhooks/test_create_deliveries_for_translation_subscription.py +++ b/saleor/plugins/webhook/tests/subscription_webhooks/test_create_deliveries_for_translation_subscription.py @@ -9,18 +9,31 @@ def test_translation_created_product( - product_translation_fr, subscription_translation_created_webhook + product_translation_fr, subscription_product_translation_created_webhook ): - webhooks = [subscription_translation_created_webhook] + webhooks = [subscription_product_translation_created_webhook] event_type = WebhookEventAsyncType.TRANSLATION_CREATED translation_id = graphene.Node.to_global_id( "ProductTranslation", product_translation_fr.id ) + product = product_translation_fr.product + product_id = graphene.Node.to_global_id("Product", product.id) deliveries = create_deliveries_for_subscriptions( event_type, product_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": product_translation_fr.name, + "translatableContent": { + "productId": product_id, + "name": product.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -28,18 +41,31 @@ def test_translation_created_product( def test_translation_created_product_variant( - variant_translation_fr, subscription_translation_created_webhook + variant_translation_fr, subscription_product_variant_translation_created_webhook ): - webhooks = [subscription_translation_created_webhook] + webhooks = [subscription_product_variant_translation_created_webhook] event_type = WebhookEventAsyncType.TRANSLATION_CREATED translation_id = graphene.Node.to_global_id( "ProductVariantTranslation", variant_translation_fr.id ) + variant = variant_translation_fr.product_variant + variant_id = graphene.Node.to_global_id("ProductVariant", variant.id) deliveries = create_deliveries_for_subscriptions( event_type, variant_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": variant_translation_fr.name, + "translatableContent": { + "productVariantId": variant_id, + "name": variant.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -47,18 +73,31 @@ def test_translation_created_product_variant( def test_translation_created_collection( - collection_translation_fr, subscription_translation_created_webhook + collection_translation_fr, subscription_collection_translation_created_webhook ): - webhooks = [subscription_translation_created_webhook] + webhooks = [subscription_collection_translation_created_webhook] event_type = WebhookEventAsyncType.TRANSLATION_CREATED translation_id = graphene.Node.to_global_id( "CollectionTranslation", collection_translation_fr.id ) + collection = collection_translation_fr.collection + collection_id = graphene.Node.to_global_id("Collection", collection.id) deliveries = create_deliveries_for_subscriptions( event_type, collection_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": collection_translation_fr.name, + "translatableContent": { + "collectionId": collection_id, + "name": collection.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -66,18 +105,31 @@ def test_translation_created_collection( def test_translation_created_category( - category_translation_fr, subscription_translation_created_webhook + category_translation_fr, subscription_category_translation_created_webhook ): - webhooks = [subscription_translation_created_webhook] + webhooks = [subscription_category_translation_created_webhook] event_type = WebhookEventAsyncType.TRANSLATION_CREATED translation_id = graphene.Node.to_global_id( "CategoryTranslation", category_translation_fr.id ) + category = category_translation_fr.category + category_id = graphene.Node.to_global_id("Category", category.id) deliveries = create_deliveries_for_subscriptions( event_type, category_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": category_translation_fr.name, + "translatableContent": { + "categoryId": category_id, + "name": category.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -85,18 +137,31 @@ def test_translation_created_category( def test_translation_created_attribute( - translated_attribute, subscription_translation_created_webhook + translated_attribute, subscription_attribute_translation_created_webhook ): - webhooks = [subscription_translation_created_webhook] + webhooks = [subscription_attribute_translation_created_webhook] event_type = WebhookEventAsyncType.TRANSLATION_CREATED translation_id = graphene.Node.to_global_id( "AttributeTranslation", translated_attribute.id ) + attribute = translated_attribute.attribute + attribute_id = graphene.Node.to_global_id("Attribute", attribute.id) deliveries = create_deliveries_for_subscriptions( event_type, translated_attribute, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": translated_attribute.name, + "translatableContent": { + "attributeId": attribute_id, + "name": attribute.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -104,18 +169,33 @@ def test_translation_created_attribute( def test_translation_created_attribute_value( - translated_attribute_value, subscription_translation_created_webhook + translated_attribute_value, subscription_attribute_value_translation_created_webhook ): - webhooks = [subscription_translation_created_webhook] + webhooks = [subscription_attribute_value_translation_created_webhook] event_type = WebhookEventAsyncType.TRANSLATION_CREATED translation_id = graphene.Node.to_global_id( "AttributeValueTranslation", translated_attribute_value.id ) + attribute_value = translated_attribute_value.attribute_value + attribute_value_id = graphene.Node.to_global_id( + "AttributeValue", attribute_value.id + ) deliveries = create_deliveries_for_subscriptions( event_type, translated_attribute_value, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": translated_attribute_value.name, + "translatableContent": { + "attributeValueId": attribute_value_id, + "name": attribute_value.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -123,18 +203,31 @@ def test_translation_created_attribute_value( def test_translation_created_page( - page_translation_fr, subscription_translation_created_webhook + page_translation_fr, subscription_page_translation_created_webhook ): - webhooks = [subscription_translation_created_webhook] + webhooks = [subscription_page_translation_created_webhook] event_type = WebhookEventAsyncType.TRANSLATION_CREATED translation_id = graphene.Node.to_global_id( "PageTranslation", page_translation_fr.id ) + page = page_translation_fr.page + page_id = graphene.Node.to_global_id("Page", page.id) deliveries = create_deliveries_for_subscriptions( event_type, page_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "title": page_translation_fr.title, + "translatableContent": { + "pageId": page_id, + "title": page.title, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -142,18 +235,34 @@ def test_translation_created_page( def test_translation_created_shipping_method( - shipping_method_translation_fr, subscription_translation_created_webhook + shipping_method_translation_fr, + subscription_shipping_method_translation_created_webhook, ): - webhooks = [subscription_translation_created_webhook] + webhooks = [subscription_shipping_method_translation_created_webhook] event_type = WebhookEventAsyncType.TRANSLATION_CREATED translation_id = graphene.Node.to_global_id( "ShippingMethodTranslation", shipping_method_translation_fr.id ) + shipping_method = shipping_method_translation_fr.shipping_method + shipping_method_id = graphene.Node.to_global_id( + "ShippingMethodType", shipping_method.id + ) deliveries = create_deliveries_for_subscriptions( event_type, shipping_method_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": shipping_method_translation_fr.name, + "translatableContent": { + "shippingMethodId": shipping_method_id, + "name": shipping_method.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -161,13 +270,15 @@ def test_translation_created_shipping_method( def test_translation_created_promotion( - promotion_translation_fr, subscription_translation_created_webhook + promotion_translation_fr, subscription_promotion_translation_created_webhook ): - webhooks = [subscription_translation_created_webhook] + webhooks = [subscription_promotion_translation_created_webhook] event_type = WebhookEventAsyncType.TRANSLATION_CREATED translation_id = graphene.Node.to_global_id( "PromotionTranslation", promotion_translation_fr.id ) + promotion = promotion_translation_fr.promotion + promotion_id = graphene.Node.to_global_id("Promotion", promotion.id) deliveries = create_deliveries_for_subscriptions( event_type, promotion_translation_fr, webhooks ) @@ -176,7 +287,11 @@ def test_translation_created_promotion( { "translation": { "id": translation_id, - "__typename": "PromotionTranslation", + "name": promotion_translation_fr.name, + "translatableContent": { + "promotionId": promotion_id, + "name": promotion.name, + }, } } ) @@ -188,19 +303,26 @@ def test_translation_created_promotion( def test_translation_created_promotion_converted_from_sale( promotion_converted_from_sale_translation_fr, - subscription_translation_created_webhook, + subscription_sale_translation_created_webhook, ): translation = promotion_converted_from_sale_translation_fr - webhooks = [subscription_translation_created_webhook] + webhooks = [subscription_sale_translation_created_webhook] event_type = WebhookEventAsyncType.TRANSLATION_CREATED translation_id = graphene.Node.to_global_id("SaleTranslation", translation.id) + promotion = promotion_converted_from_sale_translation_fr.promotion + promotion_id = graphene.Node.to_global_id("Sale", promotion.old_sale_id) deliveries = create_deliveries_for_subscriptions(event_type, translation, webhooks) expected_payload = json.dumps( { "translation": { - "id": translation_id, "__typename": "SaleTranslation", + "id": translation_id, + "name": promotion_converted_from_sale_translation_fr.name, + "translatableContent": { + "saleId": promotion_id, + "name": promotion.name, + }, } } ) @@ -211,18 +333,32 @@ def test_translation_created_promotion_converted_from_sale( def test_translation_created_promotion_rule( - promotion_rule_translation_fr, subscription_translation_created_webhook + promotion_rule_translation_fr, + subscription_promotion_rule_translation_created_webhook, ): - webhooks = [subscription_translation_created_webhook] + webhooks = [subscription_promotion_rule_translation_created_webhook] event_type = WebhookEventAsyncType.TRANSLATION_CREATED translation_id = graphene.Node.to_global_id( "PromotionRuleTranslation", promotion_rule_translation_fr.id ) + promotion_rule = promotion_rule_translation_fr.promotion_rule + promotion_rule_id = graphene.Node.to_global_id("PromotionRule", promotion_rule.id) deliveries = create_deliveries_for_subscriptions( event_type, promotion_rule_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": promotion_rule_translation_fr.name, + "translatableContent": { + "promotionRuleId": promotion_rule_id, + "name": promotion_rule.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -230,18 +366,31 @@ def test_translation_created_promotion_rule( def test_translation_created_voucher( - voucher_translation_fr, subscription_translation_created_webhook + voucher_translation_fr, subscription_voucher_translation_created_webhook ): - webhooks = [subscription_translation_created_webhook] + webhooks = [subscription_voucher_translation_created_webhook] event_type = WebhookEventAsyncType.TRANSLATION_CREATED translation_id = graphene.Node.to_global_id( "VoucherTranslation", voucher_translation_fr.id ) + voucher = voucher_translation_fr.voucher + voucher_id = graphene.Node.to_global_id("Voucher", voucher.id) deliveries = create_deliveries_for_subscriptions( event_type, voucher_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": voucher_translation_fr.name, + "translatableContent": { + "voucherId": voucher_id, + "name": voucher.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -249,18 +398,31 @@ def test_translation_created_voucher( def test_translation_created_menu_item( - menu_item_translation_fr, subscription_translation_created_webhook + menu_item_translation_fr, subscription_menu_item_translation_created_webhook ): - webhooks = [subscription_translation_created_webhook] + webhooks = [subscription_menu_item_translation_created_webhook] event_type = WebhookEventAsyncType.TRANSLATION_CREATED translation_id = graphene.Node.to_global_id( "MenuItemTranslation", menu_item_translation_fr.id ) + menu_item = menu_item_translation_fr.menu_item + menu_item_id = graphene.Node.to_global_id("MenuItem", menu_item.id) deliveries = create_deliveries_for_subscriptions( event_type, menu_item_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": menu_item_translation_fr.name, + "translatableContent": { + "menuItemId": menu_item_id, + "name": menu_item.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -268,18 +430,31 @@ def test_translation_created_menu_item( def test_translation_updated_product( - product_translation_fr, subscription_translation_updated_webhook + product_translation_fr, subscription_product_translation_updated_webhook ): - webhooks = [subscription_translation_updated_webhook] + webhooks = [subscription_product_translation_updated_webhook] event_type = WebhookEventAsyncType.TRANSLATION_UPDATED translation_id = graphene.Node.to_global_id( "ProductTranslation", product_translation_fr.id ) + product = product_translation_fr.product + product_id = graphene.Node.to_global_id("Product", product.id) deliveries = create_deliveries_for_subscriptions( event_type, product_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": product_translation_fr.name, + "translatableContent": { + "productId": product_id, + "name": product.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -287,18 +462,31 @@ def test_translation_updated_product( def test_translation_updated_product_variant( - variant_translation_fr, subscription_translation_updated_webhook + variant_translation_fr, subscription_product_variant_translation_updated_webhook ): - webhooks = [subscription_translation_updated_webhook] + webhooks = [subscription_product_variant_translation_updated_webhook] event_type = WebhookEventAsyncType.TRANSLATION_UPDATED translation_id = graphene.Node.to_global_id( "ProductVariantTranslation", variant_translation_fr.id ) + variant = variant_translation_fr.product_variant + variant_id = graphene.Node.to_global_id("ProductVariant", variant.id) deliveries = create_deliveries_for_subscriptions( event_type, variant_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": variant_translation_fr.name, + "translatableContent": { + "productVariantId": variant_id, + "name": variant.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -306,18 +494,31 @@ def test_translation_updated_product_variant( def test_translation_updated_collection( - collection_translation_fr, subscription_translation_updated_webhook + collection_translation_fr, subscription_collection_translation_updated_webhook ): - webhooks = [subscription_translation_updated_webhook] + webhooks = [subscription_collection_translation_updated_webhook] event_type = WebhookEventAsyncType.TRANSLATION_UPDATED translation_id = graphene.Node.to_global_id( "CollectionTranslation", collection_translation_fr.id ) + collection = collection_translation_fr.collection + collection_id = graphene.Node.to_global_id("Collection", collection.id) deliveries = create_deliveries_for_subscriptions( event_type, collection_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": collection_translation_fr.name, + "translatableContent": { + "collectionId": collection_id, + "name": collection.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -325,18 +526,31 @@ def test_translation_updated_collection( def test_translation_updated_category( - category_translation_fr, subscription_translation_updated_webhook + category_translation_fr, subscription_category_translation_updated_webhook ): - webhooks = [subscription_translation_updated_webhook] + webhooks = [subscription_category_translation_updated_webhook] event_type = WebhookEventAsyncType.TRANSLATION_UPDATED translation_id = graphene.Node.to_global_id( "CategoryTranslation", category_translation_fr.id ) + category = category_translation_fr.category + category_id = graphene.Node.to_global_id("Category", category.id) deliveries = create_deliveries_for_subscriptions( event_type, category_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": category_translation_fr.name, + "translatableContent": { + "categoryId": category_id, + "name": category.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -344,18 +558,31 @@ def test_translation_updated_category( def test_translation_updated_attribute( - translated_attribute, subscription_translation_updated_webhook + translated_attribute, subscription_attribute_translation_updated_webhook ): - webhooks = [subscription_translation_updated_webhook] + webhooks = [subscription_attribute_translation_updated_webhook] event_type = WebhookEventAsyncType.TRANSLATION_UPDATED translation_id = graphene.Node.to_global_id( "AttributeTranslation", translated_attribute.id ) + attribute = translated_attribute.attribute + attribute_id = graphene.Node.to_global_id("Attribute", attribute.id) deliveries = create_deliveries_for_subscriptions( event_type, translated_attribute, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": translated_attribute.name, + "translatableContent": { + "attributeId": attribute_id, + "name": attribute.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -363,18 +590,33 @@ def test_translation_updated_attribute( def test_translation_updated_attribute_value( - translated_attribute_value, subscription_translation_updated_webhook + translated_attribute_value, subscription_attribute_value_translation_updated_webhook ): - webhooks = [subscription_translation_updated_webhook] + webhooks = [subscription_attribute_value_translation_updated_webhook] event_type = WebhookEventAsyncType.TRANSLATION_UPDATED translation_id = graphene.Node.to_global_id( "AttributeValueTranslation", translated_attribute_value.id ) + attribute_value = translated_attribute_value.attribute_value + attribute_value_id = graphene.Node.to_global_id( + "AttributeValue", attribute_value.id + ) deliveries = create_deliveries_for_subscriptions( event_type, translated_attribute_value, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": translated_attribute_value.name, + "translatableContent": { + "attributeValueId": attribute_value_id, + "name": attribute_value.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -382,18 +624,31 @@ def test_translation_updated_attribute_value( def test_translation_updated_page( - page_translation_fr, subscription_translation_updated_webhook + page_translation_fr, subscription_page_translation_updated_webhook ): - webhooks = [subscription_translation_updated_webhook] + webhooks = [subscription_page_translation_updated_webhook] event_type = WebhookEventAsyncType.TRANSLATION_UPDATED translation_id = graphene.Node.to_global_id( "PageTranslation", page_translation_fr.id ) + page = page_translation_fr.page + page_id = graphene.Node.to_global_id("Page", page.id) deliveries = create_deliveries_for_subscriptions( event_type, page_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "title": page_translation_fr.title, + "translatableContent": { + "pageId": page_id, + "title": page.title, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -401,18 +656,34 @@ def test_translation_updated_page( def test_translation_updated_shipping_method( - shipping_method_translation_fr, subscription_translation_updated_webhook + shipping_method_translation_fr, + subscription_shipping_method_translation_updated_webhook, ): - webhooks = [subscription_translation_updated_webhook] + webhooks = [subscription_shipping_method_translation_updated_webhook] event_type = WebhookEventAsyncType.TRANSLATION_UPDATED translation_id = graphene.Node.to_global_id( "ShippingMethodTranslation", shipping_method_translation_fr.id ) + shipping_method = shipping_method_translation_fr.shipping_method + shipping_method_id = graphene.Node.to_global_id( + "ShippingMethodType", shipping_method.id + ) deliveries = create_deliveries_for_subscriptions( event_type, shipping_method_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": shipping_method_translation_fr.name, + "translatableContent": { + "shippingMethodId": shipping_method_id, + "name": shipping_method.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -420,19 +691,30 @@ def test_translation_updated_shipping_method( def test_translation_updated_promotion( - promotion_translation_fr, subscription_translation_updated_webhook + promotion_translation_fr, subscription_promotion_translation_updated_webhook ): - webhooks = [subscription_translation_updated_webhook] + webhooks = [subscription_promotion_translation_updated_webhook] event_type = WebhookEventAsyncType.TRANSLATION_UPDATED translation_id = graphene.Node.to_global_id( "PromotionTranslation", promotion_translation_fr.id ) + promotion = promotion_translation_fr.promotion + promotion_id = graphene.Node.to_global_id("Promotion", promotion.id) deliveries = create_deliveries_for_subscriptions( event_type, promotion_translation_fr, webhooks ) expected_payload = json.dumps( - {"translation": {"id": translation_id, "__typename": "PromotionTranslation"}} + { + "translation": { + "id": translation_id, + "name": promotion_translation_fr.name, + "translatableContent": { + "promotionId": promotion_id, + "name": promotion.name, + }, + } + } ) assert deliveries[0].payload.payload == expected_payload @@ -441,18 +723,32 @@ def test_translation_updated_promotion( def test_translation_updated_promotion_rule( - promotion_rule_translation_fr, subscription_translation_updated_webhook + promotion_rule_translation_fr, + subscription_promotion_rule_translation_updated_webhook, ): - webhooks = [subscription_translation_updated_webhook] + webhooks = [subscription_promotion_rule_translation_updated_webhook] event_type = WebhookEventAsyncType.TRANSLATION_UPDATED translation_id = graphene.Node.to_global_id( "PromotionRuleTranslation", promotion_rule_translation_fr.id ) + promotion_rule = promotion_rule_translation_fr.promotion_rule + promotion_rule_id = graphene.Node.to_global_id("PromotionRule", promotion_rule.id) deliveries = create_deliveries_for_subscriptions( event_type, promotion_rule_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": promotion_rule_translation_fr.name, + "translatableContent": { + "promotionRuleId": promotion_rule_id, + "name": promotion_rule.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -460,18 +756,31 @@ def test_translation_updated_promotion_rule( def test_translation_updated_voucher( - voucher_translation_fr, subscription_translation_updated_webhook + voucher_translation_fr, subscription_voucher_translation_updated_webhook ): - webhooks = [subscription_translation_updated_webhook] + webhooks = [subscription_voucher_translation_updated_webhook] event_type = WebhookEventAsyncType.TRANSLATION_UPDATED translation_id = graphene.Node.to_global_id( "VoucherTranslation", voucher_translation_fr.id ) + voucher = voucher_translation_fr.voucher + voucher_id = graphene.Node.to_global_id("Voucher", voucher.id) deliveries = create_deliveries_for_subscriptions( event_type, voucher_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": voucher_translation_fr.name, + "translatableContent": { + "voucherId": voucher_id, + "name": voucher.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -479,18 +788,31 @@ def test_translation_updated_voucher( def test_translation_updated_menu_item( - menu_item_translation_fr, subscription_translation_updated_webhook + menu_item_translation_fr, subscription_menu_item_translation_updated_webhook ): - webhooks = [subscription_translation_updated_webhook] + webhooks = [subscription_menu_item_translation_updated_webhook] event_type = WebhookEventAsyncType.TRANSLATION_UPDATED translation_id = graphene.Node.to_global_id( "MenuItemTranslation", menu_item_translation_fr.id ) + menu_item = menu_item_translation_fr.menu_item + menu_item_id = graphene.Node.to_global_id("MenuItem", menu_item.id) deliveries = create_deliveries_for_subscriptions( event_type, menu_item_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": menu_item_translation_fr.name, + "translatableContent": { + "menuItemId": menu_item_id, + "name": menu_item.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) diff --git a/saleor/plugins/webhook/tests/test_shipping_webhook.py b/saleor/plugins/webhook/tests/test_shipping_webhook.py index b3c30d81e15..7e1f5bb3dc2 100644 --- a/saleor/plugins/webhook/tests/test_shipping_webhook.py +++ b/saleor/plugins/webhook/tests/test_shipping_webhook.py @@ -412,9 +412,10 @@ def test_order_available_shipping_methods( ): # given settings.PLUGINS = ["saleor.plugins.webhook.plugin.WebhookPlugin"] + shipping_method = order_with_lines.shipping_method def respond(*args, **kwargs): - return webhook_response(order_with_lines.shipping_method) + return webhook_response(shipping_method) mocked_webhook.side_effect = respond permission_group_manage_orders.user_set.add(staff_api_client.user) diff --git a/saleor/product/migrations/0187_merge_20231221_1030.py b/saleor/product/migrations/0187_merge_20231221_1030.py new file mode 100644 index 00000000000..cd674db847d --- /dev/null +++ b/saleor/product/migrations/0187_merge_20231221_1030.py @@ -0,0 +1,12 @@ +# Generated by Django 3.2.22 on 2023-12-21 10:30 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("product", "0186_alter_productmedia_alt"), + ("product", "0186_remove_product_charge_taxes"), + ] + + operations = [] diff --git a/saleor/product/migrations/0188_merge_20231221_1119.py b/saleor/product/migrations/0188_merge_20231221_1119.py new file mode 100644 index 00000000000..43ba261509e --- /dev/null +++ b/saleor/product/migrations/0188_merge_20231221_1119.py @@ -0,0 +1,12 @@ +# Generated by Django 3.2.22 on 2023-12-21 11:19 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("product", "0186_alter_productmedia_alt"), + ("product", "0187_auto_20230614_0838"), + ] + + operations = [] diff --git a/saleor/product/migrations/0189_merge_20240405_1121.py b/saleor/product/migrations/0189_merge_20240405_1121.py new file mode 100644 index 00000000000..c29237d1a4f --- /dev/null +++ b/saleor/product/migrations/0189_merge_20240405_1121.py @@ -0,0 +1,12 @@ +# Generated by Django 3.2.22 on 2024-04-05 11:21 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("product", "0187_merge_20231221_1030"), + ("product", "0188_merge_20231221_1119"), + ] + + operations = [] diff --git a/saleor/product/migrations/0191_merge_20240405_1125.py b/saleor/product/migrations/0191_merge_20240405_1125.py new file mode 100644 index 00000000000..f4b0edafa5a --- /dev/null +++ b/saleor/product/migrations/0191_merge_20240405_1125.py @@ -0,0 +1,12 @@ +# Generated by Django 3.2.22 on 2024-04-05 11:25 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("product", "0189_merge_20240405_1121"), + ("product", "0190_merge_20231221_1356"), + ] + + operations = [] diff --git a/saleor/product/migrations/0192_merge_20240405_1154.py b/saleor/product/migrations/0192_merge_20240405_1154.py new file mode 100644 index 00000000000..c6523f14c3a --- /dev/null +++ b/saleor/product/migrations/0192_merge_20240405_1154.py @@ -0,0 +1,12 @@ +# Generated by Django 3.2.22 on 2024-04-05 11:54 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("product", "0191_merge_20240405_1125"), + ("product", "0191_productchannellisting_discounted_price_dirty"), + ] + + operations = [] diff --git a/saleor/product/tasks.py b/saleor/product/tasks.py index 1d0e3daf3d6..5de5ee0a636 100644 --- a/saleor/product/tasks.py +++ b/saleor/product/tasks.py @@ -58,8 +58,8 @@ def _variants_in_batches(variants_qs): def _update_variants_names(instance: ProductType, saved_attributes: Iterable): """Product variant names are created from names of assigned attributes. - After change in attribute value name, for all product variants using this - attributes we need to update the names. + After change in attribute value name, we update the names for all product variants + that lack names and use these attributes. """ initial_attributes = set(instance.variant_attributes.all()) attributes_changed = initial_attributes.intersection(saved_attributes) @@ -67,6 +67,7 @@ def _update_variants_names(instance: ProductType, saved_attributes: Iterable): return variants = ProductVariant.objects.filter( + name="", product__in=instance.products.all(), product__product_type__variant_attributes__in=attributes_changed, ) diff --git a/saleor/product/tests/test_tasks.py b/saleor/product/tests/test_tasks.py index 6238e8a0784..9c1346d9b64 100644 --- a/saleor/product/tests/test_tasks.py +++ b/saleor/product/tests/test_tasks.py @@ -5,6 +5,7 @@ import pytest from django.utils import timezone +from faker import Faker from ...discount import PromotionType, RewardValueType from ...discount.models import Promotion, PromotionRule @@ -230,17 +231,23 @@ def test_recalculate_discounted_price_for_products_task_re_trigger_task( assert recalculate_discounted_price_for_products_task_mock.called -@patch("saleor.product.tasks._update_variants_names") -def test_update_variants_names( - update_variants_names_mock, product_type, size_attribute -): +def test_update_variants_names(product_variant_list, size_attribute): + # given + variant_without_name = product_variant_list[0] + variant_with_name = product_variant_list[1] + random_name = Faker().word() + variant_with_name.name = random_name + variant_with_name.save() + product = variant_without_name.product + # when - update_variants_names(product_type.id, [size_attribute.id]) + update_variants_names(product.product_type_id, [size_attribute.id]) # then - args, _ = update_variants_names_mock.call_args - assert args[0] == product_type - assert {arg.pk for arg in args[1]} == {size_attribute.pk} + variant_without_name.refresh_from_db() + variant_with_name.refresh_from_db() + assert variant_without_name.name == variant_without_name.sku + assert variant_with_name.name == random_name def test_update_variants_names_product_type_does_not_exist(caplog): diff --git a/saleor/product/utils/digital_products.py b/saleor/product/utils/digital_products.py index d6f604a37a7..913ed1ff0c1 100644 --- a/saleor/product/utils/digital_products.py +++ b/saleor/product/utils/digital_products.py @@ -4,6 +4,7 @@ from django.utils.timezone import now from ...account import events as account_events +from ...core.db.connection import allow_writer from ..models import DigitalContentUrl @@ -42,6 +43,7 @@ def digital_content_url_is_valid(content_url: DigitalContentUrl) -> bool: return True +@allow_writer() def increment_download_count(content_url: DigitalContentUrl): content_url.download_num += 1 content_url.save(update_fields=["download_num"]) diff --git a/saleor/product/views.py b/saleor/product/views.py index cc75fcdac6e..75c82ec118c 100644 --- a/saleor/product/views.py +++ b/saleor/product/views.py @@ -2,6 +2,7 @@ import os from typing import Union +from django.conf import settings from django.http import FileResponse, HttpResponseNotFound from django.shortcuts import get_object_or_404 @@ -15,7 +16,9 @@ def digital_product(request, token: str) -> Union[FileResponse, HttpResponseNotFound]: """Return the direct download link to content if given token is still valid.""" - qs = DigitalContentUrl.objects.prefetch_related("line__order__user") + qs = DigitalContentUrl.objects.using( + settings.DATABASE_CONNECTION_REPLICA_NAME + ).prefetch_related("line__order__user") content_url = get_object_or_404(qs, token=token) # type: DigitalContentUrl if not digital_content_url_is_valid(content_url): return HttpResponseNotFound("Url is not valid anymore") diff --git a/saleor/settings.py b/saleor/settings.py index 77ed50553e9..b69ebaecc38 100644 --- a/saleor/settings.py +++ b/saleor/settings.py @@ -29,6 +29,7 @@ from . import PatchedSubscriberExecutionContext, __version__ from .core.languages import LANGUAGES as CORE_LANGUAGES from .core.schedules import initiated_promotion_webhook_schedule +from .graphql.executor import patch_executor django_stubs_ext.monkeypatch() @@ -231,6 +232,12 @@ def get_url_from_env(name, *, schemes=None) -> Optional[str]: "saleor.core.middleware.jwt_refresh_token_middleware", ] +ENABLE_RESTRICT_WRITER_MIDDLEWARE = get_bool_from_env( + "ENABLE_RESTRICT_WRITER_MIDDLEWARE", False +) +if ENABLE_RESTRICT_WRITER_MIDDLEWARE: + MIDDLEWARE = ["saleor.core.db.connection.log_writer_usage_middleware"] + MIDDLEWARE + INSTALLED_APPS = [ # External apps that need to go before django's "storages", @@ -799,7 +806,8 @@ def SENTRY_INIT(dsn: str, sentry_opts: dict): # # If running locally, set: # JAEGER_AGENT_HOST=localhost -if "JAEGER_AGENT_HOST" in os.environ: +JAEGER_HOST = os.environ.get("JAEGER_AGENT_HOST") +if JAEGER_HOST: jaeger_client.Config( config={ "sampler": {"type": "const", "param": 1}, @@ -807,7 +815,7 @@ def SENTRY_INIT(dsn: str, sentry_opts: dict): "reporting_port": os.environ.get( "JAEGER_AGENT_PORT", jaeger_client.config.DEFAULT_REPORTING_PORT ), - "reporting_host": os.environ.get("JAEGER_AGENT_HOST"), + "reporting_host": JAEGER_HOST, }, "logging": get_bool_from_env("JAEGER_LOGGING", False), }, @@ -871,6 +879,8 @@ def SENTRY_INIT(dsn: str, sentry_opts: dict): executor.SubscriberExecutionContext = PatchedSubscriberExecutionContext # type: ignore +patch_executor() + # Optional queue names for Celery tasks. # Set None to route to the default queue, or a string value to use a separate one # diff --git a/saleor/shipping/models.py b/saleor/shipping/models.py index 996aed42753..bebcdc01cc1 100644 --- a/saleor/shipping/models.py +++ b/saleor/shipping/models.py @@ -24,6 +24,7 @@ from .postal_codes import filter_shipping_methods_by_postal_code_rules if TYPE_CHECKING: + from ..account.models import Address from ..checkout.fetch import CheckoutLineInfo from ..checkout.models import Checkout from ..order.fetch import OrderLineInfo @@ -165,32 +166,44 @@ def applicable_shipping_methods_for_instance( instance: Union["Checkout", "Order"], channel_id, price: Money, + shipping_address: Optional["Address"] = None, country_code: Optional[str] = None, lines: Union[ Iterable["CheckoutLineInfo"], Iterable["OrderLineInfo"], None ] = None, ): - if not instance.shipping_address: + if not shipping_address: return None + if not country_code: - # TODO: country_code should come from argument - country_code = instance.shipping_address.country.code + country_code = shipping_address.country.code + if lines is None: # TODO: lines should comes from args in get_valid_shipping_methods_for_order lines = list(instance.lines.prefetch_related("variant__product").all()) # type: ignore[misc] # this is hack # noqa: E501 instance_product_ids = { line.variant.product_id for line in lines if line.variant } + + from ..checkout.models import Checkout + + if isinstance(instance, Checkout): + from ..checkout.utils import calculate_checkout_weight + + weight = calculate_checkout_weight(lines) # type: ignore[arg-type] + else: + weight = instance.weight + applicable_methods = self.applicable_shipping_methods( price=price, channel_id=channel_id, - weight=instance.get_total_weight(lines), - country_code=country_code or instance.shipping_address.country.code, + weight=weight, + country_code=country_code, product_ids=instance_product_ids, ).prefetch_related("postal_code_rules") return filter_shipping_methods_by_postal_code_rules( - applicable_methods, instance.shipping_address + applicable_methods, shipping_address ) diff --git a/saleor/shipping/utils.py b/saleor/shipping/utils.py index 7a07c3beb97..54dba131973 100644 --- a/saleor/shipping/utils.py +++ b/saleor/shipping/utils.py @@ -1,5 +1,5 @@ import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from django_countries import countries @@ -7,6 +7,7 @@ from .interface import ShippingMethodData if TYPE_CHECKING: + from ..tax.models import TaxClass from .models import ShippingMethod, ShippingMethodChannelListing @@ -30,12 +31,18 @@ def get_countries_without_shipping_zone(): def convert_to_shipping_method_data( - shipping_method: "ShippingMethod", listing: "ShippingMethodChannelListing" + shipping_method: "ShippingMethod", + listing: "ShippingMethodChannelListing", + tax_class: Optional["TaxClass"] = None, ) -> "ShippingMethodData": price = listing.price minimum_order_price = listing.minimum_order_price maximum_order_price = listing.maximum_order_price + if not tax_class: + # Tax class should be passed as argument, this is a fallback. + tax_class = shipping_method.tax_class + return ShippingMethodData( id=str(shipping_method.id), name=shipping_method.name, @@ -48,7 +55,7 @@ def convert_to_shipping_method_data( metadata=shipping_method.metadata, private_metadata=shipping_method.private_metadata, price=price, - tax_class=shipping_method.tax_class, + tax_class=tax_class, minimum_order_price=minimum_order_price, maximum_order_price=maximum_order_price, ) diff --git a/saleor/tax/calculations/order.py b/saleor/tax/calculations/order.py index dfa1a0b0229..d05c7bf1745 100644 --- a/saleor/tax/calculations/order.py +++ b/saleor/tax/calculations/order.py @@ -28,9 +28,11 @@ def update_order_prices_with_flat_rates( database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME, ): country_code = get_order_country(order) - default_country_rate_obj = TaxClassCountryRate.objects.filter( - country=country_code, tax_class=None - ).first() + default_country_rate_obj = ( + TaxClassCountryRate.objects.using(database_connection_name) + .filter(country=country_code, tax_class=None) + .first() + ) default_tax_rate = ( default_country_rate_obj.rate if default_country_rate_obj else Decimal(0) ) @@ -66,18 +68,20 @@ def update_order_prices_with_flat_rates( ) order.shipping_tax_rate = normalize_tax_rate_for_db(shipping_tax_rate) - # Calculate order total. - order.undiscounted_total = undiscounted_subtotal + order.base_shipping_price - order.total = _calculate_order_total( - order, lines, database_connection_name=database_connection_name + _set_order_totals( + order, + lines, + prices_entered_with_tax, + database_connection_name=database_connection_name, ) -def _calculate_order_total( +def _set_order_totals( order: "Order", lines: Iterable["OrderLine"], + prices_entered_with_tax: bool, database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME, -) -> TaxedMoney: +): currency = order.currency default_value = base_calculations.base_order_total( @@ -85,16 +89,28 @@ def _calculate_order_total( ) default_value = TaxedMoney(default_value, default_value) if default_value <= zero_taxed_money(currency): - return quantize_price(default_value, currency) + order.total = quantize_price(default_value, currency) + order.undiscounted_total = quantize_price(default_value, currency) + order.subtotal = quantize_price(default_value, currency) + return - total = zero_taxed_money(currency) + subtotal = zero_taxed_money(currency) undiscounted_subtotal = zero_taxed_money(currency) for line in lines: - total += line.total_price + subtotal += line.total_price undiscounted_subtotal += line.undiscounted_total_price - total += order.shipping_price - return quantize_price(max(total, zero_taxed_money(currency)), currency) + shipping_tax_rate = order.shipping_tax_rate or 0 + undiscounted_shipping_price = calculate_flat_rate_tax( + order.base_shipping_price, + Decimal(shipping_tax_rate * 100), + prices_entered_with_tax, + ) + undiscounted_total = undiscounted_subtotal + undiscounted_shipping_price + + order.total = quantize_price(subtotal + order.shipping_price, currency) + order.undiscounted_total = quantize_price(undiscounted_total, currency) + order.subtotal = quantize_price(subtotal, currency) def _calculate_order_shipping( diff --git a/saleor/tax/tests/test_checkout_calculations.py b/saleor/tax/tests/test_checkout_calculations.py index 9b9d07239a9..c026cda0be6 100644 --- a/saleor/tax/tests/test_checkout_calculations.py +++ b/saleor/tax/tests/test_checkout_calculations.py @@ -7,7 +7,7 @@ from ...checkout.utils import add_variant_to_checkout from ...core.prices import quantize_price from ...core.taxes import zero_taxed_money -from ...discount.utils import create_discount_objects_for_order_promotions +from ...discount.utils import create_checkout_discount_objects_for_order_promotions from ...plugins.manager import get_plugins_manager from ...tax.models import TaxClassCountryRate from .. import TaxCalculationStrategy @@ -880,7 +880,7 @@ def test_calculate_checkout_line_total_discount_from_order_promotion( lines, _ = fetch_checkout_lines(checkout) checkout_info = fetch_checkout_info(checkout, lines, manager) checkout_line_info = lines[0] - create_discount_objects_for_order_promotions(checkout_info, lines) + create_checkout_discount_objects_for_order_promotions(checkout_info, lines) # when line_price = calculate_checkout_line_total( @@ -928,7 +928,7 @@ def test_calculate_checkout_line_total_discount_for_gift_line( lines, _ = fetch_checkout_lines(checkout) checkout_info = fetch_checkout_info(checkout, lines, manager) checkout_line_info = [line_info for line_info in lines if line_info.line.is_gift][0] - create_discount_objects_for_order_promotions(checkout_info, lines) + create_checkout_discount_objects_for_order_promotions(checkout_info, lines) # when line_price = calculate_checkout_line_total( diff --git a/saleor/tax/tests/test_order_calculations.py b/saleor/tax/tests/test_order_calculations.py index e8daf7c6123..1074037fcd2 100644 --- a/saleor/tax/tests/test_order_calculations.py +++ b/saleor/tax/tests/test_order_calculations.py @@ -77,7 +77,7 @@ def test_calculations_calculate_order_undiscounted_total( # then assert order.undiscounted_total == TaxedMoney( - net=Money("80.00", "USD"), gross=Money("80.00", "USD") + net=Money("65.04", "USD"), gross=Money("80.00", "USD") ) diff --git a/saleor/tests/e2e/orders/discounts/test_order_product_with_percentage_promotion.py b/saleor/tests/e2e/orders/discounts/test_order_product_with_percentage_promotion.py index 13edade7401..2bf3b3e34cf 100644 --- a/saleor/tests/e2e/orders/discounts/test_order_product_with_percentage_promotion.py +++ b/saleor/tests/e2e/orders/discounts/test_order_product_with_percentage_promotion.py @@ -124,8 +124,8 @@ def test_order_products_on_percentage_promotion_CORE_2103( product_price = order_line["undiscountedUnitPrice"]["gross"]["amount"] assert product_price == float(product_variant_price) assert discount == order_line["unitDiscount"]["amount"] - assert order_line["unitDiscountType"] == "FIXED" - assert order_line["unitDiscountValue"] == discount + assert order_line["unitDiscountType"] == "PERCENTAGE" + assert order_line["unitDiscountValue"] == discount_value assert order_line["unitDiscountReason"] == promotion_reason product_discounted_price = product_price - discount assert product_discounted_price == order_line["unitPrice"]["gross"]["amount"] diff --git a/saleor/tests/e2e/orders/discounts/test_order_products_on_percentage_sale.py b/saleor/tests/e2e/orders/discounts/test_order_products_on_percentage_sale.py index e6b719f323c..7fd8086ee15 100644 --- a/saleor/tests/e2e/orders/discounts/test_order_products_on_percentage_sale.py +++ b/saleor/tests/e2e/orders/discounts/test_order_products_on_percentage_sale.py @@ -148,8 +148,8 @@ def test_order_products_on_percentage_sale_CORE_1003( order_line = order["order"]["lines"][0] assert order_line["unitDiscount"]["amount"] == discount - assert order_line["unitDiscountValue"] == discount - assert order_line["unitDiscountType"] == "FIXED" + assert order_line["unitDiscountValue"] == sale_discount_value + assert order_line["unitDiscountType"] == "PERCENTAGE" assert draft_line["unitDiscountReason"] == f"Sale: {sale_id}" product_price = order_line["undiscountedUnitPrice"]["gross"]["amount"] assert product_price == undiscounted_price diff --git a/saleor/tests/e2e/orders/discounts/test_order_products_on_promotion_and_manual_order_discount.py b/saleor/tests/e2e/orders/discounts/test_order_products_on_promotion_and_manual_order_discount.py index 8965e3f9a07..4877ca7dd4d 100644 --- a/saleor/tests/e2e/orders/discounts/test_order_products_on_promotion_and_manual_order_discount.py +++ b/saleor/tests/e2e/orders/discounts/test_order_products_on_promotion_and_manual_order_discount.py @@ -161,8 +161,8 @@ def test_order_products_on_promotion_and_manual_order_discount_CORE_2108( ) assert product_price == product_variant_price assert order_line["unitDiscount"]["amount"] == promotion_value - assert order_line["unitDiscountType"] == "FIXED" - assert order_line["unitDiscountValue"] == promotion_value + assert order_line["unitDiscountType"] == "PERCENTAGE" + assert order_line["unitDiscountValue"] == promotion_discount_value assert order_line["unitDiscountReason"] == promotion_reason product_discounted_price = product_price - promotion_value shipping_amount = quantize_price( diff --git a/saleor/tests/fixtures.py b/saleor/tests/fixtures.py index 1c24734f9bc..f4f558fc7dd 100644 --- a/saleor/tests/fixtures.py +++ b/saleor/tests/fixtures.py @@ -488,8 +488,6 @@ def checkout_with_item_on_promotion(checkout_with_item): variant = line.variant - channel = checkout_with_item.channel - reward_value = Decimal("5") rule = promotion.rules.create( catalogue_predicate={ @@ -5071,7 +5069,122 @@ def order_with_lines_for_cc( @pytest.fixture -def order_fulfill_data(order_with_lines, warehouse): +def order_with_lines_and_catalogue_promotion( + order_with_lines, channel_USD, catalogue_promotion_without_rules +): + order = order_with_lines + promotion = catalogue_promotion_without_rules + line = order.lines.get(quantity=3) + variant = line.variant + reward_value = Decimal(3) + rule = promotion.rules.create( + name="Catalogue rule fixed", + catalogue_predicate={ + "variantPredicate": { + "ids": [graphene.Node.to_global_id("ProductVariant", variant)] + } + }, + reward_value_type=RewardValueType.FIXED, + reward_value=reward_value, + ) + rule.channels.add(channel_USD) + + listing = variant.channel_listings.get(channel=channel_USD) + listing.discounted_price_amount = listing.price_amount - reward_value + listing.save(update_fields=["discounted_price_amount"]) + listing.variantlistingpromotionrule.create( + promotion_rule=rule, + discount_amount=reward_value, + currency=order.currency, + ) + + line.discounts.create( + type=DiscountType.PROMOTION, + value_type=RewardValueType.FIXED, + value=reward_value, + amount_value=reward_value * line.quantity, + currency=order.currency, + promotion_rule=rule, + ) + return order + + +@pytest.fixture +def order_with_lines_and_order_promotion( + order_with_lines, + channel_USD, + order_promotion_without_rules, +): + order = order_with_lines + promotion = order_promotion_without_rules + rule = promotion.rules.create( + name="Fixed subtotal rule", + order_predicate={ + "discountedObjectPredicate": {"baseSubtotalPrice": {"range": {"gte": 10}}} + }, + reward_value_type=RewardValueType.FIXED, + reward_value=Decimal(25), + reward_type=RewardType.SUBTOTAL_DISCOUNT, + ) + rule.channels.add(channel_USD) + + order.discounts.create( + promotion_rule=rule, + type=DiscountType.ORDER_PROMOTION, + value_type=rule.reward_value_type, + value=rule.reward_value, + amount_value=rule.reward_value, + currency=order.currency, + ) + return order + + +@pytest.fixture +def order_with_lines_and_gift_promotion( + order_with_lines, + channel_USD, + order_promotion_without_rules, + variant_with_many_stocks, +): + order = order_with_lines + variant = variant_with_many_stocks + variant_listing = variant.channel_listings.get(channel=channel_USD) + promotion = order_promotion_without_rules + rule = promotion.rules.create( + name="Gift subtotal rule", + order_predicate={ + "discountedObjectPredicate": {"baseSubtotalPrice": {"range": {"gte": 10}}} + }, + reward_type=RewardType.GIFT, + ) + rule.channels.add(channel_USD) + rule.gifts.set([variant]) + + gift_line = order.lines.create( + quantity=1, + variant=variant, + is_gift=True, + currency=order.currency, + unit_price_net_amount=0, + unit_price_gross_amount=0, + total_price_net_amount=0, + total_price_gross_amount=0, + is_shipping_required=True, + is_gift_card=False, + ) + gift_line.discounts.create( + promotion_rule=rule, + type=DiscountType.ORDER_PROMOTION, + value_type=RewardValueType.FIXED, + value=variant_listing.price_amount, + amount_value=variant_listing.price_amount, + currency=order.currency, + ) + return order + + +@pytest.fixture +def order_fulfill_data(order_with_lines, warehouse, checkout): FulfillmentData = namedtuple("FulfillmentData", "order variables warehouse") order = order_with_lines order_id = graphene.Node.to_global_id("Order", order.id) diff --git a/saleor/tests/settings.py b/saleor/tests/settings.py index e2e376ef3f8..deb53017348 100644 --- a/saleor/tests/settings.py +++ b/saleor/tests/settings.py @@ -1,9 +1,14 @@ +import os import re from re import Pattern from typing import Union from django.utils.functional import SimpleLazyObject +# Disable Jaeger tracing should be done before importing settings. +# without this line pytest will start sending traces to Jaeger agent. +os.environ["JAEGER_AGENT_HOST"] = "" + from ..settings import * # noqa diff --git a/saleor/tests/utils.py b/saleor/tests/utils.py index bd18d48214c..bb677fc9c12 100644 --- a/saleor/tests/utils.py +++ b/saleor/tests/utils.py @@ -1,4 +1,6 @@ import json +import math +from decimal import Decimal from django.db import connections, transaction @@ -21,3 +23,11 @@ def flush_post_commit_hooks(): def dummy_editorjs(text, json_format=False): data = {"blocks": [{"data": {"text": text}, "type": "paragraph"}]} return json.dumps(data) if json_format else data + + +def round_down(price: Decimal) -> Decimal: + return Decimal(math.floor(price * 100)) / 100 + + +def round_up(price: Decimal) -> Decimal: + return Decimal(math.ceil(price * 100)) / 100 diff --git a/saleor/thumbnail/tests/test_utils.py b/saleor/thumbnail/tests/test_utils.py index 9d7cd2f1e44..db5d308ed67 100644 --- a/saleor/thumbnail/tests/test_utils.py +++ b/saleor/thumbnail/tests/test_utils.py @@ -1,8 +1,10 @@ +from unittest import mock from unittest.mock import MagicMock import graphene import pytest from django.core.files import File +from PIL.JpegImagePlugin import JpegImageFile from .. import FILE_NAME_MAX_LENGTH, ThumbnailFormat from ..models import Thumbnail @@ -115,6 +117,27 @@ def test_processed_image_preprocess_method_called(category_with_image, thumb_for preprocess_mock.assert_called_once() +@pytest.mark.parametrize("thumb_format", [ThumbnailFormat.WEBP, ThumbnailFormat.AVIF]) +@mock.patch.object(JpegImageFile, "_getexif") +def test_processed_image_preprocess_with_exif_corrupted( + mocked_getexif, category_with_image, thumb_format +): + # given + image_path = category_with_image.background_image.name + processed_image = ProcessedImage(image_path, 128, thumb_format) + preprocess_method_name = f"preprocess_{thumb_format.upper()}" + preprocess_mock = MagicMock() + preprocess_mock.side_effect = getattr(processed_image, preprocess_method_name) + setattr(processed_image, preprocess_method_name, preprocess_mock) + mocked_getexif.side_effect = SyntaxError() + + # when + processed_image.create_thumbnail() + + # then + preprocess_mock.assert_called_once() + + def test_get_filename_from_url_unique(): # given file_format = "jpg" diff --git a/saleor/thumbnail/utils.py b/saleor/thumbnail/utils.py index adca9198bea..102be5f6cc6 100644 --- a/saleor/thumbnail/utils.py +++ b/saleor/thumbnail/utils.py @@ -173,7 +173,16 @@ def preprocess(self, image, image_format): # Ensuring image is properly rotated if hasattr(image, "_getexif"): - exif_datadict = image._getexif() # returns None if no EXIF data + try: + # validation of the exif data was added in separate PR: + # https://github.com/saleor/saleor/pull/11224, it means that there is a + # possibility that we could have the file with corrupted exif data. + # exif data is only used to apply some optional action on the image, + # but without it, we are still able to create a thumbnail. + exif_datadict = image._getexif() # returns None if no EXIF data + except SyntaxError: + exif_datadict = None + if exif_datadict is not None: exif = dict(exif_datadict.items()) orientation = exif.get(self.EXIF_ORIENTATION_KEY, None) diff --git a/saleor/thumbnail/views.py b/saleor/thumbnail/views.py index 1472a6fce8d..7e7079f999d 100644 --- a/saleor/thumbnail/views.py +++ b/saleor/thumbnail/views.py @@ -2,6 +2,7 @@ from collections import namedtuple from typing import Optional +from django.conf import settings from django.core.exceptions import ObjectDoesNotExist from django.http import ( HttpResponseBadRequest, @@ -12,6 +13,7 @@ from ..account.models import User from ..app.models import App, AppInstallation +from ..core.db.connection import allow_writer from ..core.utils.events import call_event from ..graphql.core.utils import from_global_id_or_error from ..plugins.manager import get_plugins_manager @@ -82,16 +84,22 @@ def handle_thumbnail( else: instance_id_lookup = model_data.thumbnail_field + "_id" - if thumbnail := Thumbnail.objects.filter( - format=format, size=size_px, **{instance_id_lookup: pk} - ).first(): + if ( + thumbnail := Thumbnail.objects.using(settings.DATABASE_CONNECTION_REPLICA_NAME) + .filter(format=format, size=size_px, **{instance_id_lookup: pk}) + .first() + ): return HttpResponseRedirect(thumbnail.image.url) try: if object_type in UUID_IDENTIFIABLE_TYPES: - instance = model_data.model.objects.get(uuid=pk) + instance = model_data.model.objects.using( + settings.DATABASE_CONNECTION_REPLICA_NAME + ).get(uuid=pk) else: - instance = model_data.model.objects.get(id=pk) + instance = model_data.model.objects.using( + settings.DATABASE_CONNECTION_REPLICA_NAME + ).get(id=pk) except ObjectDoesNotExist: return HttpResponseNotFound("Instance with the given id cannot be found.") @@ -118,16 +126,17 @@ def handle_thumbnail( thumbnail_file_name = prepare_thumbnail_file_name(image.name, size_px, format) # save image thumbnail - thumbnail = Thumbnail( - size=size_px, format=format, **{model_data.thumbnail_field: instance} - ) - thumbnail.image.save(thumbnail_file_name, thumbnail_file) - thumbnail.save() - - # set additional `instance` attribute, to easily get instance data - # for ThumbnailCreated subscription type - setattr(thumbnail, "instance", instance) - manager = get_plugins_manager(allow_replica=False) - call_event(manager.thumbnail_created, thumbnail) + with allow_writer(): + thumbnail = Thumbnail( + size=size_px, format=format, **{model_data.thumbnail_field: instance} + ) + thumbnail.image.save(thumbnail_file_name, thumbnail_file) + thumbnail.save() + + # set additional `instance` attribute, to easily get instance data + # for ThumbnailCreated subscription type + setattr(thumbnail, "instance", instance) + manager = get_plugins_manager(allow_replica=False) + call_event(manager.thumbnail_created, thumbnail) return HttpResponseRedirect(thumbnail.image.url) diff --git a/saleor/warehouse/models.py b/saleor/warehouse/models.py index e9c77243152..c18bc393712 100644 --- a/saleor/warehouse/models.py +++ b/saleor/warehouse/models.py @@ -127,7 +127,8 @@ def applicable_for_click_and_collect( ) stocks_qs = ( - Stock.objects.annotate_available_quantity() + Stock.objects.using(self.db) + .annotate_available_quantity() .annotate(line_quantity=F("available_quantity") - Subquery(lines_quantity)) .order_by("line_quantity") .filter( diff --git a/saleor/webhook/payloads.py b/saleor/webhook/payloads.py index d982a787892..578b0deac31 100644 --- a/saleor/webhook/payloads.py +++ b/saleor/webhook/payloads.py @@ -23,6 +23,7 @@ from ..checkout.fetch import CheckoutInfo, CheckoutLineInfo from ..checkout.models import Checkout from ..checkout.utils import get_checkout_metadata +from ..core.db.connection import allow_writer from ..core.prices import quantize_price, quantize_price_fields from ..core.utils import build_absolute_uri from ..core.utils.anonymization import ( @@ -139,6 +140,7 @@ def generate_meta(*, requestor_data: dict[str, Any], camel_case=False, **kwargs) return meta +@allow_writer() @traced_payload_generator def generate_metadata_updated_payload( instance: Any, requestor: Optional["RequestorOrLazyObject"] = None @@ -173,6 +175,7 @@ def prepare_order_lines_allocations_payload(line): return warehouse_id_quantity_allocated_map +@allow_writer() @traced_payload_generator def generate_order_lines_payload(lines: Iterable[OrderLine]): line_fields = ( @@ -271,6 +274,7 @@ def _generate_shipping_method_payload(shipping_method, channel): return json.loads(payload)[0] +@allow_writer() @traced_payload_generator def generate_order_payload( order: "Order", @@ -411,6 +415,7 @@ def _calculate_removed( return _calculate_added(current_catalogue, previous_catalogue, key) +@allow_writer() @traced_payload_generator def generate_sale_payload( promotion: "Promotion", @@ -459,6 +464,7 @@ def generate_sale_payload( ) +@allow_writer() @traced_payload_generator def generate_sale_toggle_payload( promotion: "Promotion", @@ -481,6 +487,7 @@ def generate_sale_toggle_payload( ) +@allow_writer() @traced_payload_generator def generate_invoice_payload( invoice: "Invoice", requestor: Optional["RequestorOrLazyObject"] = None @@ -521,6 +528,7 @@ def _generate_order_payload_for_invoice(order: "Order"): return payload +@allow_writer() @traced_payload_generator def generate_checkout_payload( checkout: "Checkout", requestor: Optional["RequestorOrLazyObject"] = None @@ -595,6 +603,7 @@ def generate_checkout_payload( return checkout_data +@allow_writer() @traced_payload_generator def generate_customer_payload( customer: "User", requestor: Optional["RequestorOrLazyObject"] = None @@ -633,6 +642,7 @@ def generate_customer_payload( return data +@allow_writer() @traced_payload_generator def generate_collection_payload( collection: "Collection", requestor: Optional["RequestorOrLazyObject"] = None @@ -704,6 +714,7 @@ def _get_charge_taxes_for_product(product: "Product") -> bool: return charge_taxes +@allow_writer() @traced_payload_generator def generate_product_payload( product: "Product", requestor: Optional["RequestorOrLazyObject"] = None @@ -747,6 +758,7 @@ def generate_product_payload( return product_payload +@allow_writer() @traced_payload_generator def generate_product_deleted_payload( product: "Product", variants_id, requestor: Optional["RequestorOrLazyObject"] = None @@ -777,6 +789,7 @@ def generate_product_deleted_payload( ) +@allow_writer() @traced_payload_generator def generate_product_variant_listings_payload(variant_channel_listings): serializer = PayloadSerializer() @@ -793,6 +806,7 @@ def generate_product_variant_listings_payload(variant_channel_listings): return channel_listing_payload +@allow_writer() @traced_payload_generator def generate_product_variant_media_payload(product_variant): return [ @@ -808,6 +822,7 @@ def generate_product_variant_media_payload(product_variant): ] +@allow_writer() @traced_payload_generator def generate_product_variant_with_stock_payload( stocks: Iterable["Stock"], requestor: Optional["RequestorOrLazyObject"] = None @@ -829,6 +844,7 @@ def generate_product_variant_with_stock_payload( return serializer.serialize(stocks, fields=[], extra_dict_data=extra_dict_data) +@allow_writer() @traced_payload_generator def generate_product_variant_payload( product_variants: Iterable["ProductVariant"], @@ -867,11 +883,13 @@ def generate_product_variant_payload( return payload +@allow_writer() @traced_payload_generator def generate_product_variant_stocks_payload(product_variant: "ProductVariant"): return product_variant.stocks.aggregate(Sum("quantity"))["quantity__sum"] or 0 +@allow_writer() @traced_payload_generator def generate_fulfillment_lines_payload(fulfillment: Fulfillment): serializer = PayloadSerializer() @@ -942,6 +960,7 @@ def generate_fulfillment_lines_payload(fulfillment: Fulfillment): ) +@allow_writer() @traced_payload_generator def generate_fulfillment_payload( fulfillment: Fulfillment, requestor: Optional["RequestorOrLazyObject"] = None @@ -989,6 +1008,7 @@ def generate_fulfillment_payload( return fulfillment_data +@allow_writer() @traced_payload_generator def generate_page_payload( page: Page, requestor: Optional["RequestorOrLazyObject"] = None @@ -1039,6 +1059,7 @@ def _generate_refund_data_payload(data): return data +@allow_writer() @traced_payload_generator def generate_payment_payload( payment_data: "PaymentData", requestor: Optional["RequestorOrLazyObject"] = None @@ -1057,6 +1078,7 @@ def generate_payment_payload( return json.dumps(data, cls=CustomJsonEncoder) +@allow_writer() @traced_payload_generator def generate_list_gateways_payload( currency: Optional[str], checkout: Optional["Checkout"] @@ -1114,6 +1136,7 @@ def _generate_sample_order_payload(event_name): return generate_order_payload(anonymized_order) +@allow_writer() @traced_payload_generator def generate_sample_payload(event_name: str) -> Optional[dict]: checkout_events = [ @@ -1176,6 +1199,7 @@ def process_translation_context(context): return result +@allow_writer() @traced_payload_generator def generate_translation_payload( translation: "Translation", requestor: Optional["RequestorOrLazyObject"] = None @@ -1218,6 +1242,7 @@ def _generate_payload_for_shipping_method(method: ShippingMethodData): return payload +@allow_writer() @traced_payload_generator def generate_excluded_shipping_methods_for_order_payload( order: "Order", @@ -1234,6 +1259,7 @@ def generate_excluded_shipping_methods_for_order_payload( return json.dumps(payload, cls=CustomJsonEncoder) +@allow_writer() @traced_payload_generator def generate_excluded_shipping_methods_for_checkout_payload( checkout: "Checkout", @@ -1250,6 +1276,7 @@ def generate_excluded_shipping_methods_for_checkout_payload( return json.dumps(payload, cls=CustomJsonEncoder) +@allow_writer() @traced_payload_generator def generate_checkout_payload_for_tax_calculation( checkout_info: "CheckoutInfo", @@ -1381,6 +1408,7 @@ def _generate_order_lines_payload_for_tax_calculation(lines: QuerySet[OrderLine] ) +@allow_writer() @traced_payload_generator def generate_order_payload_for_tax_calculation(order: "Order"): serializer = PayloadSerializer() @@ -1437,6 +1465,7 @@ def generate_order_payload_for_tax_calculation(order: "Order"): return order_data +@allow_writer() @traced_payload_generator def generate_transaction_action_request_payload( transaction_data: "TransactionActionData", @@ -1496,6 +1525,7 @@ def generate_transaction_action_request_payload( return json.dumps(payload, cls=CustomJsonEncoder) +@allow_writer() def generate_transaction_session_payload( transaction_process_action: "TransactionProcessActionData", transaction: "TransactionItem", @@ -1519,12 +1549,14 @@ def generate_transaction_session_payload( return json.dumps(payload, cls=CustomJsonEncoder) +@allow_writer() @traced_payload_generator def generate_thumbnail_payload(thumbnail: Thumbnail): thumbnail_id = graphene.Node.to_global_id("Thumbnail", thumbnail.id) return json.dumps({"id": thumbnail_id}) +@allow_writer() @traced_payload_generator def generate_product_media_payload(media: ProductMedia): product_media_id = graphene.Node.to_global_id("ProductMedia", media.id) diff --git a/saleor/webhook/transport/asynchronous/transport.py b/saleor/webhook/transport/asynchronous/transport.py index 83034ebe485..0041355f070 100644 --- a/saleor/webhook/transport/asynchronous/transport.py +++ b/saleor/webhook/transport/asynchronous/transport.py @@ -11,6 +11,7 @@ from ....celeryconf import app from ....core import EventDeliveryStatus +from ....core.db.connection import allow_writer from ....core.models import EventDelivery, EventPayload from ....core.tracing import webhooks_opentracing_trace from ....core.utils import get_domain @@ -129,8 +130,9 @@ def create_deliveries_for_subscriptions( ) ) - EventPayload.objects.bulk_create(event_payloads) - return EventDelivery.objects.bulk_create(event_deliveries) + with allow_writer(): + EventPayload.objects.bulk_create(event_payloads) + return EventDelivery.objects.bulk_create(event_deliveries) def group_webhooks_by_subscription(webhooks): @@ -140,6 +142,7 @@ def group_webhooks_by_subscription(webhooks): return regular, subscription +@allow_writer() def create_event_delivery_list_for_webhooks( webhooks: Sequence["Webhook"], event_payload: "EventPayload", @@ -190,14 +193,15 @@ def trigger_webhooks_async( elif data is None: raise NotImplementedError("No payload was provided for regular webhooks.") - payload = EventPayload.objects.create(payload=data) - deliveries.extend( - create_event_delivery_list_for_webhooks( - webhooks=regular_webhooks, - event_payload=payload, - event_type=event_type, + with allow_writer(): + payload = EventPayload.objects.create(payload=data) + deliveries.extend( + create_event_delivery_list_for_webhooks( + webhooks=regular_webhooks, + event_payload=payload, + event_type=event_type, + ) ) - ) if subscription_webhooks: deliveries.extend( create_deliveries_for_subscriptions( diff --git a/saleor/webhook/transport/synchronous/transport.py b/saleor/webhook/transport/synchronous/transport.py index 7f762838552..90eb1c66e6a 100644 --- a/saleor/webhook/transport/synchronous/transport.py +++ b/saleor/webhook/transport/synchronous/transport.py @@ -10,6 +10,7 @@ from ....celeryconf import app from ....core import EventDeliveryStatus +from ....core.db.connection import allow_writer from ....core.models import EventDelivery, EventPayload from ....core.tracing import webhooks_opentracing_trace from ....core.utils import get_domain @@ -252,13 +253,15 @@ def create_delivery_for_subscription_sync_event( # Return None so if subscription query returns no data Saleor will not crash but # log the issue and continue without creating a delivery. return None - event_payload = EventPayload.objects.create(payload=json.dumps({**data})) - event_delivery = EventDelivery.objects.create( - status=EventDeliveryStatus.PENDING, - event_type=event_type, - payload=event_payload, - webhook=webhook, - ) + + with allow_writer(): + event_payload = EventPayload.objects.create(payload=json.dumps({**data})) + event_delivery = EventDelivery.objects.create( + status=EventDeliveryStatus.PENDING, + event_type=event_type, + payload=event_payload, + webhook=webhook, + ) return event_delivery @@ -283,13 +286,14 @@ def trigger_webhook_sync( if not delivery: return None else: - event_payload = EventPayload.objects.create(payload=payload) - delivery = EventDelivery.objects.create( - status=EventDeliveryStatus.PENDING, - event_type=event_type, - payload=event_payload, - webhook=webhook, - ) + with allow_writer(): + event_payload = EventPayload.objects.create(payload=payload) + delivery = EventDelivery.objects.create( + status=EventDeliveryStatus.PENDING, + event_type=event_type, + payload=event_payload, + webhook=webhook, + ) kwargs = {} if timeout: @@ -337,14 +341,17 @@ def trigger_all_webhooks_sync( if not delivery: return None else: - if event_payload is None: - event_payload = EventPayload.objects.create(payload=generate_payload()) - delivery = EventDelivery.objects.create( - status=EventDeliveryStatus.PENDING, - event_type=event_type, - payload=event_payload, - webhook=webhook, - ) + with allow_writer(): + if event_payload is None: + event_payload = EventPayload.objects.create( + payload=generate_payload() + ) + delivery = EventDelivery.objects.create( + status=EventDeliveryStatus.PENDING, + event_type=event_type, + payload=event_payload, + webhook=webhook, + ) response_data = send_webhook_request_sync(delivery) if parsed_response := parse_response(response_data): diff --git a/saleor/webhook/transport/utils.py b/saleor/webhook/transport/utils.py index e5e45f4e678..8357fba8d26 100644 --- a/saleor/webhook/transport/utils.py +++ b/saleor/webhook/transport/utils.py @@ -22,6 +22,7 @@ from ...app.headers import AppHeaders, DeprecatedAppHeaders from ...app.models import App +from ...core.db.connection import allow_writer from ...core.http_client import HTTPClient from ...core.models import ( EventDelivery, @@ -378,6 +379,7 @@ def catch_duration_time(): yield lambda: time() - start +@allow_writer() def create_attempt( delivery: "EventDelivery", task_id: Optional[str] = None, @@ -394,6 +396,7 @@ def create_attempt( return attempt +@allow_writer() def attempt_update( attempt: "EventDeliveryAttempt", webhook_response: "WebhookResponse", @@ -416,6 +419,7 @@ def attempt_update( ) +@allow_writer() def clear_successful_delivery(delivery: "EventDelivery"): if delivery.status == EventDeliveryStatus.SUCCESS: payload_id = delivery.payload_id @@ -424,6 +428,7 @@ def clear_successful_delivery(delivery: "EventDelivery"): EventPayload.objects.filter(pk=payload_id, deliveries__isnull=True).delete() +@allow_writer() def delivery_update(delivery: "EventDelivery", status: str): delivery.status = status delivery.save(update_fields=["status"]) @@ -486,13 +491,14 @@ def trigger_transaction_request( payload = generate_transaction_action_request_payload( transaction_data, requestor ) - event_payload = EventPayload.objects.create(payload=payload) - delivery = EventDelivery.objects.create( - status=EventDeliveryStatus.PENDING, - event_type=event_type, - payload=event_payload, - webhook=webhook, - ) + with allow_writer(): + event_payload = EventPayload.objects.create(payload=payload) + delivery = EventDelivery.objects.create( + status=EventDeliveryStatus.PENDING, + event_type=event_type, + payload=event_payload, + webhook=webhook, + ) call_event( handle_transaction_request_task.delay, delivery.id,