diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 787919a..2660aac 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,23 +2,9 @@ on: [push, pull_request] name: CI jobs: build: - name: "Build on Racket '${{ matrix.racket-version }}' (${{ matrix.racket-variant }})" - runs-on: ubuntu-latest - strategy: - matrix: - racket-version: ["stable", "current"] - racket-variant: ["BC", "CS"] + name: "Build on Racket CS" + runs-on: self-hosted steps: - uses: actions/checkout@v3 - - uses: Bogdanp/setup-racket@v1.9.1 - with: - architecture: x64 - distribution: full - variant: ${{ matrix.racket-variant }} - version: ${{ matrix.racket-version }} - - name: Installing malt and its dependencies - run: raco pkg install --no-docs --auto --name malt - - name: Compiling malt and building its docs - run: raco setup --check-pkg-deps --unused-pkg-deps malt - name: Testing malt - run: raco test -x -p malt + run: make diff --git a/Makefile b/Makefile index 2b8078d..25302c0 100644 --- a/Makefile +++ b/Makefile @@ -10,12 +10,60 @@ TEST_FLAGS=-q # add sources be sure to update this section. #---------------------------------------- +LAZY_DIR=lazy LEARNER_DIR=learner FLAT_DIR=flat-tensors +UNIFORM_DIR=uniform-tensors +ACCELERATED_DIR=accelerated-tensors NESTED_DIR=nested-tensors TOOLS_DIR=tools MALTED_DIR=malted +# lazy +LAZY_TENSORS_DIR=$(LAZY_DIR)/tensors +LAZY_AUTODIFF_DIR=$(LAZY_DIR)/autodiff +LAZY_EXT_OPS_DIR=$(LAZY_DIR)/ext-ops + +LAZY_TENSORS_SOURCES=\ + $(LAZY_TENSORS_DIR)/c0-ast.rkt\ + $(LAZY_TENSORS_DIR)/c1-racket-runtime.rkt\ + $(LAZY_TENSORS_DIR)/c2-interpreter.rkt\ + $(LAZY_TENSORS_DIR)/c3-compiler.rkt\ + $(LAZY_TENSORS_DIR)/0-lazy.rkt\ + $(LAZY_TENSORS_DIR)/1-reflect.rkt\ + $(LAZY_TENSORS_DIR)/A-equality.rkt\ + $(LAZY_DIR)/tensors.rkt + +LAZY_AUTODIFF_SOURCES=\ + $(LAZY_AUTODIFF_DIR)/A-autodiff.rkt\ + $(LAZY_AUTODIFF_DIR)/B-prims.rkt\ + $(LAZY_AUTODIFF_DIR)/C-dualized-tensor-ops.rkt\ + $(LAZY_AUTODIFF_DIR)/D-test-helpers.rkt\ + $(LAZY_DIR)/autodiff.rkt + +LAZY_EXT_OPS_SOURCES=\ + $(LAZY_EXT_OPS_DIR)/A-scalar-ops.rkt\ + $(LAZY_EXT_OPS_DIR)/B-comparators.rkt\ + $(LAZY_EXT_OPS_DIR)/C-star-2-1.rkt\ + $(LAZY_EXT_OPS_DIR)/D-sum.rkt\ + $(LAZY_EXT_OPS_DIR)/E-argmax.rkt\ + $(LAZY_EXT_OPS_DIR)/F-max.rkt\ + $(LAZY_EXT_OPS_DIR)/G-correlate.rkt\ + $(LAZY_DIR)/ext-ops.rkt + +LAZY_LOADERS=\ + $(LAZY_DIR)/no-duals-no-overrides.rkt\ + $(LAZY_DIR)/no-duals.rkt\ + $(LAZY_DIR)/no-overrides.rkt\ + $(LAZY_DIR)/tensors.rkt\ + $(LAZY_DIR)/autodiff.rkt\ + $(LAZY_DIR)/ext-ops.rkt + +LAZY_SOURCES=$(LAZY_TENSORS_SOURCES)\ + $(LAZY_AUTODIFF_SOURCES)\ + $(LAZY_EXT_OPS_SOURCES)\ + $(LAZY_LOADERS) + # learner LEARNER_TENSORS_DIR=$(LEARNER_DIR)/tensors LEARNER_AUTODIFF_DIR=$(LEARNER_DIR)/autodiff @@ -102,6 +150,7 @@ FLAT_LOADERS=\ $(FLAT_DIR)/no-overrides.rkt\ $(FLAT_DIR)/tensors.rkt\ $(FLAT_DIR)/autodiff.rkt\ + $(FLAT_DIR)/ext-impl.rkt\ $(FLAT_DIR)/ext-ops.rkt FLAT_SOURCES=$(FLAT_TENSORS_SOURCES)\ @@ -109,6 +158,103 @@ FLAT_SOURCES=$(FLAT_TENSORS_SOURCES)\ $(FLAT_EXT_OPS_SOURCES)\ $(FLAT_LOADERS) +# uniform-tensors +UNIFORM_TENSORS_DIR=$(UNIFORM_DIR)/tensors +UNIFORM_AUTODIFF_DIR=$(UNIFORM_DIR)/autodiff +UNIFORM_EXT_OPS_DIR=$(UNIFORM_DIR)/ext-ops + +UNIFORM_TENSORS_SOURCES=\ + $(UNIFORM_TENSORS_DIR)/0-vectors.rkt\ + $(UNIFORM_TENSORS_DIR)/1-flats.rkt\ + $(UNIFORM_TENSORS_DIR)/A-equality.rkt\ + $(UNIFORM_TENSORS_DIR)/B-tensor-basics.rkt\ + $(UNIFORM_TENSORS_DIR)/C-tensor-ops.rkt\ + $(UNIFORM_TENSORS_DIR)/D-extend.rkt\ + $(UNIFORM_DIR)/tensors.rkt + +UNIFORM_AUTODIFF_SOURCES=\ + $(UNIFORM_AUTODIFF_DIR)/A-autodiff.rkt\ + $(UNIFORM_AUTODIFF_DIR)/B-prims.rkt\ + $(UNIFORM_AUTODIFF_DIR)/C-dualized-tensor-ops.rkt\ + $(UNIFORM_AUTODIFF_DIR)/D-test-helpers.rkt\ + $(UNIFORM_AUTODIFF_DIR)/E-print.rkt\ + $(UNIFORM_DIR)/autodiff.rkt + +UNIFORM_EXT_OPS_SOURCES=\ + $(UNIFORM_EXT_OPS_DIR)/A-scalar-ops.rkt\ + $(UNIFORM_EXT_OPS_DIR)/B-comparators.rkt\ + $(UNIFORM_EXT_OPS_DIR)/C-star-2-1.rkt\ + $(UNIFORM_EXT_OPS_DIR)/D-sum.rkt\ + $(UNIFORM_EXT_OPS_DIR)/E-argmax.rkt\ + $(UNIFORM_EXT_OPS_DIR)/F-max.rkt\ + $(UNIFORM_EXT_OPS_DIR)/G-correlate.rkt\ + $(UNIFORM_EXT_OPS_DIR)/I-flatten.rkt\ + $(UNIFORM_EXT_OPS_DIR)/K-concat.rkt\ + $(UNIFORM_DIR)/ext-ops.rkt + +UNIFORM_LOADERS=\ + $(UNIFORM_DIR)/no-duals-no-overrides.rkt\ + $(UNIFORM_DIR)/no-duals.rkt\ + $(UNIFORM_DIR)/no-overrides.rkt\ + $(UNIFORM_DIR)/tensors.rkt\ + $(UNIFORM_DIR)/autodiff.rkt\ + $(UNIFORM_DIR)/ext-impl.rkt\ + $(UNIFORM_DIR)/ext-ops.rkt + +UNIFORM_SOURCES=$(UNIFORM_TENSORS_SOURCES)\ + $(UNIFORM_AUTODIFF_SOURCES)\ + $(UNIFORM_EXT_OPS_SOURCES)\ + $(UNIFORM_LOADERS) + +# accelerated-tensors +ACCELERATED_TENSORS_DIR=$(ACCELERATED_DIR)/tensors +ACCELERATED_AUTODIFF_DIR=$(ACCELERATED_DIR)/autodiff +ACCELERATED_EXT_OPS_DIR=$(ACCELERATED_DIR)/ext-ops + +ACCELERATED_TENSORS_SOURCES=\ + $(ACCELERATED_TENSORS_DIR)/0-vectors.rkt\ + $(ACCELERATED_TENSORS_DIR)/1-flats.rkt\ + $(ACCELERATED_TENSORS_DIR)/2-acc-runtime.rkt\ + $(ACCELERATED_TENSORS_DIR)/A-equality.rkt\ + $(ACCELERATED_TENSORS_DIR)/B-tensor-basics.rkt\ + $(ACCELERATED_TENSORS_DIR)/C-tensor-ops.rkt\ + $(ACCELERATED_TENSORS_DIR)/D-extend.rkt\ + $(ACCELERATED_DIR)/tensors.rkt + +ACCELERATED_AUTODIFF_SOURCES=\ + $(ACCELERATED_AUTODIFF_DIR)/A-autodiff.rkt\ + $(ACCELERATED_AUTODIFF_DIR)/B-prims.rkt\ + $(ACCELERATED_AUTODIFF_DIR)/C-dualized-tensor-ops.rkt\ + $(ACCELERATED_AUTODIFF_DIR)/D-test-helpers.rkt\ + $(ACCELERATED_AUTODIFF_DIR)/E-print.rkt\ + $(ACCELERATED_DIR)/autodiff.rkt + +ACCELERATED_EXT_OPS_SOURCES=\ + $(ACCELERATED_EXT_OPS_DIR)/A-scalar-ops.rkt\ + $(ACCELERATED_EXT_OPS_DIR)/B-comparators.rkt\ + $(ACCELERATED_EXT_OPS_DIR)/C-star-2-1.rkt\ + $(ACCELERATED_EXT_OPS_DIR)/D-sum.rkt\ + $(ACCELERATED_EXT_OPS_DIR)/E-argmax.rkt\ + $(ACCELERATED_EXT_OPS_DIR)/F-max.rkt\ + $(ACCELERATED_EXT_OPS_DIR)/G-correlate.rkt\ + $(ACCELERATED_EXT_OPS_DIR)/I-flatten.rkt\ + $(ACCELERATED_EXT_OPS_DIR)/K-concat.rkt\ + $(ACCELERATED_DIR)/ext-ops.rkt + +ACCELERATED_LOADERS=\ + $(ACCELERATED_DIR)/no-duals-no-overrides.rkt\ + $(ACCELERATED_DIR)/no-duals.rkt\ + $(ACCELERATED_DIR)/no-overrides.rkt\ + $(ACCELERATED_DIR)/tensors.rkt\ + $(ACCELERATED_DIR)/autodiff.rkt\ + $(ACCELERATED_DIR)/ext-impl.rkt\ + $(ACCELERATED_DIR)/ext-ops.rkt + +ACCELERATED_SOURCES=$(ACCELERATED_TENSORS_SOURCES)\ + $(ACCELERATED_AUTODIFF_SOURCES)\ + $(ACCELERATED_EXT_OPS_SOURCES)\ + $(ACCELERATED_LOADERS) + # nested-tensors NESTED_TENSORS_DIR=$(NESTED_DIR)/tensors NESTED_AUTODIFF_DIR=$(NESTED_DIR)/autodiff @@ -182,7 +328,10 @@ MALTED_SOURCES=\ # All the sources together, plus entry points SOURCES=$(LEARNER_SOURCES)\ + $(LAZY_SOURCES)\ $(FLAT_SOURCES)\ + $(UNIFORM_SOURCES)\ + $(ACCELERATED_SOURCES)\ $(NESTED_SOURCES)\ $(TOOLS_SOURCES)\ $(MALTED_SOURCES)\ @@ -244,10 +393,12 @@ build: $(SOURCES) # Test it all test: @ echo "Running tests ..." &&\ + export MALT_PREFERENCES="$(shell pwd)/local.cfg";\ $(RACO) test $(TEST_FLAGS) $(SOURCES) one: - -@ $(RACO) make $(ARG) && $(RACO) test $(ARG) + -@ export MALT_PREFERENCES="$(shell pwd)/local.cfg";\ + $(RACO) make $(ARG) && $(RACO) test $(ARG) clean: find . -name 'compiled' | xargs -I% rm -rf % diff --git a/accelerated-tensors.rkt b/accelerated-tensors.rkt new file mode 100644 index 0000000..9132390 --- /dev/null +++ b/accelerated-tensors.rkt @@ -0,0 +1,47 @@ +#lang racket/base + +(require + (except-in "accelerated-tensors/tensors.rkt" + rank shape reshape tref trefs tensor? tlen ref refr)) + +(require "accelerated-tensors/autodiff.rkt") +(require "accelerated-tensors/ext-ops.rkt") + +(provide + tolerance + + len ref refr + + tref tlen list->tensor tensor build-tensor + + ext1-ρ ext2-ρ ext1-∇ ext2-∇ + + dual dual? ρ κ ∇ ∇¹ (rename-out (∇ gradient-of)) map* + + ext1 ext2 prim1 prim2 + + scalar? tensor? rank shape reshape trefs + + trace-print check-dual-equal? check-ρ-∇ + max-tensor-print-length make-printable + + (rename-out (d+ +) (d- -) (d* *) (d/ /) (d-rectify rectify) + (d-exp exp) (d-log log) (d-expt expt) (d-sqrt sqrt) (d-sqr sqr) + (d-sum sum) (d-abs abs) (d*-2-1 *-2-1) (d-argmax argmax) + (d-max max) (d-sum-cols sum-cols) (d-correlate correlate) + (d-flatten flatten) + (d-concat concat) (d-concat-n concat-n)) + + +-ρ --ρ *-ρ /-ρ rectify-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ + sum-ρ abs-ρ *-2-1-ρ argmax-ρ + max-ρ sum-cols-ρ correlate-ρ + flatten-ρ concat-ρ + + +-0-0 --0-0 *-0-0 /-0-0 expt-0-0 + exp-0 log-0 abs-0 rectify-0 sqrt-0 + + sum-1 argmax-1 max-1 flatten-2 concat-1-1 + + =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/accelerated-tensors/autodiff.rkt b/accelerated-tensors/autodiff.rkt new file mode 100644 index 0000000..68168c9 --- /dev/null +++ b/accelerated-tensors/autodiff.rkt @@ -0,0 +1,23 @@ +#lang racket + +(require "autodiff/A-autodiff.rkt") +(require "autodiff/B-prims.rkt") +(require "autodiff/C-dualized-tensor-ops.rkt") +(require "autodiff/D-test-helpers.rkt") +(require "autodiff/E-print.rkt") + +(provide dual dual? ρ κ ∇ ∇¹ scalar? trace-print dual* map*) +(provide prim1 prim2 ext1 ext2) +(provide (rename-out (d-rank rank) + (d-shape shape) + (d-reshape reshape) + (d-tref tref) + (d-trefs trefs) + (d-tensor? tensor?) + (d-tlen tlen) + (d-ref ref) + (d-refr refr))) + +(provide check-dual-equal? check-ρ-∇) + +(provide max-tensor-print-length make-printable) diff --git a/accelerated-tensors/autodiff/A-autodiff.rkt b/accelerated-tensors/autodiff/A-autodiff.rkt new file mode 100644 index 0000000..afcd114 --- /dev/null +++ b/accelerated-tensors/autodiff/A-autodiff.rkt @@ -0,0 +1,121 @@ +#lang racket + +(require "../tensors.rkt") +(require string-interpolation) + +;;---------------------------- +;; Real part of a dual is always a tensor (of any rank) +;;---------------------------- + +(define dual? + (λ (x) + (and (vector? x) (eq? (vector-ref x 0) dual)))) + +(define dual + (λ (r k) + (vector dual r k))) + +(define dual* + (λ (d) + (dual (ρ d) end-of-chain))) + +(define ρ + (λ (d) + (cond + ((dual? d) (vector-ref d 1)) + (else d)))) + +(define κ + (λ (d) + (cond + ((dual? d) (vector-ref d 2)) + (else end-of-chain)))) + +(define scalar? + (λ (d) + (or (number? d) + (and (dual? d) + (number? (ρ d)))))) + +(define dual-like? + (λ (d) + (or (dual? d) + (number? d) + (vector? d)))) + +;;---------------------------- +;; Chain rule +;;---------------------------- + +(define end-of-chain + (λ (d z σ) + (let ((g (hash-ref σ d 0.0))) + (hash-set σ d (+-ρ z g))))) + +(define +-ρ + (ext2-ρ + (λ (a b) "@{a} + @{b}") 0 0)) + +;;---------------------------- +;; Reverse-mode AD +;;---------------------------- + +(define ∇ + (λ (f theta) + (let ((wrt (map* dual* theta))) + (∇-once (f wrt) wrt)))) + +(define ∇¹ + (λ (f) + (λ xs + (let ((wrt (map* dual* xs))) + (∇-once (apply f wrt) wrt))))) + +(define ∇-once + (λ (y wrt) + (let ((σ (∇σ y (hasheq)))) + (map* (λ (d) + (hash-ref σ d 0.0)) + wrt)))) + +(define ∇σ + (λ (y σ) + (cond + ((dual-like? y) ((κ y) y (one-like (ρ y)) σ)) + ((list? y) (∇σ-list y σ)) + (else (printf "Unknown: ~a~%" y))))) + +(define ∇σ-list + (λ (y σ) + (cond + ((null? y) σ) + (else + (let ((σ-hat (∇σ (ref y 0) σ))) + (∇σ-list (refr y 1) σ-hat)))))) + +;;---------------------------- +;; General helpers +;;---------------------------- + +(define map* + (λ (f y) + (cond + ((dual-like? y) (f y)) + ((list? y) + (map (λ (yi) + (map* f yi)) + y)) + (else y)))) + +(define trace-print + (λ (v port) + (cond + ((dual? v) (trace-print (ρ v) port)) + (else (fprintf port "~a~%" v))))) + +(define (one-like s) ((ext1-ρ (λ (x) 1.0) (λ (x) "1.0") 0) s)) + +(include "test/test-A-autodiff.rkt") + +(provide + dual dual? ρ κ ∇ ∇¹ dual* scalar? end-of-chain map* + trace-print) diff --git a/accelerated-tensors/autodiff/B-prims.rkt b/accelerated-tensors/autodiff/B-prims.rkt new file mode 100644 index 0000000..044b2cb --- /dev/null +++ b/accelerated-tensors/autodiff/B-prims.rkt @@ -0,0 +1,262 @@ +#lang racket + +(require "../tensors/0-vectors.rkt") +(require "../tensors.rkt") +(require "A-autodiff.ss") + +(define ρ-function + (λ (f) (f ρ-function))) + +(define ρ-acc-function + (λ (f) (f ρ-acc-function))) + +(define ∇-function + (λ (f) (f ∇-function))) + +(define ∇-acc-function + (λ (f) (f ∇-acc-function))) + +(define shape-fn + (λ (f) (f shape-fn))) + +(define signature + (λ (f) (f signature))) + +;; For flat tensors, ρ-fn and ∇-fn +;; are of two types: functional and pre-allocated +;; When they are functional, they return values +;; When they are pre-allocated, they expect expect the +;; return flat-store to be pre-allocated, and simply +;; operate as fillers. +;; +;; Pre-allocated ρ and ∇ have arities +;; 6 and 7 for unary ops, and 9 and 10 for binary ops. +;; We test for this arity to determine the type. +;; +;; Generally speaking, scalar operations are functional +;; and vector operations are pre-allocated. +;; +;; The functions ensure-ρ-callable-1, ensure-∇-callable-1 +;; and ensure-ρ-callable-2, ensure-∇-callable-2 provide +;; the preallocation for flat-stores when a vector-op is +;; provided, but the invocation of prim1 expects functional +;; results. +;; + +;; Primitives need a unique identifier (its signature) which corresponds to a +;; unique GPU kernel name. However, for the graphics driver in a 2017 macbook +;; pro, the kernel name is limited to a length of 15 characters. This limitation +;; therefore limits how many unique primitives can be created. + +(define prim1 + (let ((id 0)) + (λ (ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn [shape (λ (l . r) l)]) + (let ((ρ-callable (ensure-ρ-callable-1 ρ-fn shape)) + (∇-callable (ensure-∇-callable-1 ∇-fn shape)) + (prim-sign (string-append "p1" (~r id #:base 16)))) + (set! id (add1 id)) + (λ (daf) + (cond + ((eq? daf ρ-function) ρ-fn) + ((eq? daf ρ-acc-function) ρ-acc-fn) + ((eq? daf ∇-function) ∇-fn) + ((eq? daf ∇-acc-function) ∇-acc-fn) + ((eq? daf shape-fn) shape) + ((eq? daf signature) prim-sign) + (else (prim1-dual ρ-callable ∇-callable daf)))))))) + +(define prim1-dual + (λ (ρ-fn ∇-fn da) + (let ((ra (ρ da))) + (dual (ρ-fn ra) + (λ (d z σ) + (let ((ga (∇-fn ra z))) + ((κ da) da ga σ))))))) + +(define prim2 + (let ((id 0)) + (λ (ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn [shape (λ (l . r) l)]) + (let ((ρ-callable (ensure-ρ-callable-2 ρ-fn shape)) + (∇-callable (ensure-∇-callable-2 ∇-fn shape)) + (prim-sign (string-append "p2" (~r id #:base 16)))) + (set! id (add1 id)) + (λ ds + (let ((daf (ref ds 0))) + (cond + ((eq? daf ρ-function) ρ-fn) + ((eq? daf ρ-acc-function) ρ-acc-fn) + ((eq? daf ∇-function) ∇-fn) + ((eq? daf ∇-acc-function) ∇-acc-fn) + ((eq? daf shape-fn) shape) + ((eq? daf signature) prim-sign) + (else (prim2-dual ρ-callable ∇-callable daf (ref ds 1)))))))))) + +(define prim2-dual + (λ (ρ-fn ∇-fn da db) + (let ((ra (ρ da)) + (rb (ρ db))) + (dual (ρ-fn ra rb) + (λ (d z σ) + (let-values (((ga gb) (∇-fn ra rb z))) + (let ((σ-hat ((κ da) da ga σ))) + ((κ db) db gb σ-hat)))))))) + +;;---------------------------- +;; Managing flat-optimized and +;; non-flat ρ and ∇ functions +;;---------------------------- + +(define ensure-ρ-callable-1 + (λ (ρ-fn shape-fn) + (cond + ((expects-preallocated? ρ-fn) + (λ (ra) + (apply-flat-ρ-fn-1 ρ-fn ra shape-fn))) + (else ρ-fn)))) + +(define ensure-∇-callable-1 + (λ (∇-fn shape-fn) + (cond + ((expects-preallocated? ∇-fn) + (λ (ra z) + (apply-flat-∇-fn-1 ∇-fn ra z shape-fn))) + (else ∇-fn)))) + +(define ensure-ρ-callable-2 + (λ (ρ-fn shape-fn) + (cond + ((expects-preallocated? ρ-fn) + (λ (ra rb) + (apply-flat-ρ-fn-2 ρ-fn ra rb shape-fn))) + (else ρ-fn)))) + +(define ensure-∇-callable-2 + (λ (∇-fn shape-fn) + (cond + ((expects-preallocated? ∇-fn) + (λ (ra rb z) + (apply-flat-∇-fn-2 ∇-fn ra rb z shape-fn))) + (else ∇-fn)))) + +(define apply-flat-ρ-fn-1 + (λ (ρ-fn ra shape-fn) + (let* ((in-shape (flat-shape ra)) + (in-size (size-of in-shape)) + (out-shape (shape-fn in-shape)) + (out-size (size-of out-shape))) + (cond + ((null? out-shape) + (let ((v-out (new-vec 1 0.0))) + (ρ-fn (flat-store ra) (flat-offset ra) in-size + v-out 0 1) + (vref v-out 0))) + (else + (let ((v-out (new-vec out-size 0.0))) + (ρ-fn (flat-store ra) (flat-offset ra) in-size + v-out 0 out-size) + (flat out-shape v-out 0))))))) + +(define apply-flat-∇-fn-1 + (λ (∇-fn ra z shape-fn) + (let* ((in-shape (flat-shape ra)) + (in-size (size-of in-shape)) + (out-shape (shape-fn in-shape)) + (out-size (size-of out-shape))) + (let ((g (new-vec in-size 0.0))) + (cond + ((null? out-shape) + (let ((v-z (new-vec 1 z))) + (∇-fn g (flat-store ra) (flat-offset ra) in-size + v-z 0 1) + (flat in-shape g 0))) + (else + (∇-fn g (flat-store ra) (flat-offset ra) in-size + (flat-store z) (flat-offset z) out-size) + (flat in-shape g 0))))))) + +(define apply-flat-ρ-fn-2 + (λ (ρ-fn ra rb shape-fn) + (let* ((in-shape-a (flat-shape ra)) + (in-size-a (size-of in-shape-a)) + (in-shape-b (flat-shape rb)) + (in-size-b (size-of in-shape-b)) + (out-shape (shape-fn in-shape-a in-shape-b)) + (out-size (size-of out-shape))) + (cond + ((null? out-shape) + (let ((v-out (new-vec 1 0.0))) + (ρ-fn + (flat-store ra) (flat-offset ra) in-size-a + (flat-store rb) (flat-offset rb) in-size-b + v-out 0 1) + (vref v-out 0))) + (else + (let ((v-out (new-vec out-size 0.0))) + (ρ-fn + (flat-store ra) (flat-offset ra) in-size-a + (flat-store rb) (flat-offset rb) in-size-b + v-out 0 out-size) + (flat out-shape v-out 0))))))) + +(define apply-flat-∇-fn-2 + (λ (∇-fn ra rb z shape-fn) + (let* ((in-shape-a (flat-shape ra)) + (in-size-a (size-of in-shape-a)) + (in-shape-b (flat-shape rb)) + (in-size-b (size-of in-shape-b)) + (out-shape (shape-fn in-shape-a in-shape-b)) + (out-size (size-of out-shape))) + (let ((g0 (new-vec in-size-a 0.0)) + (g1 (new-vec in-size-b 0.0))) + (cond + ((null? out-shape) + (let ((v-z (new-vec 1 z))) + (∇-fn g0 g1 + (flat-store ra) (flat-offset ra) in-size-a + (flat-store rb) (flat-offset rb) in-size-b + v-z 0 1) + (values + (flat in-shape-a g0 0) + (flat in-shape-b g1 0)))) + (else + (∇-fn g0 g1 + (flat-store ra) (flat-offset ra) in-size-a + (flat-store rb) (flat-offset rb) in-size-b + (flat-store z) (flat-offset z) out-size) + (values + (flat in-shape-a g0 0) + (flat in-shape-b g1 0)))))))) + +;;---------------------------- +;; Dualized tensor op creators +;;---------------------------- + +;; TODO: Figure out the behaviour when we compose ext* with ext*. Currently we +;; assume that "f" is always a non-extended primitive. +(define ext1 + (λ (f n) + (prim1 + (ext1-ρ (ρ-function f) (ρ-acc-function f) n (shape-fn f) + (string-append "r" (signature f))) + (ρ-acc-function f) + (ext1-∇ (∇-function f) (∇-acc-function f) n (shape-fn f) + (string-append "n" (signature f))) + (∇-acc-function f) + (shape-fn f)))) + +(define ext2 + (λ (f m n) + (prim2 + (ext2-ρ (ρ-function f) (ρ-acc-function f) m n (shape-fn f) + (string-append "r" (signature f))) + (ρ-acc-function f) + (ext2-∇ (∇-function f) (∇-acc-function f) m n (shape-fn f) + (string-append "n" (signature f))) + (∇-acc-function f) + (shape-fn f)))) + +(provide prim1 prim2 ext1 ext2 + apply-flat-ρ-fn-1 + apply-flat-ρ-fn-2 + apply-flat-∇-fn-1 + apply-flat-∇-fn-2) diff --git a/accelerated-tensors/autodiff/C-dualized-tensor-ops.rkt b/accelerated-tensors/autodiff/C-dualized-tensor-ops.rkt new file mode 100644 index 0000000..5b02ba3 --- /dev/null +++ b/accelerated-tensors/autodiff/C-dualized-tensor-ops.rkt @@ -0,0 +1,51 @@ +#lang racket + +(require "../tensors.rkt") +(require "A-autodiff.ss") + + +;;---------------------------- +;; Tensor ops, cleaned up. +;;---------------------------- + +(define d-rank + (lambda (t) + (rank (ρ t)))) + +(define d-shape + (λ (t) + (shape (ρ t)))) + +(define d-reshape + (λ (s t) + (cond + ((dual? t) + (dual (reshape s (ρ t)) + (κ t))) + (else (reshape s t))))) + +(define d-trefs + (λ (t b) + (trefs (ρ t) b))) + +(define d-tref + (λ (t i) + (tref (ρ t) i))) + +(define d-tensor? + (λ (t) + (tensor? (ρ t)))) + +(define d-tlen + (λ (t) + (tlen (ρ t)))) + +(define d-ref + (λ (l i) + (ref l (ρ i)))) + +(define d-refr + (λ (l i) + (refr l (ρ i)))) + +(provide d-rank d-shape d-reshape d-trefs d-tensor? d-tlen d-ref d-refr d-tref) diff --git a/accelerated-tensors/autodiff/D-test-helpers.rkt b/accelerated-tensors/autodiff/D-test-helpers.rkt new file mode 100644 index 0000000..208f2b1 --- /dev/null +++ b/accelerated-tensors/autodiff/D-test-helpers.rkt @@ -0,0 +1,51 @@ +#lang racket + +(require "../tensors.rkt") +(require "A-autodiff.ss") +(require "E-print.ss") + +(require rackunit) + +(define-check (check-dual-equal? actual expected) + (unless (equal-wt? actual expected) + (fail-check (format "Duals failed to match.~%actual:~%~s~%expected:~s~%" + (make-printable actual) (make-printable expected))))) +(define-check (ρ-∇-checker fn args ans grads) + (let* ((y (apply fn args)) + (g (apply (∇¹ fn) args))) + (cond + ((and (equal-wt? ans (ρ y)) + (equal-wt? grads (ρ g))) (void)) + ((equal-wt? ans (ρ y)) + (fail-check (format "Gradients failed to match.~%actual:~%~s~%expected:~s~%" + (make-printable (ρ g)) (make-printable grads)))) + (else + (fail-check (format "Answers failed to match.~%actual:~%~s~%expected:~s~%" + (make-printable (ρ y)) (make-printable ans))))))) + +(define-syntax check-ρ-∇ + (syntax-rules () + [(check-both (fn args ...) ans grads) + (ρ-∇-checker fn (list args ...) ans grads)])) + +(define equal-wt? + (λ (a b) + (cond + ((and (tensor? a) (tensor? b)) + (tensor-equal? a b)) + ((dual? a) (equal-wt? (ρ a) b)) + ((dual? b) (equal-wt? a (ρ b))) + ((and (vector? a) (vector? b) + (= (vector-length a) (vector-length b))) + (vector-andmap equal-wt? a b)) + ((and (pair? a) (pair? b) + (= (length a) (length b))) + (andmap equal-wt? a b)) + (else (equal? a b))))) + +(define vector-andmap + (λ (f v1 v2) + (for/fold ([s #t]) ([v1 v1][v2 v2]) + (and s (f v1 v2))))) + +(provide check-dual-equal? check-ρ-∇) diff --git a/accelerated-tensors/autodiff/E-print.rkt b/accelerated-tensors/autodiff/E-print.rkt new file mode 100644 index 0000000..7f4090e --- /dev/null +++ b/accelerated-tensors/autodiff/E-print.rkt @@ -0,0 +1,87 @@ +#lang racket + +(require "../tensors/0-vectors.rkt") +(require "A-autodiff.rkt") +(require "../tensors.rkt") + +(define max-tensor-print-length (make-parameter 5)) + +(struct fake-tensor (members) + #:transparent + #:methods gen:custom-write + ((define write-proc + (λ (fake-tensor port mode) + (let ((n (length (fake-tensor-members fake-tensor)))) + (case mode + ((#t) + (display "(tensor " port) + (for ([m (fake-tensor-members fake-tensor)] + [c (in-range 0 n)]) + (if (symbol? m) + (display m port) + (write m port)) + (when (< c (- n 1)) + (display " " port))) + (display ")" port)) + ((#f) + (display "(tensor " port) + (for ([m (fake-tensor-members fake-tensor)] + [c (in-range 0 n)]) + (display m port) + (when (< c (- n 1)) + (display " " port))) + (display ")" port)) + (else + (display "(tensor " port) + (for ([m (fake-tensor-members fake-tensor)] + [c (in-range 0 n)]) + (if (symbol? m) + (display m port) + (print m port mode)) + (when (< c (- n 1)) + (display " " port))) + (display ")" port)))))))) + +(define make-printable + (λ (y [max-length (max-tensor-print-length)]) + (cond + ((dual? y) (make-printable (ρ y))) + ((flat? y) (make-printable-flat y max-length)) + ((list? y) + (map (λ (le) (make-printable le max-length)) y)) + ((vector? y) + (vector-map (λ (ve) (make-printable ve max-length)) y)) + (else y)))) + +(define make-printable-flat + (λ (y max-length) + (flat->tensor-list + (flat-store y) (flat-offset y) (flat-shape y) + (strides (flat-shape y)) max-length))) + +(define flat->tensor-list + (λ (store offset shape strides max-length) + (cond + ((null? shape) (vref store offset)) + (else + (let ((top-len (car shape)) + (stride (car strides))) + (fake-tensor + (reverse + (call/cc + (λ (return) + (for/fold ((lst '())) ((i (in-range offset (+ offset (* top-len stride)) stride)) + (count (in-naturals 0))) + (cond + ((and (> max-length 0) (= count max-length)) (return (cons '... lst))) + (else + (cons (flat->tensor-list store i (cdr shape) (cdr strides) max-length) + lst))))))))))))) + +(include "test/test-E-print.rkt") + +(provide max-tensor-print-length + make-printable + ;; This is used in ext-impl.rkt + make-printable-flat + fake-tensor) diff --git a/accelerated-tensors/autodiff/test/test-A-autodiff.rkt b/accelerated-tensors/autodiff/test/test-A-autodiff.rkt new file mode 100644 index 0000000..b453b3e --- /dev/null +++ b/accelerated-tensors/autodiff/test/test-A-autodiff.rkt @@ -0,0 +1,15 @@ +(module+ test + (require rackunit) + (let ((k0 end-of-chain)) + (let ((dual0 0) + (dual1 (dual 1 k0))) + + (check-equal? dual1 (dual 1 k0)) + (check-true (dual? dual1)) + (check-false (dual? 1)) + (check-equal? (ρ dual1) 1) + (check-equal? (ρ dual0) 0) + (check-equal? (κ dual1) k0) + + (check-equal? (map* (λ (d) (ρ d)) (∇-once dual1 (list dual0 dual1))) + '(0.0 1.0))))) diff --git a/accelerated-tensors/autodiff/test/test-E-print.rkt b/accelerated-tensors/autodiff/test/test-E-print.rkt new file mode 100644 index 0000000..30f51e6 --- /dev/null +++ b/accelerated-tensors/autodiff/test/test-E-print.rkt @@ -0,0 +1,71 @@ +(module+ test + (require rackunit) + (require "../tensors.rkt") + + (define long-tensor + (tensor 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15)) + + (define dualized-long-tensor + (dual long-tensor end-of-chain)) + + (define deep-tensor + (tensor long-tensor long-tensor long-tensor long-tensor long-tensor + long-tensor long-tensor long-tensor long-tensor long-tensor + long-tensor long-tensor long-tensor long-tensor long-tensor)) + + (define deeper-tensor + (tensor deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor + deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor + deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor)) + + (check-equal? (make-printable-flat long-tensor 3) (fake-tensor '(1.0 2.0 3.0 ...))) + (check-equal? (make-printable-flat deep-tensor 3) + (fake-tensor + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + '...))) + + (check-equal? (make-printable-flat deeper-tensor 3) + (fake-tensor + (list + (fake-tensor + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + '...)) + (fake-tensor + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + '...)) + (fake-tensor + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + '...)) + '...))) + (parameterize ((max-tensor-print-length 3)) + (check-equal? (make-printable dualized-long-tensor 3) (fake-tensor '(1.0 2.0 3.0 ...))) + (check-equal? (make-printable (list long-tensor dualized-long-tensor deeper-tensor)) + (list + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor + (list + (fake-tensor + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + '...)) + (fake-tensor + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + '...)) + (fake-tensor + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + '...)) + '...)))))) diff --git a/accelerated-tensors/ext-impl.rkt b/accelerated-tensors/ext-impl.rkt new file mode 100644 index 0000000..b93c110 --- /dev/null +++ b/accelerated-tensors/ext-impl.rkt @@ -0,0 +1,42 @@ +#lang racket +(require "tensors/0-vectors.rkt") +(require "tensors/1-flats.rkt") +(require (only-in "tensors/2-acc-runtime.rkt" + ext2-∇-kernel/name + run-prim2-∇!)) +(require (only-in "tensors/B-tensor-basics.rkt" + merge-flats)) +(require (only-in "tensors/D-extend.rkt" + merge-shapes + min-shape + ext2-shapes + flat-ext1-∇ + flat-ext1-ρ + flat-ext2-ρ + functional->preallocated-1-ρ + functional->preallocated-1-∇ + functional->preallocated-2-ρ + functional->preallocated-2-∇ + functional->preallocated-1-ρ-acc + functional->preallocated-1-∇-acc + functional->preallocated-2-ρ-acc + functional->preallocated-2-∇-acc + idxs + scalarize + ensure-flat)) +(require (only-in "autodiff/B-prims.rkt" + apply-flat-ρ-fn-1 + apply-flat-ρ-fn-2 + apply-flat-∇-fn-1 + apply-flat-∇-fn-2)) +(require (only-in "autodiff/E-print.rkt" + make-printable-flat + fake-tensor)) + +(provide (all-from-out "tensors/0-vectors.rkt")) +(provide (all-from-out "tensors/1-flats.rkt")) +(provide (all-from-out "tensors/2-acc-runtime.rkt")) +(provide (all-from-out "tensors/B-tensor-basics.rkt")) +(provide (all-from-out "tensors/D-extend.rkt")) +(provide (all-from-out "autodiff/B-prims.rkt")) +(provide (all-from-out "autodiff/E-print.rkt")) diff --git a/accelerated-tensors/ext-ops.rkt b/accelerated-tensors/ext-ops.rkt new file mode 100644 index 0000000..fc223f5 --- /dev/null +++ b/accelerated-tensors/ext-ops.rkt @@ -0,0 +1,40 @@ +#lang racket + +(require "ext-ops/A-scalar-ops.rkt") +(require "ext-ops/B-comparators.rkt") +(require "ext-ops/C-star-2-1.rkt") +(require "ext-ops/D-sum.rkt") +(require "ext-ops/E-argmax.rkt") +(require "ext-ops/F-max.rkt") +(require "ext-ops/G-correlate.rkt") +(require "ext-ops/I-flatten.rkt") +(require "ext-ops/K-concat.rkt") + +(provide d+ d- d* d/ + d-expt d-exp d-log d-abs + d-rectify d-sqrt d-sqr + + +-0-0 --0-0 *-0-0 /-0-0 expt-0-0 + exp-0 log-0 abs-0 rectify-0 sqrt-0 + + +-ρ --ρ *-ρ /-ρ + expt-ρ exp-ρ log-ρ abs-ρ + rectify-ρ sqrt-ρ sqr-ρ zeroes-ρ) + +(provide =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) + +(provide d*-2-1 *-2-1-ρ) + +(provide sum-1 d-sum sum-ρ d-sum-cols sum-cols-ρ) + +(provide argmax-1 d-argmax argmax-ρ) + +(provide max-1 d-max max-ρ) + +(provide correlate-ρ d-correlate) + +(provide flatten-2 d-flatten flatten-ρ) + +(provide concat-1-1 d-concat concat-ρ + d-concat-n concat-n-ρ) diff --git a/accelerated-tensors/ext-ops/A-scalar-ops.rkt b/accelerated-tensors/ext-ops/A-scalar-ops.rkt new file mode 100644 index 0000000..72b33f9 --- /dev/null +++ b/accelerated-tensors/ext-ops/A-scalar-ops.rkt @@ -0,0 +1,210 @@ +#lang racket + +(require string-interpolation) +(require (only-in "../tensors.rkt" ext1-ρ ext2-ρ)) +(require "../autodiff.rkt") + +(define +-0-0-ρ-acc + (λ (a b) + "@{a}+@{b}")) + +(define +-0-0 + (prim2 + + +-0-0-ρ-acc + (λ (a b z) + (values z z)) + (λ (a b z) + (values z z)))) + +(define --0-0-ρ-acc + (λ (a b) + "@{a}-@{b}")) + +(define --0-0 + (prim2 - + --0-0-ρ-acc + (λ (a b z) + (values z (- z))) + (λ (a b z) + (values z "(- @{z})")))) + +(define *-0-0-ρ-acc + (λ (a b) + "@{a}*@{b}")) + +(define *-0-0 + (prim2 * + *-0-0-ρ-acc + (λ (a b z) + (values (* b z) (* a z))) + (λ (a b z) + (values "@{b}*@{z}" "@{a}*@{z}")))) + +(define /-0-0-ρ-acc + (λ (a b) + "@{a}/@{b}")) + +(define /-0-0 + (prim2 / + /-0-0-ρ-acc + (λ (a b z) + (values (* z (/ 1 b)) + (* z (/ (- a) (* b b))))) + (λ (a b z) + (values "(@{z} * (1 / @{b}))" + "(@{z} * ((- @{a}) / (@{b} * @{b})))")))) + +(define expt-0-0-ρ-acc + (λ (a b) + "pow(@{a}, @{b})")) + +(define expt-0-0 + (prim2 expt + expt-0-0-ρ-acc + (λ (a b z) + (values (* z (* b (expt a (- b 1)))) + (* z (* (expt a b) (log a))))) + (λ (a b z) + (values "(@{z} * (@{b} * pow(@{a}, (@{b} - 1))))" + "(@{z} * (pow(@{a}, @{b}) * log(@{a})))")))) + +(define exp-0-ρ-acc + (λ (a) + "exp(@{a})")) + +(define exp-0 + (prim1 exp + exp-0-ρ-acc + (λ (a z) + (* z (exp a))) + (λ (a z) + "(@{z} * exp(@{a}))"))) + +(define log-0-ρ-acc + (λ (a) + "log(@{a})")) + +(define log-0 + (prim1 log + log-0-ρ-acc + (λ (a z) + (* z (/ 1 a))) + (λ (a z) + "(@{z} * (1 / @{a}))"))) + +(define sqrt-0-ρ-acc + (λ (a) + "sqrt(@{a})")) + +(define sqrt-0 + (prim1 sqrt + sqrt-0-ρ-acc + (λ (x z) + (/ z (* 2 (sqrt x)))) + (λ (x z) + "(@{z} / (2 * sqrt(@{x})))"))) + +(define abs-0-ρ + (λ (x) + (cond + ((< x 0) (* -1 x)) + (else x)))) + +(define abs-0-ρ-acc + (λ (x) + "fabs(@{x})")) + +(define abs-0-∇ + (λ (x z) + (cond + ((< x 0) (- z)) + (else z)))) + +(define abs-0-∇-acc + (λ (x z) + "sign(@{x}) * @{z}")) + +(define abs-0 + (prim1 abs-0-ρ abs-0-ρ-acc abs-0-∇ abs-0-∇-acc)) + +(define rectify-0-ρ + (λ (s) + (cond + ((< s 0.0) 0.0) + (else s)))) + +(define rectify-0-ρ-acc + (λ (s) + "fmax(0.0f, @{s})")) + +(define rectify-0-∇ + (λ (s z) + (cond + ((< s 0.0) 0.0) + (else z)))) + +(define rectify-0-∇-acc + (λ (s z) + "step(0, @{s}) * @{z}")) + +(define rectify-shape + (λ (s) s)) + +(define rectify-0 + (prim1 rectify-0-ρ rectify-0-ρ-acc rectify-0-∇ rectify-0-∇-acc rectify-shape)) + +;;------------------------------------ +;; differentiable extended functions. +;;------------------------------------ + +(define d* (ext2 *-0-0 0 0)) +(define d+ (ext2 +-0-0 0 0)) +(define d- (ext2 --0-0 0 0)) +(define d/ (ext2 /-0-0 0 0)) +(define d-expt (ext2 expt-0-0 0 0)) + +(define d-exp (ext1 exp-0 0)) +(define d-log (ext1 log-0 0)) +(define d-abs (ext1 abs-0 0)) +(define d-rectify (ext1 rectify-0 0)) +(define d-sqrt (ext1 sqrt-0 0)) + +(define d-sqr + (λ (x) + (d* x x))) + +;;------------------------------------ +;; non-differentiable extended functions. +;;------------------------------------ + +(define *-ρ (ext2-ρ * *-0-0-ρ-acc 0 0)) +(define +-ρ (ext2-ρ + +-0-0-ρ-acc 0 0)) +(define --ρ (ext2-ρ - --0-0-ρ-acc 0 0)) +(define /-ρ (ext2-ρ / /-0-0-ρ-acc 0 0)) +(define expt-ρ (ext2-ρ expt expt-0-0-ρ-acc 0 0)) + +(define exp-ρ (ext1-ρ exp exp-0-ρ-acc 0)) +(define log-ρ (ext1-ρ log log-0-ρ-acc 0)) +(define abs-ρ (ext1-ρ abs-0-ρ abs-0-ρ-acc 0)) +(define rectify-ρ (ext1-ρ rectify-0-ρ rectify-0-ρ-acc 0)) +(define sqrt-ρ (ext1-ρ sqrt sqrt-0-ρ-acc 0)) + +(define sqr-ρ + (λ (x) + (*-ρ x x))) + +(define zeroes-ρ + (ext1-ρ (λ (_) 0.0) (λ (_) "0.0") 0)) + +(include "test/test-A-scalar-ops.rkt") + +(provide +-0-0 --0-0 *-0-0 /-0-0 expt-0-0 + exp-0 log-0 abs-0 rectify-0 sqrt-0 + + d+ d- d* d/ + d-expt d-exp d-log d-abs + d-rectify d-sqrt d-sqr + + +-ρ --ρ *-ρ /-ρ + expt-ρ exp-ρ log-ρ abs-ρ + rectify-ρ sqrt-ρ sqr-ρ zeroes-ρ) diff --git a/accelerated-tensors/ext-ops/B-comparators.rkt b/accelerated-tensors/ext-ops/B-comparators.rkt new file mode 100644 index 0000000..3db8e0f --- /dev/null +++ b/accelerated-tensors/ext-ops/B-comparators.rkt @@ -0,0 +1,99 @@ +#lang racket + +(require string-interpolation) +(require "../autodiff.rkt") + +;;---------------------------- +;; Boolean comparators +;;---------------------------- + +(define comparator + (λ (f) + (λ (da db) + (f (ρ da) (ρ db))))) + +(define =-0-0 + (comparator =)) + +(define <-0-0 + (comparator <)) + +(define <=-0-0 + (comparator <=)) + +(define >-0-0 + (comparator >)) + +(define >=-0-0 + (comparator >)) + +;;---------------------------- +;; Tensorized comparators +;;---------------------------- + +(define comparator-ρ + (λ (f) + (λ (da db) + (cond + ((f (ρ da) (ρ db)) 1.0) + (else 0.0))))) + +(define comparator-ρ-acc + (λ (f) + (λ (a b) + "@{a} @{f} @{b}"))) + +(define comparator-∇ + (λ (f) + (λ (da db z) + (cond + ((f (ρ da) (ρ db)) (values z z)) + (else (values 0.0 0.0)))))) + +(define comparator-∇-acc + (λ (f) + (λ (a b z) + (let ((bool "@{a} @{f} @{b}")) + (values "@{bool}*@{z}" "@{bool}*@{z}"))))) + +(define comparator-shape + (λ (f) + (λ (sa sb) + sa))) + +(define comparator-prim + (λ (f f-acc) + (prim2 (comparator-ρ f) (comparator-ρ-acc f-acc) + (comparator-∇ f) (comparator-∇-acc f-acc) + (comparator-shape f)))) + +(define extended-comparator + (λ (f f-acc) + (ext2 (comparator-prim f f-acc) 0 0))) + +(define =-1 + (extended-comparator = "==")) + +(define <-1 + (extended-comparator < "<")) + +(define >-1 + (extended-comparator > ">")) + +(define <=-1 + (extended-comparator <= "<=")) + +(define >=-1 + (extended-comparator >= ">=")) + +(define != + (λ (a b) + (not (= a b)))) + +(define !=-1 + (extended-comparator != "!=")) + +(include "test/test-B-comparators.rkt") + +(provide =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/accelerated-tensors/ext-ops/C-star-2-1.rkt b/accelerated-tensors/ext-ops/C-star-2-1.rkt new file mode 100644 index 0000000..8cb8ab8 --- /dev/null +++ b/accelerated-tensors/ext-ops/C-star-2-1.rkt @@ -0,0 +1,79 @@ +#lang racket + +(require string-interpolation) +(require "../tensors/0-vectors.rkt") +(require (only-in "../tensors.rkt" ext2-ρ)) +(require "../autodiff.rkt") + +(define *-2-1-base-ρ + (λ (v0 i0 stride0 + v1 i1 stride1 + v-out i-out stride-out) + (for ([i (in-range 0 stride-out)]) + (vset! v-out (+ i-out i) + (* (vref v0 (+ i0 i)) + (vref v1 (+ i1 (modulo i stride1)))))))) + +(define *-2-1-base-ρ-acc + (λ (v0 i0 stride0 + v1 i1 stride1 + v-out i-out stride-out) + #< v max) (values v (+ (- i i0) 0.0))) + (else (values max max-i)))))))) + +(define argmax-1-ρ-acc + (λ (v0 i0 stride0 + v-out i-out stride-out) + #< max) { + max = v; + max_i = i - @{i0} + 0.0; + } + } + @{v-out}[@{i-out}] = max_i; +EOF + )) + +(define argmax-1-∇ + (λ (g0 v0 i0 stride0 + vz iz stride-z) + (let ((z (vref vz iz))) + (for ([i (in-range i0 (+ i0 stride0))]) + (vset! g0 i 0.0))))) + +(define argmax-1-∇-acc + (λ (g0 v0 i0 stride0 + vz iz stride-z) + #< v max) v) + (else max))))))) + +(define max-1-ρ-acc + (λ (v0 i0 stride0 + v-out i-out stride-out) + #< v max) (values v (- i i0))) + (else (values max max-i)))))))) + +(define max-1-∇-acc + (λ (g0 v0 i0 stride0 + vz iz stride-z) + #< max) { + max = v; + max_i = i - @{i0}; + } + } + for(int i=@{i0}; i<@{i0}+@{stride0}; i++) { + if(i == @{i0}+max_i) { + @{g0}[i] += z; + } else { + @{g0}[i] += 0.0; + } + } +EOF + )) + +(define max-shape + (λ (st) + (cdr st))) + +(define max-1 + (prim1 max-1-ρ max-1-ρ-acc max-1-∇ max-1-∇-acc max-shape)) + +(define d-max + (ext1 max-1 1)) + +(define max-ρ + (ext1-ρ max-1-ρ max-1-ρ-acc 1 max-shape)) + +(include "test/test-F-max.rkt") + +(provide max-1 d-max max-ρ) diff --git a/accelerated-tensors/ext-ops/G-correlate.rkt b/accelerated-tensors/ext-ops/G-correlate.rkt new file mode 100644 index 0000000..1e2ba84 --- /dev/null +++ b/accelerated-tensors/ext-ops/G-correlate.rkt @@ -0,0 +1,158 @@ +#lang racket + +(require string-interpolation) +(require "../tensors/0-vectors.rkt") +(require (only-in "../tensors.rkt" ext2-ρ len)) +(require "../autodiff.rkt") + +;; Correlation is written taking into account how ext2 works +;; Ext2 is responsible for producing the i-out'th output from +;; v0[i0] and v1[i1], we take advantage of this. The shape constants +;; n b m d are pre-calculated the striding constants nd md and qd +;; are calculated. + +(define correlate-3-1-ρ + (λ (nd md qd) + (λ (v0 i0 _ + v1 i1 d + v-out i-out b) + (let* ((i1-min (- i1 (modulo i1 nd))) + (i1-max (+ i1-min nd))) + (for ((i (in-range 0 b))) + (vset! v-out (+ i-out i) + (for/fold ([sum 0.0]) ([j (in-range 0 md)]) + (let ((ai (+ i0 (* i md) j)) + (bi (- (+ i1 j) qd))) + (cond + ((and (>= bi i1-min) (< bi i1-max)) + (let ((a (vref v0 ai)) + (b (vref v1 bi))) + (+ sum (* a b)))) + (else sum)))))))))) + +(define correlate-3-1-ρ-acc + (λ (nd md qd) + (λ (v0 i0 _ + v1 i1 d + v-out i-out b) + #<= i1_min && bi < i1_max) { + sum += @{v0}[ai] * @{v1}[bi]; + } + } + @{v-out}[@{i-out}+i] = sum; + } +EOF + ))) + +(define correlate-3-1-∇ + (λ (nd md qd) + (λ (g0 g1 + v0 i0 bmd + v1 i1 d + vz iz b) + (let* ((i1-min (- i1 (modulo i1 nd))) + (i1-max (+ i1-min nd))) + (for ((i (in-range 0 b))) + (let ((z (vref vz (+ iz i)))) + (for ([j (in-range 0 md)]) + (let ((ai (+ i0 (* i md) j)) + (bi (- (+ i1 j) qd))) + (when (and (>= bi i1-min) (< bi i1-max)) + (let ((a (vref v0 ai)) + (b (vref v1 bi))) + (vset! g0 ai + (+ (vref g0 ai) (* z b))) + (vset! g1 bi + (+ (vref g1 bi) (* z a))))))))))))) + +(define correlate-3-1-∇-acc + (λ (nd md qd) + (λ (g + v0 i0 bmd + v1 i1 d + vz iz b) + (values + #<= i1_min && bi < i1_max) { + @{g}[ai] += z * @{v1}[bi]; + } + } + } +EOF + + #<= i1_min && bi < i1_max) { + @{g}[bi] += z * @{v0}[ai]; + } + } + } +EOF + )))) + +(define correlate-shape + (λ (bmd nd) + (list (car bmd)))) + +(define correlate-3-1 + (λ (nd md qd) + (prim2 + (correlate-3-1-ρ nd md qd) + (correlate-3-1-ρ-acc nd md qd) + (correlate-3-1-∇ nd md qd) + (correlate-3-1-∇-acc nd md qd) + correlate-shape))) + +(define d-correlate + (λ (bank signal) + (let* ((b-m-d (last 3 (shape (ρ bank)))) + (n-d (last 2 (shape (ρ signal)))) + (d (ref n-d 1)) + (nd (* d (ref n-d 0))) + (m (ref b-m-d 1)) + (q (/ (- m 1) 2)) ;; This is the padding. + (qd (* q d)) + (md (* m d))) + ((ext2 (correlate-3-1 nd md qd) 3 1) bank signal)))) + +(define correlate-ρ + (λ (bank signal) + (let* ((b-m-d (last 3 (shape (ρ bank)))) + (n-d (last 2 (shape (ρ signal)))) + (d (ref n-d 1)) + (nd (* d (ref n-d 0))) + (m (ref b-m-d 1)) + (q (/ (- m 1) 2)) ;; This is the padding. + (qd (* q d)) + (md (* m d))) + ((ext2-ρ (correlate-3-1-ρ nd md qd) (correlate-3-1-ρ-acc nd md qd) 3 1 correlate-shape) + bank signal)))) + +(define last + (λ (n s) + (refr s (- (len s) n)))) + +(include "test/test-G-correlate.rkt") + +(provide d-correlate correlate-ρ) diff --git a/accelerated-tensors/ext-ops/I-flatten.rkt b/accelerated-tensors/ext-ops/I-flatten.rkt new file mode 100644 index 0000000..0ef02f6 --- /dev/null +++ b/accelerated-tensors/ext-ops/I-flatten.rkt @@ -0,0 +1,52 @@ +#lang racket + +(require string-interpolation) +(require (only-in "../tensors.rkt" ext1-ρ tref reshape shape ref)) +(require (only-in "../autodiff.rkt" prim1 ext1)) + +(define flatten-2-ρ + (λ (t) + (reshape (flatten-shape (shape t)) t))) + +(define flatten-2-ρ-acc + (λ (v0 i0 stride0 + v-out i-out stride-out) + #<-0-0 a b)) + (check-true (<=-0-0 a b)) + (check-false (>=-0-0 a b)) + (check-false (=-0-0 a b)) + (check-true (=-0-0 a a)) + (check-true (zero? 0)) + (check-false (zero? a)))) diff --git a/accelerated-tensors/ext-ops/test/test-C-star-2-1.rkt b/accelerated-tensors/ext-ops/test/test-C-star-2-1.rkt new file mode 100644 index 0000000..bbb2c8b --- /dev/null +++ b/accelerated-tensors/ext-ops/test/test-C-star-2-1.rkt @@ -0,0 +1,24 @@ +(module+ test + (require rackunit) + (require (only-in "../tensors.rkt" tensor)) + + (let ((a (tensor (tensor 3 4 5 6) + (tensor 7 8 9 10))) + (b (tensor 2 3 4 5))) + (check-ρ-∇ (d*-2-1 a b) + (tensor (tensor 6 12 20 30) (tensor 14 24 36 50)) + (list (tensor (tensor 2.0 3.0 4.0 5.0) (tensor 2.0 3.0 4.0 5.0)) + (tensor 10.0 12.0 14.0 16.0)))) + + (let ((a (tensor (tensor 3 4 5 6) + (tensor 7 8 9 10))) + (b (tensor (tensor 2 3 4 5) + (tensor 12 13 14 15)))) + + (check-ρ-∇ (d*-2-1 a b) + (tensor (tensor (tensor 6 12 20 30) (tensor 14 24 36 50)) + (tensor (tensor 36 52 70 90) (tensor 84 104 126 150))) + (list (tensor (tensor 14.0 16.0 18.0 20.0) + (tensor 14.0 16.0 18.0 20.0)) + (tensor (tensor 10.0 12.0 14.0 16.0) + (tensor 10.0 12.0 14.0 16.0)))))) diff --git a/accelerated-tensors/ext-ops/test/test-D-sum.rkt b/accelerated-tensors/ext-ops/test/test-D-sum.rkt new file mode 100644 index 0000000..b4f34b3 --- /dev/null +++ b/accelerated-tensors/ext-ops/test/test-D-sum.rkt @@ -0,0 +1,78 @@ +(module+ test + (require rackunit) + (require "C-star-2-1.ss") + (require (only-in "../tensors.rkt" tensor)) + (require (only-in "A-scalar-ops.ss" d-sqr d* d-)) + + (let ((a (tensor 3 4 5))) + (check-ρ-∇ (sum-1 a) 12 + (list (tensor 1.0 1.0 1.0))) + (check-ρ-∇ (d-sum a) 12 + (list (tensor 1.0 1.0 1.0)))) + + (let ((a (tensor (tensor 3 4 5) + (tensor 6 7 8)))) + (check-dual-equal? (d-sum a) (tensor 12 21)) + (check-dual-equal? ((∇¹ (λ (b) (d-sum (d* b b)))) a) + (list (tensor (tensor 6.0 8.0 10.0) + (tensor 12.0 14.0 16.0)))) + (check-ρ-∇ (d-sum-cols a) (tensor 9 11 13) + (list (tensor (tensor 1 1 1) + (tensor 1 1 1))))) + + (let ((a (tensor + (tensor (tensor (tensor 3 4 5) (tensor 6 7 8)) + (tensor (tensor 8 7 6) (tensor 5 4 3))) + (tensor (tensor (tensor 1 2 3) (tensor 6 5 4)) + (tensor (tensor 7 8 9) (tensor 9 8 7)))))) + (check-ρ-∇ (d-sum-cols a) + (tensor + (tensor (tensor 9 11 13) + (tensor 13 11 9)) + (tensor (tensor 7 7 7) + (tensor 16 16 16))) + (list (tensor + (tensor (tensor (tensor 1 1 1) (tensor 1 1 1)) + (tensor (tensor 1 1 1) (tensor 1 1 1))) + (tensor (tensor (tensor 1 1 1) (tensor 1 1 1)) + (tensor (tensor 1 1 1) (tensor 1 1 1))))))) + + (define dot-product + (λ (a b) + (d-sum (d*-2-1 a b)))) + + (define sse + (λ (a b) + (d-sum (d-sqr (d- a b))))) + + (let ((a (tensor (tensor 3 4 5 6) + (tensor 7 8 9 10))) + (b (tensor 2 3 4 5))) + + (check-ρ-∇ (sum-1 b) 14 + (list (tensor 1.0 1.0 1.0 1.0))) + + (check-ρ-∇ (dot-product a b) + (tensor 68 124) + (list (tensor (tensor 2.0 3.0 4.0 5.0) + (tensor 2.0 3.0 4.0 5.0)) + (tensor 10.0 12.0 14.0 16.0))) + + (check-ρ-∇ (sse a b) + (tensor 4 100) + (list (tensor (tensor 2.0 2.0 2.0 2.0) + (tensor 10.0 10.0 10.0 10.0)) + (tensor -12.0 -12.0 -12.0 -12.0)))) + + (let ((a (tensor (tensor 3 4 5 6) + (tensor 7 8 9 10))) + (b (tensor (tensor 2 3 4 5) + (tensor 12 13 14 15)))) + + (check-ρ-∇ (dot-product a b) + (tensor (tensor 68 124) + (tensor 248 464)) + (list (tensor (tensor 14.0 16.0 18.0 20.0) + (tensor 14.0 16.0 18.0 20.0)) + (tensor (tensor 10.0 12.0 14.0 16.0) + (tensor 10.0 12.0 14.0 16.0)))))) diff --git a/accelerated-tensors/ext-ops/test/test-E-argmax.rkt b/accelerated-tensors/ext-ops/test/test-E-argmax.rkt new file mode 100644 index 0000000..5cf6a4d --- /dev/null +++ b/accelerated-tensors/ext-ops/test/test-E-argmax.rkt @@ -0,0 +1,21 @@ +(module+ test + (require (only-in "../tensors.rkt" tensor)) + + (let ((y (tensor 0.0 0.0 1.0 0.0))) + (check-ρ-∇ (d-argmax y) 2.0 + (list (tensor 0.0 0.0 0.0 0.0)))) + + (let ((y (tensor -10 -3 -2 -5))) + (check-ρ-∇ (d-argmax y) 2.0 + (list (tensor 0.0 0.0 0.0 0.0)))) + + (let ((y (tensor (tensor 0.0 0.0 1.0 0.0) + (tensor 0.0 1.0 0.0 0.0) + (tensor 1.0 0.0 0.0 0.0) + (tensor 0.0 0.0 0.0 1.0)))) + (check-ρ-∇ (d-argmax y) (tensor 2.0 1.0 0.0 3.0) + (list + (tensor (tensor 0.0 0.0 0.0 0.0) + (tensor 0.0 0.0 0.0 0.0) + (tensor 0.0 0.0 0.0 0.0) + (tensor 0.0 0.0 0.0 0.0)))))) diff --git a/accelerated-tensors/ext-ops/test/test-F-max.rkt b/accelerated-tensors/ext-ops/test/test-F-max.rkt new file mode 100644 index 0000000..01ab1a5 --- /dev/null +++ b/accelerated-tensors/ext-ops/test/test-F-max.rkt @@ -0,0 +1,10 @@ +(module+ test + (require rackunit) + (require (only-in "../tensors.rkt" tensor)) + + (let ((y (tensor (tensor 0.0 0.0 1.0 0.0) + (tensor 0.0 1.0 0.0 0.0) + (tensor 1.0 0.0 0.0 0.0) + (tensor 0.0 0.0 0.0 1.0)))) + (check-ρ-∇ (d-max y) (tensor 1.0 1.0 1.0 1.0) + (list y)))) diff --git a/accelerated-tensors/ext-ops/test/test-G-correlate.rkt b/accelerated-tensors/ext-ops/test/test-G-correlate.rkt new file mode 100644 index 0000000..8afe9b6 --- /dev/null +++ b/accelerated-tensors/ext-ops/test/test-G-correlate.rkt @@ -0,0 +1,118 @@ +(module+ test + (require rackunit) + (require (only-in "../tensors.rkt" tensor ext2-∇ check-tensor-equal?)) + + ;; for testing b = 4 + ;; m = 3 + ;; d = 2 + + ;; signal length n = 6 + + ;; (1 2) (3 4) (5 6) (7 8) (9 10) (11 12) + ;; (1 2) (3 4) (5 6) + ;; (7 8) (9 10) (11 12) + ;; (13 14) (15 16) (17 18) + ;; (19 20) (21 22) (23 24) + + ;; Signal is (n d) + (define signal (tensor (tensor 1 2) + (tensor 3 4) + (tensor 5 6) + (tensor 7 8) + (tensor 9 10) + (tensor 11 12))) + + (define bank (tensor (tensor + (tensor 1 2) + (tensor 3 4) + (tensor 5 6)) + (tensor + (tensor 7 8) + (tensor 9 10) + (tensor 11 12)) + (tensor + (tensor 13 14) + (tensor 15 16) + (tensor 17 18)) + (tensor + (tensor 19 20) + (tensor 21 22) + (tensor 23 24)))) + + (define corr-ρ + (ext2-ρ (correlate-3-1-ρ 12 6 2) (correlate-3-1-ρ-acc 12 6 2) 3 1 correlate-shape)) + + (define corr-∇ + (ext2-∇ (correlate-3-1-∇ 12 6 2) (correlate-3-1-∇-acc 12 6 2) 3 1 correlate-shape)) + + (check-tensor-equal? (corr-ρ bank signal) + ;; Should be of size nb + (tensor (tensor 50.0 110.0 170.0 230.0) + (tensor 91.0 217.0 343.0 469.0) + (tensor 133.0 331.0 529.0 727.0) + (tensor 175.0 445.0 715.0 985.0) + (tensor 217.0 559.0 901.0 1243.0) + (tensor 110.0 362.0 614.0 866.0))) + + (let-values (((filter-∇ signal-∇) + (corr-∇ bank signal (tensor (tensor 1.0 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0 1.0))))) + (check-tensor-equal? filter-∇ + (tensor + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)))) + (check-tensor-equal? signal-∇ + ;; Should be of size nb + (tensor (tensor 88.0 96.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 104.0 112.0)))) + + (check-dual-equal? (d-correlate bank signal) + ;; Should be of size nb + (tensor (tensor 50.0 110.0 170.0 230.0) + (tensor 91.0 217.0 343.0 469.0) + (tensor 133.0 331.0 529.0 727.0) + (tensor 175.0 445.0 715.0 985.0) + (tensor 217.0 559.0 901.0 1243.0) + (tensor 110.0 362.0 614.0 866.0))) + + (let ((gs ((∇¹ d-correlate) bank signal))) + (check-dual-equal? (car gs) + (tensor + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)))) + (check-dual-equal? (cadr gs) + ;; Should be of size nb + (tensor (tensor 88.0 96.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 104.0 112.0))))) diff --git a/accelerated-tensors/ext-ops/test/test-I-flatten.rkt b/accelerated-tensors/ext-ops/test/test-I-flatten.rkt new file mode 100644 index 0000000..d562ccd --- /dev/null +++ b/accelerated-tensors/ext-ops/test/test-I-flatten.rkt @@ -0,0 +1,14 @@ +(module+ test + (require rackunit) + (require (only-in "../tensors.rkt" tensor)) + (require (only-in "../autodiff.rkt" check-ρ-∇ check-dual-equal?)) + (require (only-in "A-scalar-ops.rkt" d*)) + + (define r2-t1 (tensor (tensor 3.0 4.0) (tensor 5.0 6.0))) + (define r1-t1 (tensor 3.0 4.0 5.0 6.0)) + + (check-dual-equal? (flatten-2 r2-t1) r1-t1) + (check-dual-equal? (d-flatten r2-t1) r1-t1) + (check-ρ-∇ ((λ (t1 t2) (d* t1 (flatten-2 t2))) r1-t1 r2-t1) + (tensor 9.0 16.0 25.0 36.0) + (list (tensor 3.0 4.0 5.0 6.0) (tensor (tensor 3.0 4.0) (tensor 5.0 6.0))))) diff --git a/accelerated-tensors/ext-ops/test/test-K-concat.rkt b/accelerated-tensors/ext-ops/test/test-K-concat.rkt new file mode 100644 index 0000000..b427ced --- /dev/null +++ b/accelerated-tensors/ext-ops/test/test-K-concat.rkt @@ -0,0 +1,126 @@ +(module+ test + (require rackunit) + (require (only-in "../tensors.rkt" tensor)) + (require (only-in "../autodiff.rkt" check-ρ-∇ check-dual-equal?)) + (require (only-in "A-scalar-ops.rkt" d*)) + + (define r2-t1 (tensor (tensor 3.0 4.0) (tensor 5.0 6.0))) + (define r1-t2 (tensor 5.0 6.0 7.0)) + (define r1-t1 (tensor 3.0 4.0 5.0 6.0 7.0)) + + (check-dual-equal? + (d-concat r2-t1 r1-t2) + (tensor (tensor 3.0 4.0 5.0 6.0 7.0) + (tensor 5.0 6.0 5.0 6.0 7.0))) + + (check-ρ-∇ ((λ (t1 t2 t3) (d* t3 (d-concat t1 t2))) r2-t1 r1-t2 r1-t1) + (tensor (tensor 9.0 16.0 25.0 36.0 49.0) + (tensor 15.0 24.0 25.0 36.0 49.0)) + (list (tensor (tensor 3.0 4.0) (tensor 3.0 4.0)) + (tensor 10.0 12.0 14.0) + (tensor 8.0 10.0 10.0 12.0 14.0))) + (define r3-t1 + (tensor (tensor (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0)) + + (tensor (tensor 9.0 10.0) + (tensor 11.0 12.0) + (tensor 13.0 14.0) + (tensor 15.0 16.0)) + + (tensor (tensor 17.0 18.0) + (tensor 19.0 20.0) + (tensor 21.0 22.0) + (tensor 23.0 24.0)))) + + + (define r2-t2 + (tensor (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0))) + + (define r1-t3 + (tensor 0.5 0.5)) + + (define concat-2 (d-concat-n 2)) + + (check-dual-equal? + (concat-2 r3-t1 r2-t2) + (tensor (tensor (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0) + (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0)) + + (tensor (tensor 9.0 10.0) + (tensor 11.0 12.0) + (tensor 13.0 14.0) + (tensor 15.0 16.0) + (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0)) + + (tensor (tensor 17.0 18.0) + (tensor 19.0 20.0) + (tensor 21.0 22.0) + (tensor 23.0 24.0) + (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0)))) + + + (check-ρ-∇ ((λ (t1 t2 t3) (d* t3 (concat-2 t1 t2))) r3-t1 r2-t2 r1-t3) + (tensor (tensor (tensor 0.5 1.0) + (tensor 1.5 2.0) + (tensor 2.5 3.0) + (tensor 3.5 4.0) + (tensor 0.5 1.0) + (tensor 1.5 2.0) + (tensor 2.5 3.0) + (tensor 3.5 4.0)) + + (tensor (tensor 4.5 5.0) + (tensor 5.5 6.0) + (tensor 6.5 7.0) + (tensor 7.5 8.0) + (tensor 0.5 1.0) + (tensor 1.5 2.0) + (tensor 2.5 3.0) + (tensor 3.5 4.0)) + + (tensor (tensor 8.5 9.0) + (tensor 9.5 10.0) + (tensor 10.5 11.0) + (tensor 11.5 12.0) + (tensor 0.5 1.0) + (tensor 1.5 2.0) + (tensor 2.5 3.0) + (tensor 3.5 4.0))) + (list + (tensor (tensor (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5)) + (tensor (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5)) + (tensor (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5))) + + (tensor (tensor 1.5 1.5) + (tensor 1.5 1.5) + (tensor 1.5 1.5) + (tensor 1.5 1.5)) + + (tensor 192.0 216.0)))) diff --git a/accelerated-tensors/no-duals-no-overrides.rkt b/accelerated-tensors/no-duals-no-overrides.rkt new file mode 100644 index 0000000..07ca22e --- /dev/null +++ b/accelerated-tensors/no-duals-no-overrides.rkt @@ -0,0 +1,29 @@ +#lang racket/base + +(module+ test + (require rackunit)) + +(require "tensors.rkt") +(require "ext-ops.rkt") + +(define scalar? number?) + +(provide + ;; From tensors + len ref refr + + tref tlen list->tensor tensor build-tensor + + ext1-ρ ext2-ρ + + scalar? tensor? rank shape reshape trefs + + ;; From ext-ops + +-ρ --ρ *-ρ /-ρ rectify-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ + sum-ρ abs-ρ *-2-1-ρ argmax-ρ + max-ρ sum-cols-ρ correlate-ρ + concat-ρ concat-n-ρ flatten-ρ + + =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/accelerated-tensors/no-duals.rkt b/accelerated-tensors/no-duals.rkt new file mode 100644 index 0000000..cd1bcaf --- /dev/null +++ b/accelerated-tensors/no-duals.rkt @@ -0,0 +1,29 @@ +#lang racket/base + +(module+ test + (require rackunit)) + +(require "tensors.rkt") +(require "ext-ops.rkt") + +(define scalar? number?) + +(provide + ;; From tensors + len ref refr + + tref tlen list->tensor tensor build-tensor + + ext1-ρ ext2-ρ + + scalar? tensor? rank shape reshape trefs + + ;; From ext-ops + (rename-out (+-ρ +) (--ρ -) (*-ρ *) (/-ρ /) (rectify-ρ rectify) + (exp-ρ exp) (log-ρ log) (expt-ρ expt) (sqrt-ρ sqrt) (sqr-ρ sqr) (zeroes-ρ zeroes) + (sum-ρ sum) (abs-ρ abs) (*-2-1-ρ *-2-1) (argmax-ρ argmax) + (max-ρ max) (sum-cols-ρ sum-cols) (correlate-ρ correlate) + (flatten-ρ flatten) (concat-ρ concat) (concat-n-ρ concat-n)) + + =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/accelerated-tensors/no-overrides.rkt b/accelerated-tensors/no-overrides.rkt new file mode 100644 index 0000000..05844b7 --- /dev/null +++ b/accelerated-tensors/no-overrides.rkt @@ -0,0 +1,43 @@ +#lang racket/base + +(require + (except-in "tensors.rkt" + rank shape reshape trefs tref tensor? tlen ref refr)) + +(require "autodiff.rkt") +(require "ext-ops.rkt") + +(provide + len ref refr + + tref tlen list->tensor tensor build-tensor + + ext1-ρ ext2-ρ ext1-∇ ext2-∇ + + dual dual? ρ κ ∇ ∇¹ + + ext1 ext2 prim1 prim2 + + scalar? tensor? rank shape reshape trefs + + trace-print check-dual-equal? check-ρ-∇ + make-printable + + +-0-0 --0-0 *-0-0 /-0-0 expt-0-0 + exp-0 log-0 abs-0 rectify-0 sqrt-0 + flatten-2 concat-1-1 + + d+ d- d* d/ d-rectify + d-exp d-log d-expt d-sqrt d-sqr + d-sum d-abs d*-2-1 d-argmax + d-max d-sum-cols d-correlate + d-flatten d-concat d-concat-n + + +-ρ --ρ *-ρ /-ρ rectify-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ + sum-ρ abs-ρ *-2-1-ρ argmax-ρ + max-ρ sum-cols-ρ correlate-ρ + flatten-ρ concat-ρ concat-n-ρ + + =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/accelerated-tensors/tensors.rkt b/accelerated-tensors/tensors.rkt new file mode 100644 index 0000000..0daca26 --- /dev/null +++ b/accelerated-tensors/tensors.rkt @@ -0,0 +1,22 @@ +#lang racket +(require "tensors/0-vectors.rkt") +(require "tensors/1-flats.rkt") +(require "tensors/A-equality.rkt") +(require "tensors/B-tensor-basics.rkt") +(require "tensors/C-tensor-ops.rkt") +(require "tensors/D-extend.rkt") + +(provide start-vector-manager vector-manager-report) + +(provide tolerance tensor-equal? check-tensor-equal?) + +(provide len ref refr) +(provide tref tlen list->tensor tensor build-tensor trefs) + +(provide ext1-ρ ext2-ρ ext1-∇ ext2-∇ expects-preallocated?) + +(provide flat flat? flat-shape flat-store flat-offset size-of strides) + +;; These will get overriden by duals +(provide tensor?) +(provide rank shape reshape size-of) diff --git a/accelerated-tensors/tensors/0-vectors.rkt b/accelerated-tensors/tensors/0-vectors.rkt new file mode 100644 index 0000000..8135743 --- /dev/null +++ b/accelerated-tensors/tensors/0-vectors.rkt @@ -0,0 +1,101 @@ +#lang racket +(require ffi/vector) +(require ffi/unsafe) + +;;------------------------------------------------ +;; Raw representation of vectors +;;------------------------------------------------ + +(define vec? f32vector?) +(define vec f32vector) +(define make-vec make-f32vector) +(define vref f32vector-ref) +(define vset! f32vector-set!) +(define vlen f32vector-length) +(define list->vec list->f32vector) +(define build-vec + (λ (n proc) + (list->vec (map (compose exact->inexact proc) (range n))))) +(define vec->cpointer f32vector->cpointer) +(define vref-cpointer + (λ (v i) + (unless (and (<= 0 i) (< i (vlen v))) + (error 'vref-cpointer + "Index ~a out of range [0, ~a]" + i (sub1 (vlen v)))) + (ptr-add (vec->cpointer v) i _float))) + +(define-for-syntax debug-leaks? #f) +(define-syntax when-debug-leaks + (λ (x) + (syntax-case x () + ((when-debug-leaks expr) + debug-leaks? + #'expr) + ((when-debug-leaks expr) + #'(void))))) + +(define new-vec + (λ (size initial-value [context 'new-vec]) + (let ((m (make-vec size initial-value))) + (when-debug-leaks (manage-flat-vector! m context)) + m))) + +(define vcopy + (λ (dest idest src isrc n) + (for ([id (in-range idest (+ n idest))] + [is (in-range isrc (+ n isrc))]) + (vset! dest id (vref src is))))) + +(define print-vec + (λ (v (off 0) (port (current-output-port))) + (fprintf port "#(") + (for ((i (in-range off (vlen v)))) + (fprintf port "~a " (vref v i))) + (fprintf port ")~n"))) + +(provide vec? vec vref vset! vlen vcopy print-vec + list->vec build-vec vec->cpointer vref-cpointer new-vec) + +;;------------------------------------------------ +;; Memory management for flat-vectors +;;------------------------------------------------ + +(define flat-vector-manager + (make-will-executor)) + +(define manage-flat-vector! + (λ (m context) + (set-count! context (add1 (count context))) + (will-register flat-vector-manager m (flat-vector-collector context)))) + +(define flat-vector-collector + (λ (context) + (λ (v) + (cond + ((vector? v) + (set-count! context (sub1 (count context)))) + (else (fprintf (current-error-port) "?? ...")))))) + +(define start-vector-manager + (λ () + (when-debug-leaks + (void + (thread + (λ () + (let loop () + (will-execute flat-vector-manager) + (loop)))))))) + +(define counts (make-hash)) +(define count (λ (context) (dict-ref counts context 0))) +(define set-count! (λ (context v) (dict-set! counts context v))) +(define vector-manager-report + (λ () + (fprintf (current-error-port) "----------------------------------------------~%") + (fprintf (current-error-port) "context\t\t\tcount~%") + (for ([(context count) (in-hash counts)]) + (fprintf (current-error-port) "~a\t\t\t~a~%" context count)) + (fprintf (current-error-port) "----------------------------------------------~%"))) + +(provide start-vector-manager vector-manager-report) diff --git a/accelerated-tensors/tensors/1-flats.rkt b/accelerated-tensors/tensors/1-flats.rkt new file mode 100644 index 0000000..23aeb79 --- /dev/null +++ b/accelerated-tensors/tensors/1-flats.rkt @@ -0,0 +1,73 @@ +#lang racket + +;-------------------------------------------------------- +; Representation of tensors +;-------------------------------------------------------- + +;; A flat tensor representation is for a contiguous slice in the backing store. +;; The fields we need: +;; shape : list +;; store : vector +;; offset : start of the contiguous slice. +;; size : number of elements in the contiguous slice +;; strides : Number of elements in each dimension of the tensor. +;; rank: Number of dimensions in the tensor + + + +(define flat + (λ (shape store offset) + (vector flat shape store offset + (size-of shape) + (strides shape) + (length shape)))) + +(define flat? + (λ (v) + (and (vector? v) + (eq? (vector-ref v 0) flat)))) + +(define flat-shape + (λ (f) + (vector-ref f 1))) + +(define flat-store + (λ (f) + (vector-ref f 2))) + +(define flat-offset + (λ (f) + (vector-ref f 3))) + +(define flat-size + (λ (f) + (vector-ref f 4))) + +(define flat-strides + (λ (f) + (vector-ref f 5))) + +(define flat-rank + (λ (f) + (vector-ref f 6))) + +(define size-of + (λ (shape) + (product shape 1))) + +(define product + (λ (lst a) + (cond + ((null? lst) a) + (else (product (cdr lst) (* (car lst) a)))))) + +(define strides + (λ (shape) + (cond + ((null? shape) '()) + (else (cons (size-of (cdr shape)) + (strides (cdr shape))))))) + +(provide flat flat? flat-shape flat-store + flat-offset flat-rank flat-strides flat-size + size-of strides) diff --git a/accelerated-tensors/tensors/2-acc-runtime.rkt b/accelerated-tensors/tensors/2-acc-runtime.rkt new file mode 100644 index 0000000..f815a0b --- /dev/null +++ b/accelerated-tensors/tensors/2-acc-runtime.rkt @@ -0,0 +1,675 @@ +#lang racket + +(require ffi/cvector + ffi/unsafe + opencl/c + string-interpolation + file/xxhash32 + "0-vectors.rkt" + "../../impl-loader.rkt" + "ext2-strides.rkt") + + +;; TODO: Implement MNIST as an example along with iris and morse + +(define local-work-size (make-parameter #f)) +(define xxh32-ctx (make-xxh32)) + +(define context + (let ([context #f]) + (λ () + (or context + (begin + (set! context (clCreateContext #f (cvector->vector (devices)))) + context))))) +(define command-queue + (let ([command-queue #f]) + (λ () + (or command-queue + (begin + (set! command-queue (clCreateCommandQueue (context) (device) '())) + command-queue))))) + +(define old-exit-handler (exit-handler)) +(exit-handler + (λ (v) + (when (command-queue) + (clReleaseCommandQueue (command-queue))) + (when (context) + (clReleaseContext (context))) + (old-exit-handler v))) + +(define platform + (let ([platform #f]) + (lambda () + (or platform + (begin + (set! platform (cvector-ref (clGetPlatformIDs:vector) 0)) + platform))))) +(define devices + (let ([devices #f]) + (lambda () + (or devices + (begin + (set! devices (clGetDeviceIDs:vector (platform) 'CL_DEVICE_TYPE_GPU)) + devices))))) +(define device + (let ([device #f]) + (lambda () + (or device + (begin + (set! device (cvector-ref (devices) 0)) + device))))) + +(define (cvector->vector cv) + (build-vector (cvector-length cv) + (curry cvector-ref cv))) + +;; callback function to be used for debugging clBuildProgram if we expose that +;; parameter in the opencl/c library source code. +(define print-cl-build-log + (λ (program _) + (when (debug-kernel?) + (printf "Program Source:~n~a~n" + (clGetProgramInfo:generic program 'CL_PROGRAM_SOURCE)) + (printf "Build status:~a~n" + (clGetProgramBuildInfo:generic program (device) + 'CL_PROGRAM_BUILD_STATUS)) + (printf "Build log:~a~n" + (clGetProgramBuildInfo:generic program (device) + 'CL_PROGRAM_BUILD_LOG))))) + +(define (binary-expr rator rand1 rand2) + (string-append "(" rand1 " " rator " " rand2 ")")) + +(define idx-exprs + (λ (strides i0 i1) + (λ (out-i) + (for/fold ([i0 (number->string i0)] + [i1 (number->string i1)] + [x out-i] #:result (values i0 i1)) + ([stride (strides-strides strides)]) + (let ((stride-out (number->string (vector-ref stride 0))) + (stride0 (number->string (vector-ref stride 1))) + (stride1 (number->string (vector-ref stride 2)))) + (let ((idx (binary-expr "/" x stride-out)) + (next-x (binary-expr "%" x stride-out))) + (values (binary-expr "+" i0 (binary-expr "*" idx stride0)) + (binary-expr "+" i1 (binary-expr "*" idx stride1)) + next-x))))))) + +(define idx-exprs-inv + (λ (strides i-out repeats0 repeats1 s-out) + (λ (i0-var-str i1-var-str i-rep-var-str) + (let ((gen-expr + (λ (i-in-var-str stride-i repeats) + (for/fold ([i-out (number->string i-out)] + [dividend-rep i-rep-var-str] + [predivisor-rep repeats] + [x i-in-var-str] #:result i-out) + ([desc-out s-out] ;; s-out == (append descents-out sf-out) + [stride (strides-strides strides)]) ;; (len strides) == (len descents-out) + (let ((stride-out (vector-ref stride 0)) + (stride-in (vector-ref stride stride-i))) + (cond + ((zero? stride-in) + (let* ((divisor-rep (quotient predivisor-rep desc-out)) + (divisor-rep-str (number->string divisor-rep)) + (scaling (binary-expr "/" dividend-rep divisor-rep-str)) + (next-dividend (binary-expr "%" + dividend-rep + divisor-rep-str))) + (values (binary-expr "+" i-out + (binary-expr "*" + scaling + (number->string + stride-out))) + next-dividend + divisor-rep + x))) + (else + (let ((stride-in-str (number->string stride-in))) + (let ((idx (binary-expr "/" x stride-in-str)) + (next-x (binary-expr "%" x stride-in-str))) + (values (binary-expr "+" i-out + (binary-expr "*" idx + (number->string + stride-out))) + dividend-rep + predivisor-rep + next-x)))))))))) + (values (gen-expr i0-var-str 1 repeats0) + (gen-expr i1-var-str 2 repeats1)))))) + +(define calc-repeats + (λ (s0 s1 r0 r1 s-out r-out) + (define size-rep0 (apply * (drop-right s0 r0))) + (define size-rep1 (apply * (drop-right s1 r1))) + (define size-rep-out (apply * (drop-right s-out r-out))) + (values (/ size-rep-out size-rep0) + (/ size-rep-out size-rep1)))) + +(define kernel-name + (lambda (fn) + "kernel_@{(~a (eq-hash-code fn))}")) + +(define (ext1-ρ-kernel/name prim1-ρ-f prim-sign) + (values + #<bytes/utf-8 + kernel-code)))) + (clBuildProgram program (vector (device)) (make-bytes 0)) + (set! kernel (clCreateKernel program (string->bytes/utf-8 ker-name))) + (clSetKernelArg:_cl_mem kernel 0 buf0) + (clSetKernelArg:_cl_int kernel 1 stride0) + (clSetKernelArg:_cl_mem kernel 2 buf-out) + (clSetKernelArg:_cl_int kernel 3 stride-out)) + (λ () + ;;TODO: Try using the local-work-size argument + (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 + (make-vector 1 (/ size-out stride-out)) + (if (local-work-size) (make-vector 1 (local-work-size)) (make-vector 0)) + (make-vector 0))) + (set! event (clEnqueueReadBuffer (command-queue) buf-out 'CL_TRUE 0 + (* (ctype-sizeof _cl_float) + size-out) + (vec->cpointer v-out) (vector event)))) + (λ () + (when kernel + (clReleaseKernel kernel)) + (when program + (clReleaseProgram program)) + (when buf-out + (clReleaseMemObject buf-out)) + (when buf0 + (clReleaseMemObject buf0)))))) + +(define functional->preallocated-1-ρ-acc + (λ (f-acc base-shape out-shape) + (unless (and (null? base-shape) (null? out-shape)) + (error 'ρ1-functional-non-scalar-acc + (string-append "Accelerated functional primitives can only accept and" + " return scalars, so try defining a" + " preallocated primitive instead." + " Input and output shape found: ~a ~a") + base-shape out-shape)) + (λ (v0 i0 stride0 v-out i-out stride-out) + (let ((a "@{v0}[@{i0}]")) + #<bytes/utf-8 + kernel-code)))) + (clBuildProgram program (vector (device)) (make-bytes 0)) + (set! kernel (clCreateKernel program (string->bytes/utf-8 ker-name))) + (clSetKernelArg:_cl_mem kernel 0 buf-g) + (clSetKernelArg:_cl_mem kernel 1 buf0) + (clSetKernelArg:_cl_int kernel 2 stride0) + (clSetKernelArg:_cl_mem kernel 3 buf-z) + (clSetKernelArg:_cl_int kernel 4 stride-z)) + (λ () + (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 + (make-vector 1 (/ size-z stride-z)) + (if (local-work-size) (make-vector 1 (local-work-size)) (make-vector 0)) + (make-vector 0))) + (set! event (clEnqueueReadBuffer (command-queue) buf-g 'CL_TRUE 0 + (* (ctype-sizeof _cl_float) + size0) + (vec->cpointer g0) (vector event)))) + (λ () + (when kernel + (clReleaseKernel kernel)) + (when program + (clReleaseProgram program)) + (when buf-g + (clReleaseMemObject buf-g)) + (when buf-z + (clReleaseMemObject buf-z)) + (when buf0 + (clReleaseMemObject buf0)))) + ;)) + )) + +(define functional->preallocated-1-∇-acc + (λ (f-acc base-shape out-shape) + (unless (and (null? base-shape) (null? out-shape)) + (error '∇1-functional-non-scalar-acc + (string-append "Accelerated functional primitives can only accept and" + " return scalars, so try defining a" + " preallocated primitive instead." + " Input and output shape found: ~a ~a") + base-shape out-shape)) + (λ (g0 v0 i0 stride0 vz iz stride-z) + (let ((z "@{vz}[@{iz}]") + (a "@{v0}[@{i0}]")) + #<bytes/utf-8 kernel-code)))) + (clBuildProgram program (vector (device)) (make-bytes 0)) + (set! kernel (clCreateKernel program (string->bytes/utf-8 ker-name))) + (clSetKernelArg:_cl_mem kernel 0 buf0) + (clSetKernelArg:_cl_int kernel 1 stride0) + (clSetKernelArg:_cl_mem kernel 2 buf1) + (clSetKernelArg:_cl_int kernel 3 stride1) + (clSetKernelArg:_cl_mem kernel 4 buf-out) + (clSetKernelArg:_cl_int kernel 5 stride-out)) + (λ () + (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 + (make-vector 1 (/ size-out stride-out)) + (if (local-work-size) (make-vector 1 (local-work-size)) (make-vector 0)) + (make-vector 0))) + (set! event (clEnqueueReadBuffer (command-queue) buf-out 'CL_TRUE 0 + (* (ctype-sizeof _cl_float) + size-out) + (vec->cpointer v-out) (vector event)))) + (λ () + (when kernel + (clReleaseKernel kernel)) + (when program + (clReleaseProgram program)) + (when buf-out + (clReleaseMemObject buf-out)) + (when buf1 + (clReleaseMemObject buf1)) + (when buf0 + (clReleaseMemObject buf0)))))) + +(define functional->preallocated-2-ρ-acc + (λ (f-acc t-shape u-shape out-shape) + (unless (and (null? t-shape) (null? u-shape) (null? out-shape)) + (error 'ρ2-functional-non-scalar-acc + (string-append "Accelerated functional primitives can only accept and" + " return scalars, so try defining a" + " preallocated primitive instead." + " Input 1, input 2 and output shape found: ~a ~a ~a") + t-shape u-shape out-shape)) + (λ (v0 i0 stride0 v1 i1 stride1 v-out i-out stride-out) + (let ((a "@{v0}[@{i0}]") + (b "@{v1}[@{i1}]")) + #<bytes/utf-8 (strides-signature strides))) + (xxh32-update! + xxh32-ctx + (bytes-append (apply bytes-append + (map (λ (x) (integer->integer-bytes x 4 #f)) s0)) + (apply bytes-append + (map (λ (x) (integer->integer-bytes x 4 #f)) s1)) + (integer->integer-bytes r0 1 #f) + (integer->integer-bytes r1 1 #f) + (apply bytes-append + (map (λ (x) (integer->integer-bytes x 4 #f)) s-out)) + (integer->integer-bytes r-out 1 #f))) + (define params-hash (xxh32-digest xxh32-ctx)) + (format "~a~a" prim-sign params-hash)) + + +(define ext2-∇-kernel/name + (let ((cache (make-hash))) + (λ (prim2-∇-f prim-sign strides + s0 s1 r0 r1 s-out r-out) + (let ((kernel-name (ext2-∇-kernel-name prim-sign strides + s0 s1 r0 r1 s-out r-out))) + (cond + ((hash-has-key? cache kernel-name) + (values (hash-ref cache kernel-name) kernel-name)) + (else + (let*-values (((prim-effect0 prim-effect1) (prim2-∇-f "g" + "v0" "i0" "stride0" + "v1" "i1" "stride1" + "vz" "iz" "stride_z")) + ((repeats0 repeats1) (calc-repeats s0 s1 r0 r1 s-out r-out)) + ((generate-idxs) (idx-exprs strides 0 0)) + ((generate-idxs-inv) (idx-exprs-inv strides 0 + repeats0 repeats1 s-out)) + ((i0-expr i1-expr) (generate-idxs "iz")) + ((iz-expr0 iz-expr1) (generate-idxs-inv "i0" "i1" "i_rep"))) + (define kernel-code + #<bytes/utf-8 kernel-code)))) + (clBuildProgram program (vector (device)) (make-bytes 0)) + (set! kernel (clCreateKernel program (string->bytes/utf-8 ker-name))) + (clSetKernelArg:_cl_mem kernel 0 buf-g0) + (clSetKernelArg:_cl_mem kernel 1 buf-g1) + (clSetKernelArg:_cl_mem kernel 2 buf0) + (clSetKernelArg:_cl_int kernel 3 stride0) + (clSetKernelArg:_cl_int kernel 4 size0) + (clSetKernelArg:_cl_mem kernel 5 buf1) + (clSetKernelArg:_cl_int kernel 6 stride1) + (clSetKernelArg:_cl_int kernel 7 size1) + (clSetKernelArg:_cl_mem kernel 8 buf-z) + (clSetKernelArg:_cl_int kernel 9 stride-z)) + (λ () + (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 + (make-vector 1 global-work-size) + (if (local-work-size) (make-vector 1 (local-work-size)) (make-vector 0)) + (make-vector 0))) + (set! event (clEnqueueReadBuffer (command-queue) buf-g0 'CL_TRUE 0 + (* (ctype-sizeof _cl_float) + size0) + (vec->cpointer g0) (vector event))) + (set! event (clEnqueueReadBuffer (command-queue) buf-g1 'CL_TRUE 0 + (* (ctype-sizeof _cl_float) + size1) + (vec->cpointer g1) (vector event)))) + (λ () + (when kernel + (clReleaseKernel kernel)) + (when program + (clReleaseProgram program)) + (when buf-g1 + (clReleaseMemObject buf-g1)) + (when buf-g0 + (clReleaseMemObject buf-g0)) + (when buf-z + (clReleaseMemObject buf-z)) + (when buf1 + (clReleaseMemObject buf1)) + (when buf0 + (clReleaseMemObject buf0)))))) + +(define functional->preallocated-2-∇-acc + (λ (f-acc t-shape u-shape out-shape) + (unless (and (null? t-shape) (null? u-shape) (null? out-shape)) + (error '∇2-functional-non-scalar-acc + (string-append "Accelerated functional primitives can only accept and" + " return scalars, so try defining a" + " preallocated primitive instead." + " Input 1, input 2 and output shape found: ~a ~a ~a") + t-shape u-shape out-shape)) + (λ (g v0 i0 stride0 v1 i1 stride1 vz iz stride-z) + (let ((z "@{vz}[@{iz}]") + (a "@{v0}[@{i0}]") + (b "@{v1}[@{i1}]")) + (let-values (((da db) (f-acc a b z))) + (values + #<preallocated-1-ρ-acc ext1-ρ-kernel/name + run-prim1-∇! functional->preallocated-1-∇-acc ext1-∇-kernel/name + run-prim2-ρ! functional->preallocated-2-ρ-acc ext2-ρ-kernel/name + run-prim2-∇! functional->preallocated-2-∇-acc ext2-∇-kernel/name + kernel-name local-work-size) diff --git a/accelerated-tensors/tensors/A-equality.rkt b/accelerated-tensors/tensors/A-equality.rkt new file mode 100644 index 0000000..ae320c6 --- /dev/null +++ b/accelerated-tensors/tensors/A-equality.rkt @@ -0,0 +1,71 @@ +#lang racket + +;;—————————————————–—————————————————–—————————————————– +;; Equality checks for mostly for testing. +;;—————————————————–—————————————————–—————————————————– + +(require "0-vectors.ss") +(require "1-flats.ss") +(require rackunit) + +;;—————————————————–—————————————————–—————————————————– +;; These parameters can be overriden to account for +;; different type of numbers used inside tensors. +;;—————————————————–—————————————————–—————————————————– + +(define tolerance (make-parameter 0.0001)) + +(define equal-within-tolerance? + (make-parameter + (λ (actual expected) + (< (abs (- actual expected)) (tolerance))))) + +;;—————————————————–—————————————————–—————————————————– +;; These are representation specific, but part of the +;; exported interface of the module +;;—————————————————–—————————————————–—————————————————– + +(define tensor-equal? + (λ (actual expected) + (or (equal? actual expected) + (and (real? actual) + (real? expected) + ((equal-within-tolerance?) actual expected)) + (and (flat? actual) + (flat? expected) + (equal? (flat-shape actual) + (flat-shape expected)) + (equal-elements? actual expected))))) + +(define (equal-elements? actual expected) + (let ((actual-offset (flat-offset actual)) + (expected-offset (flat-offset expected)) + (actual-size (flat-size actual)) + (expected-size (flat-size expected)) + (actual-store (flat-store actual)) + (expected-store (flat-store expected))) + (and (equal? actual-size expected-size) + (call/cc (λ (return) + (for/fold ([check #t]) + ([i-actual (in-range actual-offset + (+ actual-offset + actual-size))] + [i-expected (in-range expected-offset + (+ expected-offset + expected-size))]) + (cond + (((equal-within-tolerance?) + (vref actual-store i-actual) + (vref expected-store i-expected)) check) + (else (return #f))))))))) + +(define-check (check-tensor-equal? actual expected) + (unless (tensor-equal? actual expected) + (fail-check (format "Tensors failed to match.~%actual:~%~s~%expected:~s~%~%actual store:~%~s~%expected store:~s~%" + actual expected + (with-output-to-string (λ () (print-vec (flat-store actual)))) + (with-output-to-string (λ () (print-vec (flat-store expected)))))))) + +(include "test/test-A-equality.rkt") + +(provide tolerance equal-within-tolerance? tensor-equal? check-tensor-equal? equal-elements?) diff --git a/accelerated-tensors/tensors/B-tensor-basics.rkt b/accelerated-tensors/tensors/B-tensor-basics.rkt new file mode 100644 index 0000000..39ac818 --- /dev/null +++ b/accelerated-tensors/tensors/B-tensor-basics.rkt @@ -0,0 +1,184 @@ +#lang racket + +;-------------------------------------------------------- +; Memory management tools for vectors +;-------------------------------------------------------- + +(require "0-vectors.ss") +(require "1-flats.ss") + +;-------------------------------------------------------- +; Lists +;-------------------------------------------------------- +(define ref list-ref) +(define refr drop) +(define len length) + +(provide ref refr len) + +;-------------------------------------------------------- +; Tensor basics +;-------------------------------------------------------- + +(define tref + (λ (t i) + (cond + ((= 1 (flat-rank t)) + (vref (flat-store t) (+ (flat-offset t) i))) + (else + (flat (cdr (flat-shape t)) + (flat-store t) + (+ (flat-offset t) (* i (car (flat-strides t))))))))) + +(define tlen + (λ (t) + (car (flat-shape t)))) + +(define flat-ref-idx + (λ (v indices) + (flat-ref-idx* (flat-offset v) (flat-strides v) indices))) + +(define flat-ref-idx* + (λ (current-idx strides indices) + (cond + ((null? indices) current-idx) + (else + (flat-ref-idx* + (+ current-idx + (* (car indices) (car strides))) + (cdr strides) + (cdr indices)))))) + +(define strides + (λ (shape) + (cond + ((null? shape) '()) + (else (cons (size-of (cdr shape)) + (strides (cdr shape))))))) + +(define size-of + (λ (shape) + (product shape 1))) + +(define product + (λ (lst a) + (cond + ((null? lst) a) + (else (product (cdr lst) (* (car lst) a)))))) + +(define list->tensor + (λ (lst) + (cond + ((null? lst) (error 'list->flat-tensor "No elements found")) + ((number? (car lst)) + (flat (list (length lst)) (list->vec (map exact->inexact lst)) 0)) + (else + (flat-tensor-from-list lst))))) + +(define flat-tensor-from-list + (λ (lst) + (let* ([inner-shape (flat-shape (car lst))] + [inner-size (size-of inner-shape)] + [outer-shape (cons (length lst) inner-shape)] + [size (size-of outer-shape)] + [v (new-vec size 0.0 'from-list)]) + (for ([fl lst] + [i (in-naturals 0)]) + (vcopy v (* i inner-size) + (flat-store fl) (flat-offset fl) + inner-size)) + (flat outer-shape v 0)))) + +(define tensor? + (λ (t) + (or (number? t) + (flat? t)))) + +(define tensor + (λ args + (ensure-shape args) + (cond + ((number? (car args)) (flat (list (length args)) + (list->vec (map exact->inexact args)) + 0)) + (else (merge-flats args))))) + +(define merge-flats + (λ (args) + (let* ((inner-shape (flat-shape (car args))) + (outer (length args)) + + (new-shape (cons outer inner-shape)) + (stride (size-of inner-shape)) + + (new-size (size-of new-shape)) + + (v-out (new-vec new-size +nan.0 'tensor))) + (for ([i-out (in-range outer)] + [arg args]) + (vcopy v-out (* i-out stride) (flat-store arg) (flat-offset arg) stride)) + (flat new-shape v-out 0)))) + +(define ensure-shape + (λ (args) + (when (null? args) + (error 'tensor "Tensors cannot be empty")) + (let ((checked-shape + (λ (x) (if (flat? x) + (flat-shape x) + '())))) + (unless (and (not (null? args)) + (cond + ((number? (car args)) + (andmap number? (cdr args))) + ((flat? (car args)) + (let ((s (checked-shape (car args)))) + (andmap (λ (t) + (and (flat? t) + (equal? (checked-shape t) s))) + (cdr args)))) + (else #f))) + (error 'tensor + "Cannot construct a tensor out of these elements: ~a~%" + args))))) + +(define build-tensor + (λ (shape f) + (let* ((size (size-of shape)) + (v (new-vec size 0.0 'build-tensor)) + (strides (strides shape))) + (fill-flat-tensor v shape strides f 0 '()) + (flat shape v 0)))) + +(define fill-flat-tensor + (λ (dest shape strides f offset tidx) + (cond + ((null? (cdr shape)) + (for ([i (in-range 0 (car shape))]) + (vset! dest (+ offset i) + (exact->inexact (f (append tidx (list i))))))) + (else + (let ((stride (car strides))) + (for ([i (in-range 0 (car shape))]) + (fill-flat-tensor dest + (cdr shape) (cdr strides) f + (+ offset (* i stride)) (append tidx (list i))))))))) + +(define trefs + (λ (t b) + (let* ([st (flat-shape t)] + [est (cdr st)] + [estride (size-of est)] + [nshape (cons (length b) (cdr st))] + [size-out (size-of nshape)] + [v-out (new-vec size-out 0.0 'flat-refs)] + [vt (flat-store t)]) + (for ([ib b] + [i-out (in-range 0 size-out estride)]) + (vcopy v-out i-out vt (* ib estride) estride)) + (flat nshape v-out 0)))) + +(include "test/test-B-tensor-basics.rkt") + +(provide tref tlen list->tensor number? + tensor? tensor build-tensor trefs merge-flats) diff --git a/accelerated-tensors/tensors/C-tensor-ops.rkt b/accelerated-tensors/tensors/C-tensor-ops.rkt new file mode 100644 index 0000000..f9056b4 --- /dev/null +++ b/accelerated-tensors/tensors/C-tensor-ops.rkt @@ -0,0 +1,35 @@ +#lang racket + +(require "1-flats.ss") +(require "B-tensor-basics.ss") + +;;—————————————————– +;; Shape, rank, size-of +;;—————————————————– + +(define shape + (λ (t) + (cond + ((number? t) '()) + (else (flat-shape t))))) + +(define rank + (λ (t) + (len (shape t)))) + +;;—————————————————– +;; Reshape a tensor +;;—————————————————– + +(define reshape + (λ (s t) + (cond + ((= (size-of s) (flat-size t)) + (flat s (flat-store t) (flat-offset t))) + (else (error "Cannot reshape ~a to ~a~%" (flat-shape t) s))))) + + +(include "test/test-C-tensor-ops.rkt") + +(provide rank shape reshape) +(provide size-of strides) diff --git a/accelerated-tensors/tensors/D-extend.rkt b/accelerated-tensors/tensors/D-extend.rkt new file mode 100644 index 0000000..6043280 --- /dev/null +++ b/accelerated-tensors/tensors/D-extend.rkt @@ -0,0 +1,485 @@ +#lang racket + +(require "../../impl-loader.rkt") +(require "0-vectors.ss") +(require "1-flats.ss") +(require "2-acc-runtime.ss") +(require "B-tensor-basics.ss") +(require "C-tensor-ops.ss") +(require "ext2-strides.rkt") + +;;—————————————————–—————————————————–—————————————————– +;; Unary Pointwise extension +;;—————————————————–—————————————————–—————————————————– + +;; TODO: Replace accelerate? parameter with a function which determines based on +;; the size of computation workload when to disable GPU acceleration and default +;; to running code on the CPU . + +(define ext1-ρ + (let ((id -1)) + (λ (f f-acc m + [shape-fn scalar-shape] + [prim-sign (begin + (set! id (add1 id)) + (string-append "re1" (~r id #:base 16)))]) + (λ (t) + (cond + ((number? t) (f t)) + ((expects-preallocated? f-acc) + (scalarize + (flat-ext1-ρ f f-acc m shape-fn prim-sign t))) + (else + (let* ((in-shape (flat-shape t)) + (base-shape (min-shape m in-shape)) + (out-shape (shape-fn base-shape)) + (flat-f (functional->preallocated-1-ρ f base-shape out-shape)) + (flat-f-acc (functional->preallocated-1-ρ-acc f-acc base-shape out-shape))) + (scalarize + (flat-ext1-ρ flat-f flat-f-acc m shape-fn prim-sign t))))))))) + +(define ext1-∇ + (let ((id -1)) + (λ (f f-acc m + [shape-fn scalar-shape] + [prim-sign (begin + (set! id (add1 id)) + (string-append "ne1" (~r id #:base 16)))]) + (λ (t z) + (cond + ((number? t) (f t z)) + ((expects-preallocated? f-acc) + (scalarize (flat-ext1-∇ f f-acc m shape-fn prim-sign t (ensure-flat z)))) + (else + (let* ((in-shape (flat-shape t)) + (base-shape (min-shape m in-shape)) + (out-shape (shape-fn base-shape)) + (flat-f (functional->preallocated-1-∇ f base-shape out-shape)) + (flat-f-acc (functional->preallocated-1-∇-acc + f-acc base-shape out-shape))) + (scalarize (flat-ext1-∇ flat-f flat-f-acc m + shape-fn prim-sign + t (ensure-flat z)))))))))) + +(define functional->preallocated-1-ρ + (λ (f base-shape out-shape) + (λ (v0 i0 stride0 v-out i-out stride-out) + (set-prealloc-ρ! v-out i-out out-shape + (f (arg-value base-shape v0 i0)))))) + +(define functional->preallocated-1-∇ + (λ (f base-shape out-shape) + (λ (g0 v0 i0 stride0 vz iz stride-z) + (let ((z (arg-value out-shape vz iz)) + (a (arg-value base-shape v0 i0))) + (set-prealloc-∇! g0 i0 base-shape (f a z)))))) + +(define set-prealloc-ρ! + (λ (v-out i-out out-shape a) + (cond + ((null? out-shape) (vset! v-out i-out a)) + (else + (v-copy-flat! v-out i-out a))))) + +(define set-prealloc-∇! + (λ (v-out i-out out-shape a) + (cond + ((null? out-shape) (vset! v-out i-out (+ (vref v-out i-out) a))) + (else + (v-add-flat! v-out i-out a))))) + +(define arg-value + (λ (v-shape v i) + (cond + ((null? v-shape) (vref v i)) + (else + (error 'ρ-functional-non-scalar-in + (string-append "Functional primitives can only accept scalars," + " so try defining a preallocated primitive" + " instead. In shape found: ~a") + v-shape) + #;(flat v-shape v i))))) + + +(define invoke-functional-∇ + (λ (f base-shape v0 i0) + (cond + ((null? base-shape) (f (vref v0 i0))) + (else (f (flat base-shape v0 i0)))))) + +;;—————————————————–—————————————————–—————————————————– +;; Binary Pointwise extension +;;—————————————————–—————————————————–—————————————————– + +(define ext2-ρ + (let ((id -1)) + (λ (f f-acc m n + [shape-fn scalar-shape] + [prim-sign (begin + (set! id (add1 id)) + (string-append "re2" (~r id #:base 16)))]) + (λ (t u) + (cond + ((and (number? t) (number? u)) (f t u)) + ((expects-preallocated? f-acc) + (scalarize + (flat-ext2-ρ f f-acc m n shape-fn prim-sign t u))) + ((number? t) + (let* ((t-shape '()) + (u-shape (min-shape n (flat-shape u))) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-ρ f t-shape u-shape out-shape)) + (flat-f-acc (functional->preallocated-2-ρ-acc f-acc t-shape u-shape out-shape))) + (scalarize + (flat-ext2-ρ flat-f flat-f-acc m n shape-fn prim-sign (ensure-flat t) u)))) + ((number? u) + (let* ((t-shape (min-shape m (flat-shape t))) + (u-shape '()) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-ρ f t-shape u-shape out-shape)) + (flat-f-acc (functional->preallocated-2-ρ-acc f-acc t-shape u-shape out-shape))) + (scalarize + (flat-ext2-ρ flat-f flat-f-acc m n shape-fn prim-sign t (ensure-flat u))))) + (else + (let* ((t-shape (min-shape m (flat-shape t))) + (u-shape (min-shape n (flat-shape u))) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-ρ f t-shape u-shape out-shape)) + (flat-f-acc (functional->preallocated-2-ρ-acc f-acc t-shape u-shape out-shape))) + (scalarize + (flat-ext2-ρ flat-f flat-f-acc m n shape-fn prim-sign t u))))))))) + +(define ext2-∇ + (let ((id -1)) + (λ (f f-acc m n + [shape-fn scalar-shape] + [prim-sign (begin + (set! id (add1 id)) + (string-append "ne2" (~r id #:base 16)))]) + (λ (t u z) + (let ((invoke-flat-ext2-∇ + (λ (f f-acc m n shape-fn t u z) + (let-values (((da db) (flat-ext2-∇ f f-acc m n shape-fn prim-sign t u z))) + (values (scalarize da) (scalarize db)))))) + (cond + ((and (number? t) (number? u)) (f t u z)) + ((expects-preallocated? f-acc) + (invoke-flat-ext2-∇ f f-acc m n shape-fn t u z)) + ((number? t) + (let* ((t-shape '()) + (u-shape (min-shape n (flat-shape u))) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-∇ f t-shape u-shape out-shape)) + (flat-f-acc (functional->preallocated-2-∇-acc f-acc t-shape u-shape out-shape))) + (invoke-flat-ext2-∇ flat-f flat-f-acc m n shape-fn (ensure-flat t) u z))) + ((number? u) + (let* ((t-shape (min-shape m (flat-shape t))) + (u-shape '()) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-∇ f t-shape u-shape out-shape)) + (flat-f-acc (functional->preallocated-2-∇-acc f-acc t-shape u-shape out-shape))) + (invoke-flat-ext2-∇ flat-f flat-f-acc m n shape-fn t (ensure-flat u) z))) + (else + (let* ((t-shape (min-shape m (flat-shape t))) + (u-shape (min-shape n (flat-shape u))) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-∇ f t-shape u-shape out-shape)) + (flat-f-acc (functional->preallocated-2-∇-acc f-acc t-shape u-shape out-shape))) + (invoke-flat-ext2-∇ flat-f flat-f-acc m n shape-fn t u z))))))))) + +(define functional->preallocated-2-ρ + (λ (f t-shape u-shape out-shape) + (λ (v0 i0 stride0 v1 i1 stride1 v-out i-out stride-out) + (set-prealloc-ρ! v-out i-out out-shape + (f (arg-value t-shape v0 i0) + (arg-value u-shape v1 i1)))))) + +(define functional->preallocated-2-∇ + (λ (f t-shape u-shape out-shape) + (λ (g0 g1 v0 i0 stride0 v1 i1 stride1 vz iz stride-z) + (let ((z (arg-value out-shape vz iz)) + (a (arg-value t-shape v0 i0)) + (b (arg-value u-shape v1 i1))) + (let-values (((da db) (f a b z))) + (set-prealloc-∇! g0 i0 t-shape da) + (set-prealloc-∇! g1 i1 u-shape db)))))) + +(define idxs + (λ (strides out-i i0 i1) + (for/fold ([i0 i0] + [i1 i1] + [x out-i] #:result (values i0 i1)) + ([stride (strides-strides strides)]) + (let ((idx (quotient x (vector-ref stride 0))) + (next-x (remainder x (vector-ref stride 0)))) + (values (+ i0 (* idx (vector-ref stride 1))) + (+ i1 (* idx (vector-ref stride 2))) + next-x))))) + +(define merge-shapes + (λ (in-shape min-rank out-f-shape) + (append (take in-shape (- (length in-shape) min-rank)) + out-f-shape))) + +(define flat-ext1-ρ + (λ (f f-acc min-rank shape-fn f-sign t0) + (let* ((s0 (flat-shape t0)) + (v0 (flat-store t0)) + (off0 (flat-offset t0)) + (sf0 (min-shape min-rank s0)) + (stride0 (size-of sf0)) + (size0 (size-of s0)) + + (sf-out (shape-fn sf0)) + (stride-out (size-of sf-out)) + (s-out (merge-shapes s0 min-rank sf-out)) + (size-out (size-of s-out)) + (v-out (new-vec size-out 0.0))) + (cond + ((accelerate?) + (let-values (((kernel-code kernel-name) (ext1-ρ-kernel/name f-acc f-sign))) + (run-prim1-ρ! kernel-code kernel-name + v0 off0 size0 stride0 + v-out size-out stride-out))) + (else + (for ([i-out (in-range 0 size-out stride-out)]) + (let ((i0 (+ off0 (* (/ i-out stride-out) stride0)))) + (f v0 i0 stride0 v-out i-out stride-out))))) + (flat s-out v-out 0)))) + +(define flat-ext1-∇ + (λ (fᵈ fᵈ-acc min-rank shape-fn fᵈ-sign t0 z) + ;; z has the same shape as the output + (let* ((s0 (flat-shape t0)) + (v0 (flat-store t0)) + (off0 (flat-offset t0)) + (sf0 (min-shape min-rank s0)) + (stride0 (size-of sf0)) + (size0 (size-of s0)) + + (sz (flat-shape z)) + (size-z (size-of sz)) + (sf-z (shape-fn sf0)) + (stride-z (size-of sf-z)) + (vz (flat-store z)) + (offz (flat-offset z)) + + (g0 (new-vec size0 0.0))) + (cond + ((accelerate?) + (let-values (((kernel-code kernel-name) (ext1-∇-kernel/name fᵈ-acc fᵈ-sign))) + (run-prim1-∇! kernel-code kernel-name g0 + v0 off0 size0 stride0 + vz offz size-z stride-z))) + (else + (for ([iz (in-range 0 size-z stride-z)]) + (let ((i0 (+ off0 (* (/ iz stride-z) stride0)))) + (fᵈ g0 v0 i0 stride0 vz (+ offz iz) stride-z))))) + (flat s0 g0 0)))) + +(define flat-ext2-ρ + (λ (f f-acc r0 r1 shape-fn f-sign t0 t1) + (let* ((s0 (flat-shape t0)) + (v0 (flat-store t0)) + (off0 (flat-offset t0)) + (sf0 (min-shape r0 s0)) + (size0 (size-of s0)) + + (s1 (flat-shape t1)) + (v1 (flat-store t1)) + (off1 (flat-offset t1)) + (sf1 (min-shape r1 s1)) + (size1 (size-of s1)) + + (sf-out (shape-fn sf0 sf1)) + (stride0 (size-of sf0)) + (stride1 (size-of sf1)) + (stride-out (size-of sf-out))) + (ext2-shapes s0 s1 r0 r1 sf-out + (λ (s-out size-out q0 q1 strides) + (let ((out-v (new-vec size-out 0.0))) + (cond + ((accelerate?) + (let-values (((kernel-code kernel-name) (ext2-ρ-kernel/name f-acc f-sign strides))) + (run-prim2-ρ! kernel-code kernel-name + v0 off0 size0 stride0 + v1 off1 size1 stride1 + out-v size-out stride-out))) + (else + (for ([out-i (in-range 0 size-out stride-out)]) + (let-values (((i0 i1) + (idxs strides out-i off0 off1))) + (f v0 i0 stride0 v1 i1 stride1 out-v (+ 0 out-i) stride-out))))) + (flat s-out out-v 0))))))) + +(define flat-ext2-∇ + (λ (fᵈ fᵈ-acc r0 r1 shape-fn fᵈ-sign t0 t1 z) + (let* ((s0 (flat-shape t0)) + (v0 (flat-store t0)) + (off0 (flat-offset t0)) + (sf0 (min-shape r0 s0)) + (size0 (size-of s0)) + (stride0 (size-of sf0)) + + (s1 (flat-shape t1)) + (v1 (flat-store t1)) + (off1 (flat-offset t1)) + (sf1 (min-shape r1 s1)) + (size1 (size-of s1)) + (stride1 (size-of sf1)) + + (sf-z (shape-fn sf0 sf1)) + (stride-z (size-of sf-z)) + (vz (flat-store z)) + (offz (flat-offset z))) + (ext2-shapes s0 s1 r0 r1 sf-z + (λ (sz size-z q0 q1 strides) + (let ((g0 (new-vec (size-of s0) 0.0)) + (g1 (new-vec (size-of s1) 0.0))) + (cond + ((accelerate?) + (let-values (((kernel-code kernel-name) + (ext2-∇-kernel/name fᵈ-acc fᵈ-sign strides s0 s1 r0 r1 sz + (length sf-z)))) + (run-prim2-∇! kernel-code kernel-name + g0 g1 + v0 off0 size0 stride0 + v1 off1 size1 stride1 + vz offz size-z stride-z))) + (else + (for ([iz (in-range 0 size-z stride-z)]) + (let-values (((i0 i1) (idxs strides iz off0 off1))) + (fᵈ g0 g1 v0 i0 stride0 v1 i1 stride1 vz (+ offz iz) stride-z))))) + (values (flat s0 g0 0) + (flat s1 g1 0)))))))) + +;;TODO: Create a caching macro to generalize caching of functions +(define ext2-shapes + (let ((cache (make-hash))) + (λ (s0 s1 r0 r1 sf-out k) + (define cache-key (equal-hash-code (list s0 s1 r0 r1 sf-out))) + (cond + [(hash-has-key? cache cache-key) (apply k (hash-ref cache cache-key))] + [else + (let ((l0 (length s0)) + (l1 (length s1)) + (k (λ args + (hash-set! cache cache-key args) + (apply k args)))) + (cond + ((and (= r0 l0) (= r1 l1)) + (k sf-out + (size-of sf-out) + (size-of s0) + (size-of s1) + strides-null)) + + ((= r0 l0) + (ext2-shapes s0 (cdr s1) r0 r1 sf-out + (desc-right (car s1) k))) + + ((= r1 l1) + (ext2-shapes (cdr s0) s1 r0 r1 sf-out + (desc-left (car s0) k))) + + ((and (not (null? s0)) + (not (null? s1)) + (= (car s0) (car s1))) + (ext2-shapes (cdr s0) (cdr s1) r0 r1 sf-out + (desc-both (car s0) k))) + + ((> l1 l0) + (ext2-shapes s0 (cdr s1) r0 r1 sf-out + (desc-right (car s1) k))) + + ((> l0 l1) + (ext2-shapes (cdr s0) s1 r0 r1 sf-out + (desc-left (car s0) k))) + + (else (error 'ext + "Shapes are incompatible for ext2: ~a, and ~a for min ranks ~a, and ~a~%" + s0 s1 r0 r1))))])))) + +(define desc-both + (λ (d k) + (λ (s-out qout q0 q1 strides) + (k (cons d s-out) + (* qout d) + (* q0 d) + (* q1 d) + (strides-cons qout q0 q1 strides))))) + +(define desc-left + (λ (d k) + (λ (s-out qout q0 q1 strides) + (k (cons d s-out) + (* qout d) + (* q0 d) + q1 + (strides-cons qout q0 0 strides))))) + +(define desc-right + (λ (d k) + (λ (s-out qout q0 q1 strides) + (k (cons d s-out) + (* qout d) + q0 + (* q1 d) + (strides-cons qout 0 q1 strides))))) + +(define v-copy-flat! + (λ (vg ig a) + ;; copy elements from a to vg + (let ((va (flat-store a)) + (a-offset (flat-offset a)) + (a-stride (size-of (flat-shape a)))) + (for ([i (in-range 0 a-stride)]) + (vset! vg (+ ig i) + (vref va (+ a-offset i))))))) + +(define v-add-flat! + (λ (vg ig a) + ;; copy elements to a to vg while adding them to vg + (let ((va (flat-store a)) + (a-offset (flat-offset a)) + (a-stride (size-of (flat-shape a)))) + (for ([i (in-range 0 a-stride)]) + (vset! vg (+ ig i) + (+ (vref vg (+ ig i)) + (vref va (+ a-offset i)))))))) + +(define expects-preallocated? + (λ (f) + (let ((a (procedure-arity f))) + (and (integer? a) + (>= a 6))))) + +(define ensure-flat + (λ (z) + (cond + ((number? z) + (flat '() (new-vec 1 (exact->inexact z)) 0)) + (else z)))) + +(define scalarize + (λ (t) + (cond + ((null? (flat-shape t)) (vref (flat-store t) 0)) + (else t)))) + +(define min-shape + (λ (min-rank in-shape) + (drop in-shape (- (length in-shape) min-rank)))) + +(define scalar-shape + (λ (s0 [s1 '()]) '())) + +(include "test/test-D-extend.rkt") + +(provide ext1-ρ ext1-∇ ext2-ρ ext2-∇ expects-preallocated? + functional->preallocated-1-ρ functional->preallocated-1-∇ + functional->preallocated-2-ρ functional->preallocated-2-∇ + functional->preallocated-1-ρ-acc functional->preallocated-1-∇-acc + functional->preallocated-2-ρ-acc functional->preallocated-2-∇-acc + merge-shapes min-shape ext2-shapes idxs + flat-ext1-∇ flat-ext1-ρ flat-ext2-ρ scalarize ensure-flat) diff --git a/accelerated-tensors/tensors/ext2-strides.rkt b/accelerated-tensors/tensors/ext2-strides.rkt new file mode 100644 index 0000000..abdd1c2 --- /dev/null +++ b/accelerated-tensors/tensors/ext2-strides.rkt @@ -0,0 +1,36 @@ +#lang racket + +(require file/xxhash32) + +(provide strides-null + strides-cons + (rename-out (ext2-strides-strides strides-strides) + (ext2-strides-sign strides-signature))) + +(define (strides-signature! ctx strides) + (xxh32-update! + ctx + (for/fold + ((result #"")) + ((stride-vec strides)) + (match-let* ((`#(,s1 ,s2 ,s3) stride-vec)) + (bytes-append result + (integer->integer-bytes s1 4 #f) + (integer->integer-bytes s2 4 #f) + (integer->integer-bytes s3 4 #f)))))) + +(struct ext2-strides ((strides #:mutable) sign)) + +(define strides-null + (ext2-strides '() (let ((ctx (make-xxh32))) + (~r (xxh32-digest ctx) #:base 16)))) + +(define strides-cons + (λ (st-out st0 st1 strides) + (let ((new-list (cons (vector st-out st0 st1) + (ext2-strides-strides strides)))) + (ext2-strides new-list + (begin + (let ((ctx (make-xxh32))) + (strides-signature! ctx new-list) + (~r (xxh32-digest ctx) #:base 16))))))) diff --git a/accelerated-tensors/tensors/test/test-2-acc-runtime.rkt b/accelerated-tensors/tensors/test/test-2-acc-runtime.rkt new file mode 100644 index 0000000..6149676 --- /dev/null +++ b/accelerated-tensors/tensors/test/test-2-acc-runtime.rkt @@ -0,0 +1,8 @@ +(module+ test + (require rackunit) + + (for ((_ (in-range 100))) + (λ () + (check-true (not (not (context)))) + (check-true (not (not (command-queue)))))) + ) diff --git a/accelerated-tensors/tensors/test/test-A-equality.rkt b/accelerated-tensors/tensors/test/test-A-equality.rkt new file mode 100644 index 0000000..9488d8c --- /dev/null +++ b/accelerated-tensors/tensors/test/test-A-equality.rkt @@ -0,0 +1,84 @@ +(module+ test + (require rackunit) + + (check-true ((equal-within-tolerance?) 1.00001 1.0001)) + (check-true ((equal-within-tolerance?) 1.0002 1.0001)) + (check-false ((equal-within-tolerance?) 1.0003 1.0001)) + + (define t0 + (flat '(2 3 4) + (build-vec 24 + (λ (i) + (* 2.0 i))) + 0)) + + + (define t1 + (flat '(2 3 4) + (build-vec 24 + (λ (i) + (* 2.000001 i))) + 0)) + + (define t2 + (flat '(1 2 3 4) + (build-vec 24 + (λ (i) + (* 2.000001 i))) + 0)) + + (define t3 + (flat '(2 2 3 4) + (build-vec 48 + (λ (i) + (* (quotient i 24) i))) + 0)) + + (define t4 + (flat '(2 2 3 4) + (build-vec 48 + (λ (i) + (- (* 2.000001 (* (quotient i 24) i)) 48.0))) + 0)) + + (check-true (equal-elements? t0 t1)) + + (check-true (equal-elements? t0 t2)) ;; elements are equal, but shapes are not + + (check-true (equal-elements? t0 (flat '(2 3 4) + (flat-store t2) + 0))) + + (check-false (equal-elements? t1 (flat '(2 3 4) + (flat-store t3) + 24))) + + (check-true (equal-elements? t1 (flat '(2 3 4) + (flat-store t4) + 24))) + + (check-true (tensor-equal? t0 t1)) + + (check-false (tensor-equal? t0 t2)) ;; elements are equal, but shapes are not + + (check-true (tensor-equal? t0 (flat '(2 3 4) + (flat-store t2) + 0))) + + (check-false (tensor-equal? t1 (flat '(2 3 4) + (flat-store t3) + 24))) + + (check-true (tensor-equal? t1 (flat '(2 3 4) + (flat-store t4) + 24))) + + (check-tensor-equal? t0 t1) + + (check-tensor-equal? t0 (flat '(2 3 4) + (flat-store t2) + 0)) + + (check-tensor-equal? t1 (flat '(2 3 4) + (flat-store t4) + 24))) diff --git a/accelerated-tensors/tensors/test/test-B-tensor-basics.rkt b/accelerated-tensors/tensors/test/test-B-tensor-basics.rkt new file mode 100644 index 0000000..903220b --- /dev/null +++ b/accelerated-tensors/tensors/test/test-B-tensor-basics.rkt @@ -0,0 +1,33 @@ +(module+ test + (require rackunit) + (require "A-equality.ss") + + (define r0-td 3.0) + (define r1-td (tensor 3.0 4.0 5.0)) + (define r2-td (tensor (tensor 3.0 4.0 5.0) (tensor 7.0 8.0 9.0))) + (define r3-td + (tensor (tensor (tensor 0 1) (tensor 2 3) (tensor 4 5)) + (tensor (tensor 6 7) (tensor 8 9) (tensor 10 11)) + (tensor (tensor 12 13) (tensor 14 15) (tensor 16 17)) + (tensor (tensor 18 19) (tensor 20 21) (tensor 22 23)))) + + (check-tensor-equal? (tref r1-td 2) 5.0) + (check-equal? (tlen r1-td) 3) + (check-tensor-equal? (list->tensor (list 3.0 4.0 5.0)) r1-td) + + (check-true (and (tensor? r0-td) (tensor? r1-td))) + (check-false (tensor? '(a b c))) + + (check-tensor-equal? (build-tensor '(4 3 2) + (λ (idx) + (+ (* 6 (ref idx 0)) + (* 2 (ref idx 1)) + (ref idx 2)))) + r3-td) + + (check-tensor-equal? (build-tensor '(1 2 3) (λ (idx) (+ (list-ref idx 0) (list-ref idx 1) (list-ref idx 2)))) + (tensor (tensor (tensor 0 1 2) (tensor 1 2 3)))) + + (check-tensor-equal? (trefs r1-td '(0 2)) (tensor 3.0 5.0)) + + ) diff --git a/accelerated-tensors/tensors/test/test-C-tensor-ops.rkt b/accelerated-tensors/tensors/test/test-C-tensor-ops.rkt new file mode 100644 index 0000000..4509a31 --- /dev/null +++ b/accelerated-tensors/tensors/test/test-C-tensor-ops.rkt @@ -0,0 +1,63 @@ +(module+ test + (require rackunit) + (require "A-equality.ss") + + (define r0-td 3.0) + (define r1-td (tensor 3.0 4.0 5.0)) + (define r2-td (tensor (tensor 3.0 4.0 5.0) (tensor 7.0 8.0 9.0))) + (define r3-td + (tensor (tensor (tensor 0 1) (tensor 2 3) (tensor 4 5)) + (tensor (tensor 6 7) (tensor 8 9) (tensor 10 11)) + (tensor (tensor 12 13) (tensor 14 15) (tensor 16 17)) + (tensor (tensor 18 19) (tensor 20 21) (tensor 22 23)))) + + (define test-shape (list 2 2 3)) + + (check-equal? (shape r0-td) (list)) + (check-equal? (shape r1-td) (list 3)) + (check-equal? (shape r2-td) (list 2 3)) + + (check-equal? (rank r0-td) 0) + (check-equal? (rank r1-td) 1) + (check-equal? (rank r2-td) 2) + + (check-equal? (size-of '()) 1) + (check-equal? (size-of test-shape) 12) + + + (check-equal? (size-of '(4 3 2)) 24) + + (check-tensor-equal? (reshape '(24) r3-td) + (tensor 0 1 2 3 4 5 + 6 7 8 9 10 11 + 12 13 14 15 16 17 + 18 19 20 21 22 23)) + + (check-tensor-equal? (reshape '(4 1) (tensor 0 1 2 3)) + (tensor (tensor 0) (tensor 1) (tensor 2) (tensor 3))) + + (check-tensor-equal? (reshape '(6) r2-td) + (tensor 3.0 4.0 5.0 7.0 8.0 9.0)) + + (check-tensor-equal? (reshape '(3 2) r2-td) + (tensor (tensor 3.0 4.0) + (tensor 5.0 7.0) + (tensor 8.0 9.0))) + + + (check-exn exn:fail? + (λ () + (tensor "1 2" 1 2))) + + (check-exn exn:fail? + (λ () + (tensor))) + + (check-exn exn:fail? + (λ () + (tensor 1 (tensor 2 3)))) + + (check-exn exn:fail? + (λ () + (tensor tensor (tensor 2 3)))) +) diff --git a/accelerated-tensors/tensors/test/test-D-extend.rkt b/accelerated-tensors/tensors/test/test-D-extend.rkt new file mode 100644 index 0000000..78ba2a1 --- /dev/null +++ b/accelerated-tensors/tensors/test/test-D-extend.rkt @@ -0,0 +1,294 @@ +(module+ test + (require rackunit) + (require string-interpolation) + (require "A-equality.rkt") + (require "B-tensor-basics.rkt") + + (define sum-f + (λ (in-v iᵢ sᵢ out-v iₒ sₒ) + (vset! out-v iₒ + (for/fold ([sum 0.0]) ([i (in-range iᵢ (+ iᵢ sᵢ))]) + (+ sum (vref in-v i)))))) + + (define sum-f-acc + (λ (v0 i0 stride0 v-out i-out stride-out) + #<vec '(3.0 4.0 5.0)) 0)) + (define r2-td (flat '(2 3) (list->vec '(3.0 4.0 5.0 7.0 8.0 9.0)) 0)) + + (define +ᶠ +) + (define +ᵈ (λ (a b z) (values z z))) + (define +ᵈ-acc +ᵈ) + + (define sqrᶠ (λ (a) (* a a))) + (define sqrᵈ + (λ (a z) (* z 2 a))) + (define sqrᵈ-acc + (λ (a z) + "@{z} * 2.0 * @{a}")) + + (define d-sqr (ext1-∇ sqrᵈ sqrᵈ-acc 0 scalar-shape)) + + (define one-like + (λ (t) + (let* ((st (flat-shape t)) + (size-t (size-of st))) + (flat st + (new-vec size-t 1.0) + 0)))) + + (check-true (equal-elements? (d-sqr r1-td (one-like r1-td)) (tensor 6.0 8.0 10.0))) + + (let ((gsqr (d-sqr r2-td (one-like r2-td)))) + (check-tensor-equal? gsqr (reshape '(2 3) (tensor 6.0 8.0 10.0 14.0 16.0 18.0)))) + + (define d+ (ext2-∇ +ᵈ +ᵈ-acc 0 0 scalar-shape)) + + (let-values (((da db) (d+ r1-td r1-td (one-like r1-td)))) + (check-tensor-equal? da (tensor 1.0 1.0 1.0)) + (check-tensor-equal? db (tensor 1.0 1.0 1.0))) + + (let-values (((da db) (d+ r1-td r2-td (one-like r2-td)))) + (check-tensor-equal? da (tensor 2.0 2.0 2.0)) + (check-tensor-equal? db (reshape '(2 3) (tensor 1.0 1.0 1.0 1.0 1.0 1.0)))) + + (define *∇ (ext2-∇ (λ (a b z) (values (* z b) (* z a))) + (λ (a b z) (values "@{z} * @{b}" "@{z} * @{a}" )) + 0 + 0)) + + (let-values (((gt gu) (*∇ (tensor 2.0 3.0 4.0) (tensor 1.0 2.0 3.0) (tensor 1.0 1.0 1.0)))) + (check-tensor-equal? gt (tensor 1.0 2.0 3.0)) + (check-tensor-equal? gu (tensor 2.0 3.0 4.0))) + + (define sum-1-∇ + (λ (g t it st vz iz sz) + (for* ([i (in-range it (+ it st))]) + (vset! g i (vref vz iz))))) + (define sum-1-∇-acc + (λ (g0 v0 i0 stride0 vz iz stride-z) + #<tensor tensor build-tensor @@ -31,7 +33,7 @@ (d-concat concat) (d-concat-n concat-n)) +-ρ --ρ *-ρ /-ρ rectify-ρ - exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ sum-ρ abs-ρ *-2-1-ρ argmax-ρ max-ρ sum-cols-ρ correlate-ρ flatten-ρ concat-ρ diff --git a/flat-tensors/autodiff/B-prims.rkt b/flat-tensors/autodiff/B-prims.rkt index ccbe516..1bb8d9e 100644 --- a/flat-tensors/autodiff/B-prims.rkt +++ b/flat-tensors/autodiff/B-prims.rkt @@ -151,7 +151,7 @@ (λ (ρ-fn ra rb shape-fn) (let* ((in-shape-a (flat-shape ra)) (in-size-a (size-of in-shape-a)) - (in-shape-b (flat-shape ra)) + (in-shape-b (flat-shape rb)) (in-size-b (size-of in-shape-b)) (out-shape (shape-fn in-shape-a in-shape-b)) (out-size (size-of out-shape))) @@ -175,7 +175,7 @@ (λ (∇-fn ra rb z shape-fn) (let* ((in-shape-a (flat-shape ra)) (in-size-a (size-of in-shape-a)) - (in-shape-b (flat-shape ra)) + (in-shape-b (flat-shape rb)) (in-size-b (size-of in-shape-b)) (out-shape (shape-fn in-shape-a in-shape-b)) (out-size (size-of out-shape))) diff --git a/flat-tensors/autodiff/E-print.rkt b/flat-tensors/autodiff/E-print.rkt index cd0b5d7..ca3ce62 100644 --- a/flat-tensors/autodiff/E-print.rkt +++ b/flat-tensors/autodiff/E-print.rkt @@ -80,4 +80,7 @@ (include "test/test-E-print.rkt") (provide max-tensor-print-length - make-printable) + make-printable + ;; This is used in ext-impl.rkt + make-printable-flat + fake-tensor) diff --git a/flat-tensors/ext-impl.rkt b/flat-tensors/ext-impl.rkt new file mode 100644 index 0000000..6eb7d34 --- /dev/null +++ b/flat-tensors/ext-impl.rkt @@ -0,0 +1,28 @@ +#lang racket +(require "tensors/0-vectors.rkt") +(require "tensors/1-flats.rkt") +(require (only-in "tensors/B-tensor-basics.rkt" + merge-flats)) +(require (only-in "tensors/D-extend.rkt" + merge-shapes + min-shape + ext2-shapes + flat-ext1-∇ + flat-ext1-ρ + flat-ext2-ρ + functional->preallocated-1-ρ + functional->preallocated-1-∇ + functional->preallocated-2-ρ + functional->preallocated-2-∇ + idxs + scalarize + ensure-flat)) +(require (only-in "autodiff/E-print.rkt" + make-printable-flat + fake-tensor)) + +(provide (all-from-out "tensors/0-vectors.rkt")) +(provide (all-from-out "tensors/1-flats.rkt")) +(provide (all-from-out "tensors/B-tensor-basics.rkt")) +(provide (all-from-out "tensors/D-extend.rkt")) +(provide (all-from-out "autodiff/E-print.rkt")) diff --git a/flat-tensors/ext-ops.rkt b/flat-tensors/ext-ops.rkt index 83af7de..fc223f5 100644 --- a/flat-tensors/ext-ops.rkt +++ b/flat-tensors/ext-ops.rkt @@ -19,7 +19,7 @@ +-ρ --ρ *-ρ /-ρ expt-ρ exp-ρ log-ρ abs-ρ - rectify-ρ sqrt-ρ sqr-ρ) + rectify-ρ sqrt-ρ sqr-ρ zeroes-ρ) (provide =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/flat-tensors/ext-ops/A-scalar-ops.rkt b/flat-tensors/ext-ops/A-scalar-ops.rkt index b251e80..9096209 100644 --- a/flat-tensors/ext-ops/A-scalar-ops.rkt +++ b/flat-tensors/ext-ops/A-scalar-ops.rkt @@ -121,6 +121,9 @@ (λ (x) (*-ρ x x))) +(define zeroes-ρ + (ext1-ρ (λ (_) 0.0) 0)) + (include "test/test-A-scalar-ops.rkt") (provide +-0-0 --0-0 *-0-0 /-0-0 expt-0-0 @@ -132,4 +135,4 @@ +-ρ --ρ *-ρ /-ρ expt-ρ exp-ρ log-ρ abs-ρ - rectify-ρ sqrt-ρ sqr-ρ) + rectify-ρ sqrt-ρ sqr-ρ zeroes-ρ) diff --git a/flat-tensors/ext-ops/E-argmax.rkt b/flat-tensors/ext-ops/E-argmax.rkt index ad4aefe..3130843 100644 --- a/flat-tensors/ext-ops/E-argmax.rkt +++ b/flat-tensors/ext-ops/E-argmax.rkt @@ -7,7 +7,7 @@ (λ (v0 i0 stride0 v-out i-out stride-out) (vector-set! v-out i-out - (for/fold ([max 0.0] + (for/fold ([max -inf.0] [max-i -1] #:result max-i) ([i (in-range i0 (+ i0 stride0))]) (let ((v (vector-ref v0 i))) diff --git a/flat-tensors/ext-ops/F-max.rkt b/flat-tensors/ext-ops/F-max.rkt index 5a54bd5..77e3d26 100644 --- a/flat-tensors/ext-ops/F-max.rkt +++ b/flat-tensors/ext-ops/F-max.rkt @@ -7,7 +7,7 @@ (λ (v0 i0 stride0 v-out i-out stride-out) (vector-set! v-out i-out - (for/fold ([max 0.0]) + (for/fold ([max -inf.0]) ([i (in-range i0 (+ i0 stride0))]) (let ((v (vector-ref v0 i))) (cond diff --git a/flat-tensors/no-duals-no-overrides.rkt b/flat-tensors/no-duals-no-overrides.rkt index ac07a7a..07ca22e 100644 --- a/flat-tensors/no-duals-no-overrides.rkt +++ b/flat-tensors/no-duals-no-overrides.rkt @@ -20,7 +20,7 @@ ;; From ext-ops +-ρ --ρ *-ρ /-ρ rectify-ρ - exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ sum-ρ abs-ρ *-2-1-ρ argmax-ρ max-ρ sum-cols-ρ correlate-ρ concat-ρ concat-n-ρ flatten-ρ diff --git a/flat-tensors/no-duals.rkt b/flat-tensors/no-duals.rkt index 927c8c7..cd1bcaf 100644 --- a/flat-tensors/no-duals.rkt +++ b/flat-tensors/no-duals.rkt @@ -20,7 +20,7 @@ ;; From ext-ops (rename-out (+-ρ +) (--ρ -) (*-ρ *) (/-ρ /) (rectify-ρ rectify) - (exp-ρ exp) (log-ρ log) (expt-ρ expt) (sqrt-ρ sqrt) (sqr-ρ sqr) + (exp-ρ exp) (log-ρ log) (expt-ρ expt) (sqrt-ρ sqrt) (sqr-ρ sqr) (zeroes-ρ zeroes) (sum-ρ sum) (abs-ρ abs) (*-2-1-ρ *-2-1) (argmax-ρ argmax) (max-ρ max) (sum-cols-ρ sum-cols) (correlate-ρ correlate) (flatten-ρ flatten) (concat-ρ concat) (concat-n-ρ concat-n)) diff --git a/flat-tensors/no-overrides.rkt b/flat-tensors/no-overrides.rkt index 35dcbdd..05844b7 100644 --- a/flat-tensors/no-overrides.rkt +++ b/flat-tensors/no-overrides.rkt @@ -34,7 +34,7 @@ d-flatten d-concat d-concat-n +-ρ --ρ *-ρ /-ρ rectify-ρ - exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ sum-ρ abs-ρ *-2-1-ρ argmax-ρ max-ρ sum-cols-ρ correlate-ρ flatten-ρ concat-ρ concat-n-ρ diff --git a/flat-tensors/tensors/B-tensor-basics.rkt b/flat-tensors/tensors/B-tensor-basics.rkt index 4af5d69..60df715 100644 --- a/flat-tensors/tensors/B-tensor-basics.rkt +++ b/flat-tensors/tensors/B-tensor-basics.rkt @@ -179,4 +179,4 @@ (include "test/test-B-tensor-basics.rkt") (provide tref tlen list->tensor number? - tensor? tensor build-tensor trefs) + tensor? tensor build-tensor trefs merge-flats) diff --git a/flat-tensors/tensors/C-tensor-ops.rkt b/flat-tensors/tensors/C-tensor-ops.rkt index f9056b4..cf32d23 100644 --- a/flat-tensors/tensors/C-tensor-ops.rkt +++ b/flat-tensors/tensors/C-tensor-ops.rkt @@ -26,7 +26,7 @@ (cond ((= (size-of s) (flat-size t)) (flat s (flat-store t) (flat-offset t))) - (else (error "Cannot reshape ~a to ~a~%" (flat-shape t) s))))) + (else (error 'tensor-reshape "Cannot reshape ~a to ~a~%" (flat-shape t) s))))) (include "test/test-C-tensor-ops.rkt") diff --git a/flat-tensors/tensors/D-extend.rkt b/flat-tensors/tensors/D-extend.rkt index bc698fa..e2e5501 100644 --- a/flat-tensors/tensors/D-extend.rkt +++ b/flat-tensors/tensors/D-extend.rkt @@ -209,11 +209,12 @@ (sf-z (shape-fn sf0)) (stride-z (size-of sf-z)) (vz (flat-store z)) + (offz (flat-offset z)) (g0 (new-vec size0 0.0))) (for ([iz (in-range 0 size-z stride-z)] [i0 (in-range off0 (+ off0 size0) stride0)]) - (fᵈ g0 v0 i0 stride0 vz iz stride-z)) + (fᵈ g0 v0 i0 stride0 vz (+ offz iz) stride-z)) (flat s0 g0 0)))) (define flat-ext2-ρ @@ -384,4 +385,8 @@ (include "test/test-D-extend.rkt") -(provide ext1-ρ ext1-∇ ext2-ρ ext2-∇ expects-preallocated?) +(provide ext1-ρ ext1-∇ ext2-ρ ext2-∇ expects-preallocated? + functional->preallocated-1-ρ functional->preallocated-1-∇ + functional->preallocated-2-ρ functional->preallocated-2-∇ + merge-shapes min-shape ext2-shapes idxs + flat-ext1-∇ flat-ext1-ρ flat-ext2-ρ scalarize ensure-flat) diff --git a/flat-tensors/tensors/test/test-D-extend.rkt b/flat-tensors/tensors/test/test-D-extend.rkt index b6026f1..34d9618 100644 --- a/flat-tensors/tensors/test/test-D-extend.rkt +++ b/flat-tensors/tensors/test/test-D-extend.rkt @@ -43,8 +43,8 @@ (define dup (ext1-ρ dup-f 1 dup-shape-f)) (check-equal? (flat-store (dup t0)) - #(0 2 4 6 0 2 4 6 - 8 10 12 14 8 10 12 14 + #(0 2 4 6 0 2 4 6 + 8 10 12 14 8 10 12 14 16 18 20 22 16 18 20 22 24 26 28 30 24 26 28 30 32 34 36 38 32 34 36 38 diff --git a/impl-loader.rkt b/impl-loader.rkt index 4fc5f1b..8c1a871 100644 --- a/impl-loader.rkt +++ b/impl-loader.rkt @@ -22,7 +22,7 @@ (define init-settings (λ () - (settings (make-hash (read-preferences "local.cfg"))))) + (settings (make-hash (read-preferences (or (getenv "MALT_PREFERENCES") "local.cfg")))))) ;;-------------------------------- ;; Config params so far @@ -32,11 +32,21 @@ (λ () (car (dict-ref (settings) 'tensor-implementation)))) +(define accelerate? + (λ () + (car (dict-ref (settings) 'accelerate?)))) + +(define debug-kernel? + (λ () + (car (dict-ref (settings) 'debug-kernel?)))) + ;; Default settings (define default-preferences - `((tensor-implementation learner))) + `((tensor-implementation learner) + (accelerate? #t) + (debug-kernel? #f))) (when (not (settings)) (init-settings)) -(provide tensor-implementation) +(provide tensor-implementation accelerate? debug-kernel?) diff --git a/impl-no-duals-no-overrides.rkt b/impl-no-duals-no-overrides.rkt index 2fe6658..c01c9a6 100644 --- a/impl-no-duals-no-overrides.rkt +++ b/impl-no-duals-no-overrides.rkt @@ -11,12 +11,18 @@ (printf "Tensor implementation (no-duals, no-overrides): ~s~%" (tensor-implementation)) #`(begin #,(case (tensor-implementation) + ((lazy) #'(require "lazy/no-duals-no-overrides.rkt")) ((learner) #'(require "learner/no-duals-no-overrides.rkt")) ((flat-tensors) #'(require "flat-tensors/no-duals-no-overrides.rkt")) + ((uniform-tensors) #'(require "uniform-tensors/no-duals-no-overrides.rkt")) + ((accelerated-tensors) #'(require "accelerated-tensors/no-duals-no-overrides.rkt")) ((nested-tensors) #'(require "nested-tensors/no-duals-no-overrides.rkt"))) #,(case (tensor-implementation) + ((lazy) #'(provide (all-from-out "lazy/no-duals-no-overrides.rkt"))) ((learner) #'(provide (all-from-out "learner/no-duals-no-overrides.rkt"))) ((flat-tensors) #'(provide (all-from-out "flat-tensors/no-duals-no-overrides.rkt"))) + ((uniform-tensors) #'(provide (all-from-out "uniform-tensors/no-duals-no-overrides.rkt"))) + ((accelerated-tensors) #'(provide (all-from-out "accelerated-tensors/no-duals-no-overrides.rkt"))) ((nested-tensors) #'(provide (all-from-out "nested-tensors/no-duals-no-overrides.rkt"))))))) (load-tensors) diff --git a/impl-no-duals.rkt b/impl-no-duals.rkt index 05043ba..a90faf4 100644 --- a/impl-no-duals.rkt +++ b/impl-no-duals.rkt @@ -11,12 +11,18 @@ (printf "Tensor implementation (no-duals): ~s~%" (tensor-implementation)) #`(begin #,(case (tensor-implementation) + ((lazy) #'(require "lazy/no-duals.rkt")) ((learner) #'(require "learner/no-duals.rkt")) ((flat-tensors) #'(require "flat-tensors/no-duals.rkt")) + ((uniform-tensors) #'(require "uniform-tensors/no-duals.rkt")) + ((accelerated-tensors) #'(require "accelerated-tensors/no-duals.rkt")) ((nested-tensors) #'(require "nested-tensors/no-duals.rkt"))) #,(case (tensor-implementation) + ((lazy) #'(provide (all-from-out "lazy/no-duals.rkt"))) ((learner) #'(provide (all-from-out "learner/no-duals.rkt"))) ((flat-tensors) #'(provide (all-from-out "flat-tensors/no-duals.rkt"))) + ((uniform-tensors) #'(provide (all-from-out "uniform-tensors/no-duals.rkt"))) + ((accelerated-tensors) #'(provide (all-from-out "accelerated-tensors/no-duals.rkt"))) ((nested-tensors) #'(provide (all-from-out "nested-tensors/no-duals.rkt"))))))) (load-tensors) diff --git a/impl-no-overrides.rkt b/impl-no-overrides.rkt index 0f479e8..4ece6b3 100644 --- a/impl-no-overrides.rkt +++ b/impl-no-overrides.rkt @@ -11,12 +11,18 @@ (printf "Tensor implementation (no-overrides): ~s~%" (tensor-implementation)) #`(begin #,(case (tensor-implementation) + ((lazy) #'(require "lazy/no-overrides.rkt")) ((learner) #'(require "learner/no-overrides.rkt")) ((flat-tensors) #'(require "flat-tensors/no-overrides.rkt")) + ((uniform-tensors) #'(require "uniform-tensors/no-overrides.rkt")) + ((accelerated-tensors) #'(require "accelerated-tensors/no-overrides.rkt")) ((nested-tensors) #'(require "nested-tensors/no-overrides.rkt"))) #,(case (tensor-implementation) + ((lazy) #'(provide (all-from-out "lazy/no-overrides.rkt"))) ((learner) #'(provide (all-from-out "learner/no-overrides.rkt"))) ((flat-tensors) #'(provide (all-from-out "flat-tensors/no-overrides.rkt"))) + ((uniform-tensors) #'(provide (all-from-out "uniform-tensors/no-overrides.rkt"))) + ((accelerated-tensors) #'(provide (all-from-out "accelerated-tensors/no-overrides.rkt"))) ((nested-tensors) #'(provide (all-from-out "nested-tensors/no-overrides.rkt"))))))) (load-tensors) diff --git a/impl.rkt b/impl.rkt index 4f09203..bf1c1ca 100644 --- a/impl.rkt +++ b/impl.rkt @@ -11,12 +11,18 @@ (printf "Tensor implementation: ~s~%" (tensor-implementation)) #`(begin #,(case (tensor-implementation) + ((lazy) #'(require "lazy.rkt")) ((learner) #'(require "learner.rkt")) ((flat-tensors) #'(require "flat-tensors.rkt")) + ((uniform-tensors) #'(require "uniform-tensors.rkt")) + ((accelerated-tensors) #'(require "accelerated-tensors.rkt")) ((nested-tensors) #'(require "nested-tensors.rkt"))) #,(case (tensor-implementation) + ((lazy) #'(provide (all-from-out "lazy.rkt"))) ((learner) #'(provide (all-from-out "learner.rkt"))) ((flat-tensors) #'(provide (all-from-out "flat-tensors.rkt"))) + ((uniform-tensors) #'(provide (all-from-out "uniform-tensors.rkt"))) + ((accelerated-tensors) #'(provide (all-from-out "accelerated-tensors.rkt"))) ((nested-tensors) #'(provide (all-from-out "nested-tensors.rkt"))))))) (load-tensors) diff --git a/info.rkt b/info.rkt index b00546a..6b803c0 100644 --- a/info.rkt +++ b/info.rkt @@ -1,12 +1,16 @@ #lang info (define collection "malt") -(define deps '(["base" #:version "8.2"] "rackunit-lib")) +(define deps '(["base" #:version "8.2"] + "rackunit-lib" + "opencl" + "xxhash" + "string-interpolation")) (define pkg-desc "A MAchine Learning Toolkit accompanying The Little Learner: A Straight Line to Deep Learning by Daniel P. Friedman and Anurag Mendhekar") (define version "0.1") (define compile-omit-paths (list #rx"test\\\\" #rx"test/")) (define test-omit-paths (list #rx"test\\\\" #rx"test/")) -(define pkg-authors '("Anurag Mendhekar" "Daniel P. Friedman")) +(define pkg-authors '("Anurag Mendhekar" "Daniel P. Friedman" "Darshal Shetty")) (define build-deps '("scribble-lib" "racket-doc" "rackunit-lib")) (define scribblings '(("scribblings/malt.scrbl" (multi-page)))) (define license 'MIT) diff --git a/lazy.rkt b/lazy.rkt new file mode 100644 index 0000000..0505225 --- /dev/null +++ b/lazy.rkt @@ -0,0 +1,48 @@ +#lang racket/base + +(require + (except-in "lazy/tensors.rkt" + rank shape reshape tref trefs tensor? tlen ref refr)) + +(require "lazy/autodiff.rkt") +(require "lazy/ext-ops.rkt") + +(provide + tolerance + + len ref refr + + tref tlen list->tensor tensor build-tensor + + ext1-ρ ext2-ρ ext1-∇ ext2-∇ + + print-compiler? compiler-cache + + dual dual? ρ κ ∇ ∇¹ (rename-out (∇ gradient-of)) map* + + ext1 ext2 prim1 prim2 + + scalar? tensor? rank shape reshape trefs + + trace-print make-printable max-tensor-print-length check-dual-equal? check-ρ-∇ + + (rename-out (d+ +) (d- -) (d* *) (d/ /) (d-rectify rectify) + (d-exp exp) (d-log log) (d-expt expt) (d-sqrt sqrt) (d-sqr sqr) + (d-sum sum) (d-abs abs) (d*-2-1 *-2-1) (d-argmax argmax) + (d-max max) (d-sum-cols sum-cols) (d-correlate correlate) + (d-flatten flatten) + (d-concat concat) (d-concat-n concat-n)) + + +-ρ --ρ *-ρ /-ρ rectify-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ + sum-ρ abs-ρ *-2-1-ρ argmax-ρ + max-ρ sum-cols-ρ correlate-ρ + flatten-ρ concat-ρ + + +-0-0 --0-0 *-0-0 /-0-0 expt-0-0 + exp-0 log-0 sqrt-0 abs-0 rectify-0 + + sum-1 argmax-1 max-1 flatten-2 concat-1-1 + + =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/lazy/autodiff.rkt b/lazy/autodiff.rkt new file mode 100644 index 0000000..ed89acd --- /dev/null +++ b/lazy/autodiff.rkt @@ -0,0 +1,23 @@ +#lang racket + +(require "autodiff/A-autodiff.rkt") +(require "autodiff/B-prims.rkt") +(require "autodiff/C-dualized-tensor-ops.rkt") +(require "autodiff/D-test-helpers.rkt") +(require "autodiff/E-print.rkt") + +(provide dual dual? ρ κ ∇ ∇¹ scalar? trace-print dual* map*) +(provide prim1 prim2 ext1 ext2) +(provide (rename-out (d-rank rank) + (d-shape shape) + (d-reshape reshape) + (d-tref tref) + (d-trefs trefs) + (d-tensor? tensor?) + (d-tlen tlen) + (d-ref ref) + (d-refr refr))) + +(provide check-dual-equal? check-ρ-∇) + +(provide make-printable max-tensor-print-length) diff --git a/lazy/autodiff/A-autodiff.rkt b/lazy/autodiff/A-autodiff.rkt new file mode 100644 index 0000000..cae86b0 --- /dev/null +++ b/lazy/autodiff/A-autodiff.rkt @@ -0,0 +1,121 @@ +#lang racket + +(require string-interpolation) +(require "../tensors.rkt") + +;;---------------------------- +;; Real part of a dual is always a tensor (of any rank) +;;---------------------------- + +(define dual? + (λ (x) + (and (vector? x) (eq? (vector-ref x 0) dual)))) + +(define dual + (λ (r k) + (vector dual r k))) + +(define dual* + (λ (d) + (dual (ρ d) end-of-chain))) + +(define ρ + (λ (d) + (cond + ((dual? d) (scalarize (vector-ref d 1))) + (else (scalarize d))))) + +(define κ + (λ (d) + (cond + ((dual? d) (vector-ref d 2)) + (else end-of-chain)))) + +(define scalar? + (λ (d) + (or (number? d) + (and (dual? d) + (number? (ρ d)))))) + +(define dual-like? + (λ (d) + (or (dual? d) + (number? d) + (tensor? d)))) + +;;---------------------------- +;; Chain rule +;;---------------------------- + +(define end-of-chain + (λ (d z σ) + (let ((g (hash-ref σ d 0.0))) + (hash-set σ d (+-ρ z g))))) + +(define +-ρ + (ext2-ρ + (λ (a b) "@{a} + @{b}") 0 0)) + +;;---------------------------- +;; Reverse-mode AD +;;---------------------------- + +(define ∇ + (λ (f theta) + (let ((wrt (map* dual* theta))) + (∇-once (f wrt) wrt)))) + +(define ∇¹ + (λ (f) + (λ xs + (let ((wrt (map* dual* xs))) + (∇-once (apply f wrt) wrt))))) + +(define ∇-once + (λ (y wrt) + (let ((σ (∇σ y (hasheq)))) + (map* (λ (d) + (hash-ref σ d 0.0)) + wrt)))) + +(define ∇σ + (λ (y σ) + (cond + ((dual-like? y) ((κ y) y (one-like (ρ y)) σ)) + ((list? y) (∇σ-list y σ)) + (else (printf "Unknown: ~a~%" y))))) + +(define ∇σ-list + (λ (y σ) + (cond + ((null? y) σ) + (else + (let ((σ-hat (∇σ (ref y 0) σ))) + (∇σ-list (refr y 1) σ-hat)))))) + +;;---------------------------- +;; General helpers +;;---------------------------- + +(define map* + (λ (f y) + (cond + ((dual-like? y) (f y)) + ((list? y) + (map (λ (yi) + (map* f yi)) + y)) + (else y)))) + +(define trace-print + (λ (v port) + (cond + ((dual? v) (trace-print (ρ v) port)) + (else (fprintf port "~a~%" v))))) + +(define (one-like s) ((ext1-ρ (λ (x) 1.0) (λ (x) "1.0") 0) s)) + +(include "test/test-A-autodiff.rkt") + +(provide + dual dual? ρ κ ∇ ∇¹ dual* scalar? end-of-chain + trace-print map*) diff --git a/lazy/autodiff/B-prims.rkt b/lazy/autodiff/B-prims.rkt new file mode 100644 index 0000000..5b4b252 --- /dev/null +++ b/lazy/autodiff/B-prims.rkt @@ -0,0 +1,130 @@ +#lang racket + +(require (only-in "../tensors/c0-ast.rkt" + tpmake-prim1-ρ + tpmake-prim2-ρ + tpmake-prim1-∇ + tpmake-prim2-∇)) +(require "../tensors.rkt") +(require (only-in "../tensors/c1-racket-runtime.rkt" ext2-∇-result)) +(require (only-in "../tensors/c0-ast.rkt" tcomp-ds-ref)) +(require "A-autodiff.ss") + +(struct prim (ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn shape-fn signature expects-prealloc? proc) + #:property prop:procedure (λ (this . args) + (apply (prim-proc this) args))) + +;;TODO: Add new ast nodes for the 4 forces being done in the four preallocated->functional-* functions +(define prim1 + (let ((id 0)) + (λ (ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn [shape (λ (l . r) l)] [expects-prealloc? #f]) + (let ((prim-sign (string-append "p1" (~r id #:base 16)))) + (set! id (add1 id)) + (prim ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn shape prim-sign expects-prealloc? + (λ (da) + (prim1-dual (if expects-prealloc? + (preallocated->functional-1-ρ ρ-fn ρ-acc-fn prim-sign shape) + ρ-fn) + (if expects-prealloc? + (preallocated->functional-1-∇ ∇-fn ∇-acc-fn prim-sign shape) + ∇-fn) + da))))))) + +;; TODO: Convert the use of force* into the construction of an AST so that we +;; don't prematurely trigger computation. +(define prim1-dual + (λ (ρ-fn ∇-fn da) + (let ((ra (ρ da))) + (dual (ρ-fn ra) + (λ (d z σ) + (force*1 (∇-fn ra z) + (λ (ga) + ((κ da) da ga σ)))))))) + +(define prim2 + (let ((id 0)) + (λ (ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn [shape (λ (l . r) l)] [expects-prealloc? #f]) + (let ((prim-sign (string-append "p2" (~r id #:base 16)))) + (set! id (add1 id)) + (prim ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn shape prim-sign expects-prealloc? + (λ (da db) + (prim2-dual (if expects-prealloc? + (preallocated->functional-2-ρ ρ-fn ρ-acc-fn prim-sign shape) + ρ-fn) + (if expects-prealloc? + (preallocated->functional-2-∇ ∇-fn ∇-acc-fn prim-sign shape) + ∇-fn) + da db))))))) + +(define prim2-dual + (λ (ρ-fn ∇-fn da db) + (let ((ra (ρ da)) + (rb (ρ db))) + (dual (ρ-fn ra rb) + (λ (d z σ) + (force*2 (λ () (∇-fn ra rb z)) + (λ (ga gb) + (let ((σ-hat ((κ da) da ga σ))) + ((κ db) db gb σ-hat))))))))) + +;;---------------------------- +;; Managing flat-optimized and +;; non-flat ρ and ∇ functions +;;---------------------------- + +(define preallocated->functional-1-ρ + (λ (ρ-fn ρ-fn-acc prim-sign shape-fn) + (λ (ra) + (tpmake-prim1-ρ ρ-fn ρ-fn-acc prim-sign shape-fn ra)))) + +(define preallocated->functional-1-∇ + (λ (∇-fn ∇-fn-acc prim-sign shape-fn) + (λ (ra z) + (tpmake-prim1-∇ ∇-fn ∇-fn-acc prim-sign shape-fn ra z)))) + +(define preallocated->functional-2-ρ + (λ (ρ-fn ρ-fn-acc prim-sign shape-fn) + (λ (ra rb) + (tpmake-prim2-ρ ρ-fn ρ-fn-acc prim-sign shape-fn ra rb)))) + +(define preallocated->functional-2-∇ + (λ (∇-fn ∇-fn-acc prim-sign shape-fn) + (λ (ra rb z) + (let ((out-ref0 (ext2-∇-result (tcomp-ds-ref #f))) + (out-ref1 (ext2-∇-result (tcomp-ds-ref #f)))) + (values + (tpmake-prim2-∇ ∇-fn ∇-fn-acc prim-sign shape-fn ra rb z out-ref0 out-ref1 0) + (tpmake-prim2-∇ ∇-fn ∇-fn-acc prim-sign shape-fn ra rb z out-ref0 out-ref1 1)))))) + +;;---------------------------- +;; Dualized tensor op creators +;;---------------------------- +(define ext1 + (λ (f n) + (unless (prim? f) + (error 'ext1-prim "Function to be extended must be a primitive. Found: ~a" f)) + (prim1 + (ext1-ρ (prim-ρ-fn f) (prim-ρ-acc-fn f) n (prim-shape-fn f) + (prim-expects-prealloc? f) (string-append "r" (prim-signature f))) + (prim-ρ-acc-fn f) + (ext1-∇ (prim-∇-fn f) (prim-∇-acc-fn f) n (prim-shape-fn f) + (prim-expects-prealloc? f) (string-append "n" (prim-signature f))) + (prim-∇-acc-fn f) + (prim-shape-fn f) + #f))) + +(define ext2 + (λ (f m n) + (unless (prim? f) + (error 'ext2-prim "Function to be extended must be a primitive. Found: ~a" f)) + (prim2 + (ext2-ρ (prim-ρ-fn f) (prim-ρ-acc-fn f) m n (prim-shape-fn f) + (prim-expects-prealloc? f) (string-append "r" (prim-signature f))) + (prim-ρ-acc-fn f) + (ext2-∇ (prim-∇-fn f) (prim-∇-acc-fn f) m n (prim-shape-fn f) + (prim-expects-prealloc? f) (string-append "n" (prim-signature f))) + (prim-∇-acc-fn f) + (prim-shape-fn f) + #f))) + +(provide prim1 prim2 ext1 ext2) diff --git a/lazy/autodiff/C-dualized-tensor-ops.rkt b/lazy/autodiff/C-dualized-tensor-ops.rkt new file mode 100644 index 0000000..988bc80 --- /dev/null +++ b/lazy/autodiff/C-dualized-tensor-ops.rkt @@ -0,0 +1,51 @@ +#lang racket + +(require "../tensors.rkt") +(require "A-autodiff.ss") + + +;;---------------------------- +;; Tensor ops, cleaned up. +;;---------------------------- + +(define d-rank + (lambda (t) + (rank (ρ t)))) + +(define d-shape + (λ (t) + (shape (ρ t)))) + +(define d-reshape + (λ (s t) + (cond + ((dual? t) + (dual (reshape s (ρ t)) + (κ t))) + (else (reshape s t))))) + +(define d-tref + (λ (t i) + (tref (ρ t) i))) + +(define d-trefs + (λ (t b) + (trefs (ρ t) b))) + +(define d-tensor? + (λ (t) + (tensor? (ρ t)))) + +(define d-tlen + (λ (t) + (tlen (ρ t)))) + +(define d-ref + (λ (l i) + (ref l (ρ i)))) + +(define d-refr + (λ (l i) + (refr l (ρ i)))) + +(provide d-rank d-shape d-reshape d-trefs d-tensor? d-tlen d-ref d-refr d-tref) diff --git a/lazy/autodiff/D-test-helpers.rkt b/lazy/autodiff/D-test-helpers.rkt new file mode 100644 index 0000000..a680951 --- /dev/null +++ b/lazy/autodiff/D-test-helpers.rkt @@ -0,0 +1,59 @@ +#lang racket + +(require "../tensors.rkt") +(require "../tensors/c0-ast.rkt") +(require "A-autodiff.ss") +(require (except-in "../../accelerated-tensors/ext-impl.rkt" + scalarize)) + +(require rackunit) + +(define force-print-store + (λ (t) + (with-output-to-string + (λ () + (print-vec (flat-store (↓ t) + #;(list-ref (unbox (tpromise-dst t)) 0))))))) + +(define-binary-check (check-dual-equal? equal-wt? actual expected)) +(define-check (ρ-∇-checker fn args ans grads) + (let* ((y (apply fn args)) + (g (apply (∇¹ fn) args)) + (ans-ρ (ρ ans))) + (cond + ((and (equal-wt? ans-ρ (ρ y)) + (equal-wt? grads (ρ g))) (void)) + ((equal-wt? ans-ρ (ρ y)) + (fail-check (format "Gradients failed to match.~%actual:~%~s~%expected:~%~s~%~%actual store:~%~a~%expected store:~%~a~%" + (ρ g) grads (map force-print-store (ρ g)) (map force-print-store grads)))) + (else + (fail-check (format "Answers failed to match.~%actual:~%~s~%expected:~%~s~%~%actual store:~%~a~%expected store:~%~a~%" + (ρ y) ans-ρ (force-print-store (ρ y)) (force-print-store ans-ρ))))))) + +(define-syntax check-ρ-∇ + (syntax-rules () + [(check-both (fn args ...) ans grads) + (ρ-∇-checker fn (list args ...) ans grads)])) + +(define equal-wt? + (λ (a b) + (cond + ((and (tensor? a) (tensor? b)) + (tensor-equal? a b)) + ((dual? a) (equal-wt? (ρ a) b)) + ((dual? b) (equal-wt? a (ρ b))) + ((and (vector? a) (vector? b) + (= (vector-length a) (vector-length b))) + (vector-andmap equal-wt? a b)) + ((and (pair? a) (pair? b) + (= (length a) (length b))) + (andmap equal-wt? a b)) + (else (equal? a b))))) + + +(define vector-andmap + (λ (f v1 v2) + (for/fold ([s #t]) ([v1 v1][v2 v2]) + (and s (f v1 v2))))) + +(provide check-dual-equal? check-ρ-∇) diff --git a/lazy/autodiff/E-print.rkt b/lazy/autodiff/E-print.rkt new file mode 100644 index 0000000..b39f25a --- /dev/null +++ b/lazy/autodiff/E-print.rkt @@ -0,0 +1,26 @@ +#lang racket + +(require "A-autodiff.rkt") +(require "../tensors/0-lazy.rkt") +(require "../tensors/1-reflect.rkt") +(require (except-in "../../accelerated-tensors/ext-impl.rkt" scalarize)) + +(define max-tensor-print-length (make-parameter 5)) + +(define make-printable + (λ (y [max-length (max-tensor-print-length)]) + (cond + ((dual? y) (make-printable (ρ y))) + ((tpromise? y) + (make-printable (↓ y) max-length)) + ((flat? y) (make-printable-flat y max-length)) + ((list? y) + (map (λ (le) (make-printable le max-length)) y)) + ((vector? y) + (vector-map (λ (ve) (make-printable ve max-length)) y)) + (else y)))) + +(include "test/test-E-print.rkt") + +(provide max-tensor-print-length + make-printable) diff --git a/lazy/autodiff/test/test-A-autodiff.rkt b/lazy/autodiff/test/test-A-autodiff.rkt new file mode 100644 index 0000000..b453b3e --- /dev/null +++ b/lazy/autodiff/test/test-A-autodiff.rkt @@ -0,0 +1,15 @@ +(module+ test + (require rackunit) + (let ((k0 end-of-chain)) + (let ((dual0 0) + (dual1 (dual 1 k0))) + + (check-equal? dual1 (dual 1 k0)) + (check-true (dual? dual1)) + (check-false (dual? 1)) + (check-equal? (ρ dual1) 1) + (check-equal? (ρ dual0) 0) + (check-equal? (κ dual1) k0) + + (check-equal? (map* (λ (d) (ρ d)) (∇-once dual1 (list dual0 dual1))) + '(0.0 1.0))))) diff --git a/lazy/autodiff/test/test-E-print.rkt b/lazy/autodiff/test/test-E-print.rkt new file mode 100644 index 0000000..092f78b --- /dev/null +++ b/lazy/autodiff/test/test-E-print.rkt @@ -0,0 +1,71 @@ +(module+ test + (require rackunit) + (require "../tensors.rkt") + + (define long-tensor + (tensor 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15)) + + (define dualized-long-tensor + (dual long-tensor end-of-chain)) + + (define deep-tensor + (tensor long-tensor long-tensor long-tensor long-tensor long-tensor + long-tensor long-tensor long-tensor long-tensor long-tensor + long-tensor long-tensor long-tensor long-tensor long-tensor)) + + (define deeper-tensor + (tensor deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor + deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor + deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor)) + + (check-equal? (make-printable long-tensor 3) (fake-tensor '(1.0 2.0 3.0 ...))) + (check-equal? (make-printable deep-tensor 3) + (fake-tensor + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + '...))) + + (check-equal? (make-printable deeper-tensor 3) + (fake-tensor + (list + (fake-tensor + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + '...)) + (fake-tensor + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + '...)) + (fake-tensor + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + '...)) + '...))) + (parameterize ((max-tensor-print-length 3)) + (check-equal? (make-printable dualized-long-tensor 3) (fake-tensor '(1.0 2.0 3.0 ...))) + (check-equal? (make-printable (list long-tensor dualized-long-tensor deeper-tensor)) + (list + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor + (list + (fake-tensor + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + '...)) + (fake-tensor + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + '...)) + (fake-tensor + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + '...)) + '...)))))) diff --git a/lazy/ext-ops.rkt b/lazy/ext-ops.rkt new file mode 100644 index 0000000..8a58d1a --- /dev/null +++ b/lazy/ext-ops.rkt @@ -0,0 +1,40 @@ +#lang racket + +(require "ext-ops/A-scalar-ops.rkt") +(require "ext-ops/B-comparators.rkt") +(require "ext-ops/C-star-2-1.rkt") +(require "ext-ops/D-sum.rkt") +(require "ext-ops/E-argmax.rkt") +(require "ext-ops/F-max.rkt") +(require "ext-ops/G-correlate.rkt") +(require "ext-ops/I-flatten.rkt") +(require "ext-ops/K-concat.rkt") + +(provide d+ d- d* d/ + d-expt d-exp d-log d-abs + d-rectify d-sqrt d-sqr + + +-0-0 --0-0 *-0-0 /-0-0 expt-0-0 + exp-0 log-0 sqrt-0 abs-0 rectify-0 + + +-ρ --ρ *-ρ /-ρ + expt-ρ exp-ρ log-ρ abs-ρ + rectify-ρ sqrt-ρ sqr-ρ zeroes-ρ) + +(provide =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) + +(provide d*-2-1 *-2-1-ρ) + +(provide sum-1 d-sum sum-ρ d-sum-cols sum-cols-ρ) + +(provide argmax-1 d-argmax argmax-ρ) + +(provide max-1 d-max max-ρ) + +(provide correlate-ρ d-correlate) + +(provide flatten-2 d-flatten flatten-ρ) + +(provide concat-1-1 d-concat concat-ρ + d-concat-n concat-n-ρ) diff --git a/lazy/ext-ops/A-scalar-ops.rkt b/lazy/ext-ops/A-scalar-ops.rkt new file mode 100644 index 0000000..72b33f9 --- /dev/null +++ b/lazy/ext-ops/A-scalar-ops.rkt @@ -0,0 +1,210 @@ +#lang racket + +(require string-interpolation) +(require (only-in "../tensors.rkt" ext1-ρ ext2-ρ)) +(require "../autodiff.rkt") + +(define +-0-0-ρ-acc + (λ (a b) + "@{a}+@{b}")) + +(define +-0-0 + (prim2 + + +-0-0-ρ-acc + (λ (a b z) + (values z z)) + (λ (a b z) + (values z z)))) + +(define --0-0-ρ-acc + (λ (a b) + "@{a}-@{b}")) + +(define --0-0 + (prim2 - + --0-0-ρ-acc + (λ (a b z) + (values z (- z))) + (λ (a b z) + (values z "(- @{z})")))) + +(define *-0-0-ρ-acc + (λ (a b) + "@{a}*@{b}")) + +(define *-0-0 + (prim2 * + *-0-0-ρ-acc + (λ (a b z) + (values (* b z) (* a z))) + (λ (a b z) + (values "@{b}*@{z}" "@{a}*@{z}")))) + +(define /-0-0-ρ-acc + (λ (a b) + "@{a}/@{b}")) + +(define /-0-0 + (prim2 / + /-0-0-ρ-acc + (λ (a b z) + (values (* z (/ 1 b)) + (* z (/ (- a) (* b b))))) + (λ (a b z) + (values "(@{z} * (1 / @{b}))" + "(@{z} * ((- @{a}) / (@{b} * @{b})))")))) + +(define expt-0-0-ρ-acc + (λ (a b) + "pow(@{a}, @{b})")) + +(define expt-0-0 + (prim2 expt + expt-0-0-ρ-acc + (λ (a b z) + (values (* z (* b (expt a (- b 1)))) + (* z (* (expt a b) (log a))))) + (λ (a b z) + (values "(@{z} * (@{b} * pow(@{a}, (@{b} - 1))))" + "(@{z} * (pow(@{a}, @{b}) * log(@{a})))")))) + +(define exp-0-ρ-acc + (λ (a) + "exp(@{a})")) + +(define exp-0 + (prim1 exp + exp-0-ρ-acc + (λ (a z) + (* z (exp a))) + (λ (a z) + "(@{z} * exp(@{a}))"))) + +(define log-0-ρ-acc + (λ (a) + "log(@{a})")) + +(define log-0 + (prim1 log + log-0-ρ-acc + (λ (a z) + (* z (/ 1 a))) + (λ (a z) + "(@{z} * (1 / @{a}))"))) + +(define sqrt-0-ρ-acc + (λ (a) + "sqrt(@{a})")) + +(define sqrt-0 + (prim1 sqrt + sqrt-0-ρ-acc + (λ (x z) + (/ z (* 2 (sqrt x)))) + (λ (x z) + "(@{z} / (2 * sqrt(@{x})))"))) + +(define abs-0-ρ + (λ (x) + (cond + ((< x 0) (* -1 x)) + (else x)))) + +(define abs-0-ρ-acc + (λ (x) + "fabs(@{x})")) + +(define abs-0-∇ + (λ (x z) + (cond + ((< x 0) (- z)) + (else z)))) + +(define abs-0-∇-acc + (λ (x z) + "sign(@{x}) * @{z}")) + +(define abs-0 + (prim1 abs-0-ρ abs-0-ρ-acc abs-0-∇ abs-0-∇-acc)) + +(define rectify-0-ρ + (λ (s) + (cond + ((< s 0.0) 0.0) + (else s)))) + +(define rectify-0-ρ-acc + (λ (s) + "fmax(0.0f, @{s})")) + +(define rectify-0-∇ + (λ (s z) + (cond + ((< s 0.0) 0.0) + (else z)))) + +(define rectify-0-∇-acc + (λ (s z) + "step(0, @{s}) * @{z}")) + +(define rectify-shape + (λ (s) s)) + +(define rectify-0 + (prim1 rectify-0-ρ rectify-0-ρ-acc rectify-0-∇ rectify-0-∇-acc rectify-shape)) + +;;------------------------------------ +;; differentiable extended functions. +;;------------------------------------ + +(define d* (ext2 *-0-0 0 0)) +(define d+ (ext2 +-0-0 0 0)) +(define d- (ext2 --0-0 0 0)) +(define d/ (ext2 /-0-0 0 0)) +(define d-expt (ext2 expt-0-0 0 0)) + +(define d-exp (ext1 exp-0 0)) +(define d-log (ext1 log-0 0)) +(define d-abs (ext1 abs-0 0)) +(define d-rectify (ext1 rectify-0 0)) +(define d-sqrt (ext1 sqrt-0 0)) + +(define d-sqr + (λ (x) + (d* x x))) + +;;------------------------------------ +;; non-differentiable extended functions. +;;------------------------------------ + +(define *-ρ (ext2-ρ * *-0-0-ρ-acc 0 0)) +(define +-ρ (ext2-ρ + +-0-0-ρ-acc 0 0)) +(define --ρ (ext2-ρ - --0-0-ρ-acc 0 0)) +(define /-ρ (ext2-ρ / /-0-0-ρ-acc 0 0)) +(define expt-ρ (ext2-ρ expt expt-0-0-ρ-acc 0 0)) + +(define exp-ρ (ext1-ρ exp exp-0-ρ-acc 0)) +(define log-ρ (ext1-ρ log log-0-ρ-acc 0)) +(define abs-ρ (ext1-ρ abs-0-ρ abs-0-ρ-acc 0)) +(define rectify-ρ (ext1-ρ rectify-0-ρ rectify-0-ρ-acc 0)) +(define sqrt-ρ (ext1-ρ sqrt sqrt-0-ρ-acc 0)) + +(define sqr-ρ + (λ (x) + (*-ρ x x))) + +(define zeroes-ρ + (ext1-ρ (λ (_) 0.0) (λ (_) "0.0") 0)) + +(include "test/test-A-scalar-ops.rkt") + +(provide +-0-0 --0-0 *-0-0 /-0-0 expt-0-0 + exp-0 log-0 abs-0 rectify-0 sqrt-0 + + d+ d- d* d/ + d-expt d-exp d-log d-abs + d-rectify d-sqrt d-sqr + + +-ρ --ρ *-ρ /-ρ + expt-ρ exp-ρ log-ρ abs-ρ + rectify-ρ sqrt-ρ sqr-ρ zeroes-ρ) diff --git a/lazy/ext-ops/B-comparators.rkt b/lazy/ext-ops/B-comparators.rkt new file mode 100644 index 0000000..3db8e0f --- /dev/null +++ b/lazy/ext-ops/B-comparators.rkt @@ -0,0 +1,99 @@ +#lang racket + +(require string-interpolation) +(require "../autodiff.rkt") + +;;---------------------------- +;; Boolean comparators +;;---------------------------- + +(define comparator + (λ (f) + (λ (da db) + (f (ρ da) (ρ db))))) + +(define =-0-0 + (comparator =)) + +(define <-0-0 + (comparator <)) + +(define <=-0-0 + (comparator <=)) + +(define >-0-0 + (comparator >)) + +(define >=-0-0 + (comparator >)) + +;;---------------------------- +;; Tensorized comparators +;;---------------------------- + +(define comparator-ρ + (λ (f) + (λ (da db) + (cond + ((f (ρ da) (ρ db)) 1.0) + (else 0.0))))) + +(define comparator-ρ-acc + (λ (f) + (λ (a b) + "@{a} @{f} @{b}"))) + +(define comparator-∇ + (λ (f) + (λ (da db z) + (cond + ((f (ρ da) (ρ db)) (values z z)) + (else (values 0.0 0.0)))))) + +(define comparator-∇-acc + (λ (f) + (λ (a b z) + (let ((bool "@{a} @{f} @{b}")) + (values "@{bool}*@{z}" "@{bool}*@{z}"))))) + +(define comparator-shape + (λ (f) + (λ (sa sb) + sa))) + +(define comparator-prim + (λ (f f-acc) + (prim2 (comparator-ρ f) (comparator-ρ-acc f-acc) + (comparator-∇ f) (comparator-∇-acc f-acc) + (comparator-shape f)))) + +(define extended-comparator + (λ (f f-acc) + (ext2 (comparator-prim f f-acc) 0 0))) + +(define =-1 + (extended-comparator = "==")) + +(define <-1 + (extended-comparator < "<")) + +(define >-1 + (extended-comparator > ">")) + +(define <=-1 + (extended-comparator <= "<=")) + +(define >=-1 + (extended-comparator >= ">=")) + +(define != + (λ (a b) + (not (= a b)))) + +(define !=-1 + (extended-comparator != "!=")) + +(include "test/test-B-comparators.rkt") + +(provide =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/lazy/ext-ops/C-star-2-1.rkt b/lazy/ext-ops/C-star-2-1.rkt new file mode 100644 index 0000000..fab49f2 --- /dev/null +++ b/lazy/ext-ops/C-star-2-1.rkt @@ -0,0 +1,79 @@ +#lang racket + +(require string-interpolation) +(require "../../accelerated-tensors/ext-impl.rkt") +(require (only-in "../tensors.rkt" ext2-ρ)) +(require "../autodiff.rkt") + +(define *-2-1-base-ρ + (λ (v0 i0 stride0 + v1 i1 stride1 + v-out i-out stride-out) + (for ([i (in-range 0 stride-out)]) + (vset! v-out (+ i-out i) + (* (vref v0 (+ i0 i)) + (vref v1 (+ i1 (modulo i stride1)))))))) + +(define *-2-1-base-ρ-acc + (λ (v0 i0 stride0 + v1 i1 stride1 + v-out i-out stride-out) + #< v max) (values v (+ (- i i0) 0.0))) + (else (values max max-i)))))))) + +(define argmax-1-ρ-acc + (λ (v0 i0 stride0 + v-out i-out stride-out) + #< max) { + max = v; + max_i = i - @{i0} + 0.0; + } + } + @{v-out}[@{i-out}] = max_i; +EOF + )) + +(define argmax-1-∇ + (λ (g0 v0 i0 stride0 + vz iz stride-z) + (let ((z (vref vz iz))) + (for ([i (in-range i0 (+ i0 stride0))]) + (vset! g0 i 0.0))))) + +(define argmax-1-∇-acc + (λ (g0 v0 i0 stride0 + vz iz stride-z) + #< v max) v) + (else max))))))) + +(define max-1-ρ-acc + (λ (v0 i0 stride0 + v-out i-out stride-out) + #< v max) (values v (- i i0))) + (else (values max max-i)))))))) + +(define max-1-∇-acc + (λ (g0 v0 i0 stride0 + vz iz stride-z) + #< max) { + max = v; + max_i = i - @{i0}; + } + } + for(int i=@{i0}; i<@{i0}+@{stride0}; i++) { + if(i == @{i0}+max_i) { + @{g0}[i] += z; + } else { + @{g0}[i] += 0.0; + } + } +EOF + )) + +(define max-shape + (λ (st) + (cdr st))) + +(define max-1 + (prim1 max-1-ρ max-1-ρ-acc max-1-∇ max-1-∇-acc max-shape #t)) + +(define d-max + (ext1 max-1 1)) + +(define max-ρ + (ext1-ρ max-1-ρ max-1-ρ-acc 1 max-shape #t)) + +(include "test/test-F-max.rkt") + +(provide max-1 d-max max-ρ) diff --git a/lazy/ext-ops/G-correlate.rkt b/lazy/ext-ops/G-correlate.rkt new file mode 100644 index 0000000..14c0cb6 --- /dev/null +++ b/lazy/ext-ops/G-correlate.rkt @@ -0,0 +1,158 @@ +#lang racket + +(require string-interpolation) +(require "../../accelerated-tensors/ext-impl.rkt") +(require (only-in "../tensors.rkt" ext2-ρ len)) +(require "../autodiff.rkt") + +;; Correlation is written taking into account how ext2 works +;; Ext2 is responsible for producing the i-out'th output from +;; v0[i0] and v1[i1], we take advantage of this. The shape constants +;; n b m d are pre-calculated the striding constants nd md and qd +;; are calculated. + +(define correlate-3-1-ρ + (λ (nd md qd) + (λ (v0 i0 _ + v1 i1 d + v-out i-out b) + (let* ((i1-min (- i1 (modulo i1 nd))) + (i1-max (+ i1-min nd))) + (for ((i (in-range 0 b))) + (vset! v-out (+ i-out i) + (for/fold ([sum 0.0]) ([j (in-range 0 md)]) + (let ((ai (+ i0 (* i md) j)) + (bi (- (+ i1 j) qd))) + (cond + ((and (>= bi i1-min) (< bi i1-max)) + (let ((a (vref v0 ai)) + (b (vref v1 bi))) + (+ sum (* a b)))) + (else sum)))))))))) + +(define correlate-3-1-ρ-acc + (λ (nd md qd) + (λ (v0 i0 _ + v1 i1 d + v-out i-out b) + #<= i1_min && bi < i1_max) { + sum += @{v0}[ai] * @{v1}[bi]; + } + } + @{v-out}[@{i-out}+i] = sum; + } +EOF + ))) + +(define correlate-3-1-∇ + (λ (nd md qd) + (λ (g0 g1 + v0 i0 bmd + v1 i1 d + vz iz b) + (let* ((i1-min (- i1 (modulo i1 nd))) + (i1-max (+ i1-min nd))) + (for ((i (in-range 0 b))) + (let ((z (vref vz (+ iz i)))) + (for ([j (in-range 0 md)]) + (let ((ai (+ i0 (* i md) j)) + (bi (- (+ i1 j) qd))) + (when (and (>= bi i1-min) (< bi i1-max)) + (let ((a (vref v0 ai)) + (b (vref v1 bi))) + (vset! g0 ai + (+ (vref g0 ai) (* z b))) + (vset! g1 bi + (+ (vref g1 bi) (* z a))))))))))))) + +(define correlate-3-1-∇-acc + (λ (nd md qd) + (λ (g + v0 i0 bmd + v1 i1 d + vz iz b) + (values + #<= i1_min && bi < i1_max) { + @{g}[ai] += z * @{v1}[bi]; + } + } + } +EOF + + #<= i1_min && bi < i1_max) { + @{g}[bi] += z * @{v0}[ai]; + } + } + } +EOF + )))) + +(define correlate-shape + (λ (bmd nd) + (list (car bmd)))) + +(define correlate-3-1 + (λ (nd md qd) + (prim2 + (correlate-3-1-ρ nd md qd) + (correlate-3-1-ρ-acc nd md qd) + (correlate-3-1-∇ nd md qd) + (correlate-3-1-∇-acc nd md qd) + correlate-shape #t))) + +(define d-correlate + (λ (bank signal) + (let* ((b-m-d (last 3 (shape (ρ bank)))) + (n-d (last 2 (shape (ρ signal)))) + (d (ref n-d 1)) + (nd (* d (ref n-d 0))) + (m (ref b-m-d 1)) + (q (/ (- m 1) 2)) ;; This is the padding. + (qd (* q d)) + (md (* m d))) + ((ext2 (correlate-3-1 nd md qd) 3 1) bank signal)))) + +(define correlate-ρ + (λ (bank signal) + (let* ((b-m-d (last 3 (shape (ρ bank)))) + (n-d (last 2 (shape (ρ signal)))) + (d (ref n-d 1)) + (nd (* d (ref n-d 0))) + (m (ref b-m-d 1)) + (q (/ (- m 1) 2)) ;; This is the padding. + (qd (* q d)) + (md (* m d))) + ((ext2-ρ (correlate-3-1-ρ nd md qd) (correlate-3-1-ρ-acc nd md qd) 3 1 correlate-shape #t) + bank signal)))) + +(define last + (λ (n s) + (refr s (- (len s) n)))) + +(include "test/test-G-correlate.rkt") + +(provide d-correlate correlate-ρ) diff --git a/lazy/ext-ops/I-flatten.rkt b/lazy/ext-ops/I-flatten.rkt new file mode 100644 index 0000000..0ef02f6 --- /dev/null +++ b/lazy/ext-ops/I-flatten.rkt @@ -0,0 +1,52 @@ +#lang racket + +(require string-interpolation) +(require (only-in "../tensors.rkt" ext1-ρ tref reshape shape ref)) +(require (only-in "../autodiff.rkt" prim1 ext1)) + +(define flatten-2-ρ + (λ (t) + (reshape (flatten-shape (shape t)) t))) + +(define flatten-2-ρ-acc + (λ (v0 i0 stride0 + v-out i-out stride-out) + #<-0-0 a b)) + (check-true (<=-0-0 a b)) + (check-false (>=-0-0 a b)) + (check-false (=-0-0 a b)) + (check-true (=-0-0 a a)) + (check-true (zero? 0)) + (check-false (zero? a)))) diff --git a/lazy/ext-ops/test/test-C-star-2-1.rkt b/lazy/ext-ops/test/test-C-star-2-1.rkt new file mode 100644 index 0000000..233466e --- /dev/null +++ b/lazy/ext-ops/test/test-C-star-2-1.rkt @@ -0,0 +1,28 @@ +(module+ test + (require rackunit) + (require (only-in "../tensors.rkt" tensor)) + + (let ((a (tensor (tensor 3 4 5 6) + (tensor 7 8 9 10))) + (b (tensor 2 3 4 5))) + (check-ρ-∇ (d*-2-1 a b) + (tensor (tensor 6 12 20 30) (tensor 14 24 36 50)) + (list (tensor (tensor 2.0 3.0 4.0 5.0) (tensor 2.0 3.0 4.0 5.0)) + (tensor 10.0 12.0 14.0 16.0))) + (check-ρ-∇ (*-2-1 a b) + (tensor (tensor 6 12 20 30) (tensor 14 24 36 50)) + (list (tensor (tensor 2.0 3.0 4.0 5.0) (tensor 2.0 3.0 4.0 5.0)) + (tensor 10.0 12.0 14.0 16.0)))) + + (let ((a (tensor (tensor 3 4 5 6) + (tensor 7 8 9 10))) + (b (tensor (tensor 2 3 4 5) + (tensor 12 13 14 15)))) + + (check-ρ-∇ (d*-2-1 a b) + (tensor (tensor (tensor 6 12 20 30) (tensor 14 24 36 50)) + (tensor (tensor 36 52 70 90) (tensor 84 104 126 150))) + (list (tensor (tensor 14.0 16.0 18.0 20.0) + (tensor 14.0 16.0 18.0 20.0)) + (tensor (tensor 10.0 12.0 14.0 16.0) + (tensor 10.0 12.0 14.0 16.0)))))) diff --git a/lazy/ext-ops/test/test-D-sum.rkt b/lazy/ext-ops/test/test-D-sum.rkt new file mode 100644 index 0000000..9271f11 --- /dev/null +++ b/lazy/ext-ops/test/test-D-sum.rkt @@ -0,0 +1,66 @@ +(module+ test + (require rackunit) + (require "C-star-2-1.ss") + (require (only-in "../tensors.rkt" tensor)) + (require (only-in "A-scalar-ops.ss" d-sqr d* d-)) + + (let ((a (tensor 3 4 5))) + (check-ρ-∇ (sum-1 a) 12 + (list (tensor 1.0 1.0 1.0))) + (check-ρ-∇ (d-sum a) 12 + (list (tensor 1.0 1.0 1.0)))) + + (let ((a (tensor (tensor 3 4 5) + (tensor 6 7 8)))) + (check-dual-equal? (d-sum a) (tensor 12 21)) + (check-dual-equal? ((∇¹ (λ (b) (d-sum (d* b b)))) a) + (list (tensor (tensor 6.0 8.0 10.0) + (tensor 12.0 14.0 16.0))))) + + (define dot-product + (λ (a b) + (d-sum (d*-2-1 a b)))) + + (define sse + (λ (a b) + (d-sum (d-sqr (d- a b))))) + + (let ((a (tensor (tensor 3 4 5 6) + (tensor 7 8 9 10))) + (b (tensor 2 3 4 5))) + + (check-ρ-∇ (dot-product a b) + (tensor 68 124) + (list (tensor (tensor 2.0 3.0 4.0 5.0) + (tensor 2.0 3.0 4.0 5.0)) + (tensor 10.0 12.0 14.0 16.0))) + + (check-ρ-∇ (sse a b) + (tensor 4 100) + (list (tensor (tensor 2.0 2.0 2.0 2.0) + (tensor 10.0 10.0 10.0 10.0)) + (tensor -12.0 -12.0 -12.0 -12.0)))) + + (let ((a (tensor (tensor 3 4 5 6) + (tensor 7 8 9 10))) + (b (tensor (tensor 2 3 4 5) + (tensor 12 13 14 15)))) + + (check-ρ-∇ (dot-product a b) + (tensor (tensor 68 124) + (tensor 248 464)) + (list (tensor (tensor 14.0 16.0 18.0 20.0) + (tensor 14.0 16.0 18.0 20.0)) + (tensor (tensor 10.0 12.0 14.0 16.0) + (tensor 10.0 12.0 14.0 16.0))))) + (let ((a (tensor (tensor 0 1 2) + (tensor 3 4 5) + (tensor 6 7 8)))) + (check-ρ-∇ (sum-cols-2 a) (tensor 9 12 15) + (list (tensor (tensor 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0)))) + (check-ρ-∇ (d-sum-cols a) (tensor 9 12 15) + (list (tensor (tensor 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0)))))) diff --git a/lazy/ext-ops/test/test-E-argmax.rkt b/lazy/ext-ops/test/test-E-argmax.rkt new file mode 100644 index 0000000..144980a --- /dev/null +++ b/lazy/ext-ops/test/test-E-argmax.rkt @@ -0,0 +1,19 @@ +(module+ test + (require (only-in "../tensors.rkt" tensor)) + + (let ((y (tensor 0.0 0.0 1.0 0.0))) + (check-ρ-∇ (argmax-1 y) 2.0 + (list (tensor 0.0 0.0 0.0 0.0))) + (check-ρ-∇ (d-argmax y) 2.0 + (list (tensor 0.0 0.0 0.0 0.0)))) + + (let ((y (tensor (tensor 0.0 0.0 1.0 0.0) + (tensor 0.0 1.0 0.0 0.0) + (tensor 1.0 0.0 0.0 0.0) + (tensor 0.0 0.0 0.0 1.0)))) + (check-ρ-∇ (d-argmax y) (tensor 2.0 1.0 0.0 3.0) + (list + (tensor (tensor 0.0 0.0 0.0 0.0) + (tensor 0.0 0.0 0.0 0.0) + (tensor 0.0 0.0 0.0 0.0) + (tensor 0.0 0.0 0.0 0.0)))))) diff --git a/lazy/ext-ops/test/test-F-max.rkt b/lazy/ext-ops/test/test-F-max.rkt new file mode 100644 index 0000000..88d74a5 --- /dev/null +++ b/lazy/ext-ops/test/test-F-max.rkt @@ -0,0 +1,14 @@ +(module+ test + (require rackunit) + (require (only-in "../tensors.rkt" tensor)) + + (let ((y (tensor 0.0 1.0 0.0 0.0))) + (check-ρ-∇ (max-1 y) 1.0 (list y)) + (check-ρ-∇ (d-max y) 1.0 (list y))) + + (let ((y (tensor (tensor 0.0 0.0 1.0 0.0) + (tensor 0.0 1.0 0.0 0.0) + (tensor 1.0 0.0 0.0 0.0) + (tensor 0.0 0.0 0.0 1.0)))) + (check-ρ-∇ (d-max y) (tensor 1.0 1.0 1.0 1.0) + (list y)))) diff --git a/lazy/ext-ops/test/test-G-correlate.rkt b/lazy/ext-ops/test/test-G-correlate.rkt new file mode 100644 index 0000000..e3cade2 --- /dev/null +++ b/lazy/ext-ops/test/test-G-correlate.rkt @@ -0,0 +1,118 @@ +(module+ test + (require rackunit) + (require (only-in "../tensors.rkt" tensor ext2-∇ check-tensor-equal?)) + + ;; for testing b = 4 + ;; m = 3 + ;; d = 2 + + ;; signal length n = 6 + + ;; (1 2) (3 4) (5 6) (7 8) (9 10) (11 12) + ;; (1 2) (3 4) (5 6) + ;; (7 8) (9 10) (11 12) + ;; (13 14) (15 16) (17 18) + ;; (19 20) (21 22) (23 24) + + ;; Signal is (n d) + (define signal (tensor (tensor 1 2) + (tensor 3 4) + (tensor 5 6) + (tensor 7 8) + (tensor 9 10) + (tensor 11 12))) + + (define bank (tensor (tensor + (tensor 1 2) + (tensor 3 4) + (tensor 5 6)) + (tensor + (tensor 7 8) + (tensor 9 10) + (tensor 11 12)) + (tensor + (tensor 13 14) + (tensor 15 16) + (tensor 17 18)) + (tensor + (tensor 19 20) + (tensor 21 22) + (tensor 23 24)))) + + (define corr-ρ + (ext2-ρ (correlate-3-1-ρ 12 6 2) (correlate-3-1-ρ-acc 12 6 2) 3 1 correlate-shape #t)) + + (define corr-∇ + (ext2-∇ (correlate-3-1-∇ 12 6 2) (correlate-3-1-∇-acc 12 6 2) 3 1 correlate-shape #t)) + + (check-tensor-equal? (corr-ρ bank signal) + ;; Should be of size nb + (tensor (tensor 50.0 110.0 170.0 230.0) + (tensor 91.0 217.0 343.0 469.0) + (tensor 133.0 331.0 529.0 727.0) + (tensor 175.0 445.0 715.0 985.0) + (tensor 217.0 559.0 901.0 1243.0) + (tensor 110.0 362.0 614.0 866.0))) + + (let-values (((filter-∇ signal-∇) + (corr-∇ bank signal (tensor (tensor 1.0 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0 1.0))))) + (check-tensor-equal? filter-∇ + (tensor + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)))) + (check-tensor-equal? signal-∇ + ;; Should be of size nb + (tensor (tensor 88.0 96.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 104.0 112.0)))) + + (check-dual-equal? (d-correlate bank signal) + ;; Should be of size nb + (tensor (tensor 50.0 110.0 170.0 230.0) + (tensor 91.0 217.0 343.0 469.0) + (tensor 133.0 331.0 529.0 727.0) + (tensor 175.0 445.0 715.0 985.0) + (tensor 217.0 559.0 901.0 1243.0) + (tensor 110.0 362.0 614.0 866.0))) + + (let ((gs ((∇¹ d-correlate) bank signal))) + (check-dual-equal? (car gs) + (tensor + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)))) + (check-dual-equal? (cadr gs) + ;; Should be of size nb + (tensor (tensor 88.0 96.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 104.0 112.0))))) diff --git a/lazy/ext-ops/test/test-I-flatten.rkt b/lazy/ext-ops/test/test-I-flatten.rkt new file mode 100644 index 0000000..f7740cf --- /dev/null +++ b/lazy/ext-ops/test/test-I-flatten.rkt @@ -0,0 +1,13 @@ +(module+ test + (require rackunit) + (require (only-in "../tensors.rkt" tensor)) + (require (only-in "../autodiff.rkt" check-ρ-∇ check-dual-equal?)) + (require (only-in "A-scalar-ops.rkt" d*)) + + (define r2-t1 (tensor (tensor 3.0 4.0) (tensor 5.0 6.0))) + (define r1-t1 (tensor 3.0 4.0 5.0 6.0)) + + (check-dual-equal? (flatten-2 r2-t1) r1-t1) + (check-ρ-∇ ((λ (t1 t2) (d* t1 (flatten-2 t2))) r1-t1 r2-t1) + (tensor 9.0 16.0 25.0 36.0) + (list (tensor 3.0 4.0 5.0 6.0) (tensor (tensor 3.0 4.0) (tensor 5.0 6.0))))) diff --git a/lazy/ext-ops/test/test-K-concat.rkt b/lazy/ext-ops/test/test-K-concat.rkt new file mode 100644 index 0000000..aa405fc --- /dev/null +++ b/lazy/ext-ops/test/test-K-concat.rkt @@ -0,0 +1,131 @@ +(module+ test + (require rackunit) + (require (only-in "../tensors.rkt" tensor)) + (require (only-in "../autodiff.rkt" check-ρ-∇ check-dual-equal?)) + (require (only-in "A-scalar-ops.rkt" d*)) + + (define r2-t1 (tensor (tensor 3.0 4.0) (tensor 5.0 6.0))) + (define r1-t2 (tensor 5.0 6.0 7.0)) + (define r1-t1 (tensor 3.0 4.0 5.0 6.0 7.0)) + + (check-ρ-∇ (concat-1-1 r1-t2 r1-t1) + (tensor 5.0 6.0 7.0 3.0 4.0 5.0 6.0 7.0) + (list (tensor 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0 1.0 1.0))) + + (check-dual-equal? + (d-concat r2-t1 r1-t2) + (tensor (tensor 3.0 4.0 5.0 6.0 7.0) + (tensor 5.0 6.0 5.0 6.0 7.0))) + + (check-ρ-∇ ((λ (t1 t2 t3) (d* t3 (d-concat t1 t2))) r2-t1 r1-t2 r1-t1) + (tensor (tensor 9.0 16.0 25.0 36.0 49.0) + (tensor 15.0 24.0 25.0 36.0 49.0)) + (list (tensor (tensor 3.0 4.0) (tensor 3.0 4.0)) + (tensor 10.0 12.0 14.0) + (tensor 8.0 10.0 10.0 12.0 14.0))) + (define r3-t1 + (tensor (tensor (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0)) + + (tensor (tensor 9.0 10.0) + (tensor 11.0 12.0) + (tensor 13.0 14.0) + (tensor 15.0 16.0)) + + (tensor (tensor 17.0 18.0) + (tensor 19.0 20.0) + (tensor 21.0 22.0) + (tensor 23.0 24.0)))) + + + (define r2-t2 + (tensor (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0))) + + (define r1-t3 + (tensor 0.5 0.5)) + + (define concat-2 (d-concat-n 2)) + + (check-dual-equal? + (concat-2 r3-t1 r2-t2) + (tensor (tensor (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0) + (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0)) + + (tensor (tensor 9.0 10.0) + (tensor 11.0 12.0) + (tensor 13.0 14.0) + (tensor 15.0 16.0) + (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0)) + + (tensor (tensor 17.0 18.0) + (tensor 19.0 20.0) + (tensor 21.0 22.0) + (tensor 23.0 24.0) + (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0)))) + + + (check-ρ-∇ ((λ (t1 t2 t3) (d* t3 (concat-2 t1 t2))) r3-t1 r2-t2 r1-t3) + (tensor (tensor (tensor 0.5 1.0) + (tensor 1.5 2.0) + (tensor 2.5 3.0) + (tensor 3.5 4.0) + (tensor 0.5 1.0) + (tensor 1.5 2.0) + (tensor 2.5 3.0) + (tensor 3.5 4.0)) + + (tensor (tensor 4.5 5.0) + (tensor 5.5 6.0) + (tensor 6.5 7.0) + (tensor 7.5 8.0) + (tensor 0.5 1.0) + (tensor 1.5 2.0) + (tensor 2.5 3.0) + (tensor 3.5 4.0)) + + (tensor (tensor 8.5 9.0) + (tensor 9.5 10.0) + (tensor 10.5 11.0) + (tensor 11.5 12.0) + (tensor 0.5 1.0) + (tensor 1.5 2.0) + (tensor 2.5 3.0) + (tensor 3.5 4.0))) + (list + (tensor (tensor (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5)) + (tensor (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5)) + (tensor (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5))) + + (tensor (tensor 1.5 1.5) + (tensor 1.5 1.5) + (tensor 1.5 1.5) + (tensor 1.5 1.5)) + + (tensor 192.0 216.0)))) diff --git a/lazy/no-duals-no-overrides.rkt b/lazy/no-duals-no-overrides.rkt new file mode 100644 index 0000000..07ca22e --- /dev/null +++ b/lazy/no-duals-no-overrides.rkt @@ -0,0 +1,29 @@ +#lang racket/base + +(module+ test + (require rackunit)) + +(require "tensors.rkt") +(require "ext-ops.rkt") + +(define scalar? number?) + +(provide + ;; From tensors + len ref refr + + tref tlen list->tensor tensor build-tensor + + ext1-ρ ext2-ρ + + scalar? tensor? rank shape reshape trefs + + ;; From ext-ops + +-ρ --ρ *-ρ /-ρ rectify-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ + sum-ρ abs-ρ *-2-1-ρ argmax-ρ + max-ρ sum-cols-ρ correlate-ρ + concat-ρ concat-n-ρ flatten-ρ + + =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/lazy/no-duals.rkt b/lazy/no-duals.rkt new file mode 100644 index 0000000..cd1bcaf --- /dev/null +++ b/lazy/no-duals.rkt @@ -0,0 +1,29 @@ +#lang racket/base + +(module+ test + (require rackunit)) + +(require "tensors.rkt") +(require "ext-ops.rkt") + +(define scalar? number?) + +(provide + ;; From tensors + len ref refr + + tref tlen list->tensor tensor build-tensor + + ext1-ρ ext2-ρ + + scalar? tensor? rank shape reshape trefs + + ;; From ext-ops + (rename-out (+-ρ +) (--ρ -) (*-ρ *) (/-ρ /) (rectify-ρ rectify) + (exp-ρ exp) (log-ρ log) (expt-ρ expt) (sqrt-ρ sqrt) (sqr-ρ sqr) (zeroes-ρ zeroes) + (sum-ρ sum) (abs-ρ abs) (*-2-1-ρ *-2-1) (argmax-ρ argmax) + (max-ρ max) (sum-cols-ρ sum-cols) (correlate-ρ correlate) + (flatten-ρ flatten) (concat-ρ concat) (concat-n-ρ concat-n)) + + =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/lazy/no-overrides.rkt b/lazy/no-overrides.rkt new file mode 100644 index 0000000..05844b7 --- /dev/null +++ b/lazy/no-overrides.rkt @@ -0,0 +1,43 @@ +#lang racket/base + +(require + (except-in "tensors.rkt" + rank shape reshape trefs tref tensor? tlen ref refr)) + +(require "autodiff.rkt") +(require "ext-ops.rkt") + +(provide + len ref refr + + tref tlen list->tensor tensor build-tensor + + ext1-ρ ext2-ρ ext1-∇ ext2-∇ + + dual dual? ρ κ ∇ ∇¹ + + ext1 ext2 prim1 prim2 + + scalar? tensor? rank shape reshape trefs + + trace-print check-dual-equal? check-ρ-∇ + make-printable + + +-0-0 --0-0 *-0-0 /-0-0 expt-0-0 + exp-0 log-0 abs-0 rectify-0 sqrt-0 + flatten-2 concat-1-1 + + d+ d- d* d/ d-rectify + d-exp d-log d-expt d-sqrt d-sqr + d-sum d-abs d*-2-1 d-argmax + d-max d-sum-cols d-correlate + d-flatten d-concat d-concat-n + + +-ρ --ρ *-ρ /-ρ rectify-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ + sum-ρ abs-ρ *-2-1-ρ argmax-ρ + max-ρ sum-cols-ρ correlate-ρ + flatten-ρ concat-ρ concat-n-ρ + + =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/lazy/tensors.rkt b/lazy/tensors.rkt new file mode 100644 index 0000000..330a31a --- /dev/null +++ b/lazy/tensors.rkt @@ -0,0 +1,23 @@ +#lang racket +(require "tensors/0-lazy.rkt") +(require "tensors/1-reflect.rkt") +(require "tensors/A-equality.rkt") + +(provide start-vector-manager vector-manager-report) + +(provide tolerance tensor-equal? check-tensor-equal?) + +(provide len ref refr) +(provide tref tlen list->tensor tensor build-tensor trefs) + +(provide ext1-ρ ext2-ρ ext1-∇ ext2-∇) + +(provide ↓ scalarize) + +(provide print-compiler? compiler-cache) + +;; These will get overriden by duals +(provide tensor?) +(provide rank shape reshape size-of) + +(provide force*1 force*2) diff --git a/lazy/tensors/0-lazy.rkt b/lazy/tensors/0-lazy.rkt new file mode 100644 index 0000000..3cc6ff4 --- /dev/null +++ b/lazy/tensors/0-lazy.rkt @@ -0,0 +1,346 @@ +#lang racket +(require "../../accelerated-tensors/ext-impl.rkt") +(require (prefix-in acc: "../../accelerated-tensors/tensors.rkt")) + +(require "c0-ast.rkt") +(require (only-in "c1-racket-runtime.rkt" ext2-∇-result)) + +#| +Questions: + +* How do I create a preallocated function example for ext2-∇? Also, how do the preallocated functions work? +A. Here are what the formal parameters of a binary preallocated ∇-function mean: + + - g0, g1: These are the empty gradient tensor stores which need to be filled +with the gradients of the corresponding ρ-function w.r.t. the first and second +arguments respectively. + + - t, it, st: the store, beginning offset and total size of the flat +representation of the first tensor argument respectively which we will need to loop through +the scalar elements. + + - u, iu, su: the store, beginning offset and total size of the flat +representation of the second tensor argument respectively which we will need to +loop through the scalar elements. + + - z, iz, sz: the store, beginning offset and total size of the flat +representation of the accumulator respectively which we will need to +loop through the scalar elements. + + Here are the invariants of the formal parameters: + + - The flat tensors corresponding to the stores g0 and g1 have the same shape +as the first and second input tensors respectively + + - The flat tensor corresponding to the store z has the same shape as the +result of invoking the corresponding ρ-function with the two tensor arguments + +* Here are a few problems which need to be addressed while checking compiler +invariants after compiling the expression "(((l2-loss plane) r2d1 (tensor 1.0 +1.0)) plane-theta-0)" from the test-C-loss.rkt file: + + - The env contains flat tensors with the shape '() i.e. they scalars in +disguise + + - Somehow copies of a flat tensor are being added to the env rather than the +instructions refering to the same gensym variable + +|# + +#; +(: scalar? (-> Any Boolean)) +(define scalar? number?) + +#; +(: tensor (case-> (-> tpromise * tpromise) + (-> Number * tpromise))) +(define tensor + (λ args + (list->tpromise args))) +#; +(: ensure-shape (-> (U (Listof tpromise) (Listof Number)) Void)) +(define ensure-shape + (λ (args) + (when (null? args) + (error 'tensor "Tensors cannot be empty")) + (let ((checked-shape + (λ (x) (if (tpromise? x) + (tpromise-shape x) + '()))) + (scalar-like? + (λ (x) + (or (number? x) + (and (tpromise? x) + (null? (tpromise-shape x))))))) + (unless (and (not (null? args)) + (cond + ((scalar-like? (car args)) + (andmap scalar-like? (cdr args))) + ((tpromise? (car args)) + (let ((s (checked-shape (car args)))) + (andmap (λ (t) + (and (tpromise? t) + (equal? (checked-shape t) s))) + (cdr args)))) + (else #f))) + (error 'tensor + "Cannot construct a tensor out of these elements: ~a~%" + args))))) + +#; +(: tensor-inner-flat (-> (Listof (U tpromise Number)) + (U flat tcomp-list->tensor))) +(define tensor-inner-flat + (λ (lst) + (cond + [(andmap number? lst) (apply acc:tensor lst)] + [(andmap tpromise-flat? lst) + (apply acc:tensor + (for/list ((tp-flat lst)) + (car (unbox (tpromise-dst tp-flat)))))] + [else lst]))) + +(define list->tpromise + (λ (lst) + (ensure-shape lst) + (let ((inner-tensor (tensor-inner-flat lst))) + (cond + ((flat? inner-tensor) + (tpmake-flat inner-tensor)) + (else + (let* ((inner-shape (tp-shape (car lst))) + (outer (length lst)) + (new-shape (cons outer inner-shape))) + (tpmake-list->tensor inner-tensor new-shape))))))) + +(define bounded-idx*^ + (λ (shape idx*) + (match `(,shape ,idx*) + [`(,_ ()) #t] + [`(() ,_) #f] + [`((,sa . ,sd) (,ia . ,id*)) + (and (< ia sa) + (>= ia 0) + (bounded-idx*^ sd id*))]))) + +(define bounded-idx*? + (λ (tp idx*) + (bounded-idx*^ (tpromise-shape tp) idx*))) + +(define tp-tref + (lambda (tp i) + (cond + [(bounded-idx*? tp (list i)) + (tpmake-tref tp i (cdr (tpromise-shape tp)))] + [else (error 'exn:tp-tref + (string-append + "Index out of bounds. ~a " + "greater than or equals length ~a~%") + i + (tp-tlen tp))]))) + +(define tp-tlen + (λ (tp) + (car (tpromise-shape tp)))) + +(define tp-shape + (lambda (v) + (cond + [(tpromise? v) (tpromise-shape v)] + [else (acc:shape v)]))) + +(define build-tpromise + (λ (s f) + (tpmake-flat (acc:build-tensor s f)))) + +(define tp-trefs + (λ (tp b) + (cond + [(ormap (λ (i) + (>= i + (car (tpromise-shape tp)))) + b) + (error 'tp-trefs + "An index was out of bounds")] + [else + (tpmake-trefs tp b + `(,(length b) + . ,(cdr (tpromise-shape tp))))]))) + +;; Default arguments shape-fn and expects-prealloc? need not be passed when f is +;; a function on scalars and doesn't expect a preallocated output vector as its +;; argument. The signature argument is only supposed to be passed within the +;; definition of ext1 and ext2 functions in B-prims.rkt. +(define tp-ext1-ρ + (let ((id -1)) + (λ (f f-acc m + [shape-fn scalar-shape] + [expects-prealloc? #f] + [prim-sign (begin + (set! id (add1 id)) + (string-append "re1" (~r id #:base 16)))]) + (λ (tp) + (let* ((in-shape (tp-shape tp)) + (base-shape (min-shape m in-shape)) + (shape-fn-out (shape-fn base-shape)) + (out-shape (merge-shapes in-shape m shape-fn-out))) + (cond + [(scalar? tp) (f tp)] + [(and (tpromise? tp) + (null? (tpromise-shape tp))) + (tpmake-ext1-ρ-scalar f f-acc prim-sign tp out-shape)] + [expects-prealloc? + (tpmake-ext1-ρ f f-acc prim-sign m shape-fn tp out-shape)] + [else + (let ((flat-f (functional->preallocated-1-ρ f base-shape shape-fn-out)) + (flat-f-acc (functional->preallocated-1-ρ-acc f-acc base-shape shape-fn-out))) + (tpmake-ext1-ρ flat-f flat-f-acc prim-sign m shape-fn tp out-shape))])))))) + +;; See comment for tp-ext1-ρ +(define tp-ext2-ρ + (let ((id -1)) + (λ (f f-acc m n + [shape-fn scalar-shape] + [expects-prealloc? #f] + [prim-sign (begin + (set! id (add1 id)) + (string-append "re2" (~r id #:base 16)))]) + (λ (tp-t tp-u) + (let* ((s0 (tp-shape tp-t)) + (s1 (tp-shape tp-u)) + (sf0 (min-shape m s0)) + (sf1 (min-shape n s1)) + (sf-out (shape-fn sf0 sf1))) + (cond + ((and (number? tp-t) (number? tp-u)) + (f tp-t tp-u)) + [(and (tpromise? tp-t) (tpromise? tp-u) + (null? (tpromise-shape tp-t)) + (null? (tpromise-shape tp-u))) + (tpmake-ext2-ρ-scalar f f-acc prim-sign tp-t tp-u sf-out)] + [expects-prealloc? + (tpmake-ext2-ρ + tp-t tp-u + f f-acc prim-sign m n shape-fn + (ext2-shapes s0 s1 m n sf-out + (λ (s-out . _) s-out)))] + [else + (let ((flat-f (functional->preallocated-2-ρ + f sf0 sf1 sf-out)) + (flat-f-acc (functional->preallocated-2-ρ-acc + f-acc sf0 sf1 sf-out))) + (tpmake-ext2-ρ + tp-t tp-u + flat-f flat-f-acc prim-sign m n shape-fn + (ext2-shapes s0 s1 m n sf-out + (λ (s-out . _) s-out))))])))))) + +(define scalar-shape + (λ (s0 [s1 '()]) '())) + +;; See comment for tp-ext1-ρ +(define tp-ext1-∇ + (let ((id -1)) + (λ (f f-acc m + [shape-fn scalar-shape] + [expects-prealloc? #f] + [prim-sign (begin + (set! id (add1 id)) + (string-append "ne1" (~r id #:base 16)))]) + (λ (tp zp) + ;; + (cond + ((number? tp) (f tp zp)) + (expects-prealloc? + (tpmake-ext1-∇ tp zp f f-acc prim-sign m shape-fn (tp-shape tp))) + (else + (let* ((in-shape (tpromise-shape tp)) + (base-shape (min-shape m in-shape)) + (out-shape (shape-fn base-shape)) + (flat-f (functional->preallocated-1-∇ f base-shape out-shape)) + (flat-f-acc (functional->preallocated-1-∇-acc f-acc base-shape out-shape))) + (tpmake-ext1-∇ tp zp flat-f flat-f-acc prim-sign m shape-fn (tp-shape tp))))))))) + +;; See comment for tp-ext1-ρ +(define tp-ext2-∇ + (let ((id -1)) + (λ (f f-acc m n + [shape-fn scalar-shape] + [expects-prealloc? #f] + [prim-sign (begin + (set! id (add1 id)) + (string-append "ne2" (~r id #:base 16)))]) + (let ((tp-f + (λ (f f-acc tp-t tp-u tp-z) + (tp-d-ext2^ f f-acc prim-sign m n shape-fn + tp-t tp-u tp-z)))) + (λ (tp-t tp-u tp-z) + (cond + (expects-prealloc? + (tp-f f f-acc tp-t tp-u tp-z)) + [else (let* ((t-shape (min-shape m (tp-shape tp-t))) + (u-shape (min-shape n (tp-shape tp-u))) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-∇ + f t-shape u-shape out-shape)) + (flat-f-acc (functional->preallocated-2-∇-acc + f-acc t-shape u-shape out-shape))) + (tp-f flat-f flat-f-acc tp-t tp-u tp-z))])))))) + +(define tp-d-ext2^ + (λ (fᵈ fᵈ-acc sign r0 r1 shape-fn tp-t0 tp-t1 tp-z) + (let* ((out-ref0 (ext2-∇-result (tcomp-ds-ref #f))) + (out-ref1 (ext2-∇-result (tcomp-ds-ref #f)))) + (values + (tpmake-ext2-∇ fᵈ fᵈ-acc sign r0 r1 shape-fn + tp-t0 tp-t1 tp-z out-ref0 out-ref1 0 (tp-shape tp-t0)) + (tpmake-ext2-∇ fᵈ fᵈ-acc sign r0 r1 shape-fn + tp-t0 tp-t1 tp-z out-ref0 out-ref1 1 (tp-shape tp-t1)))))) + +(define tp-rank + (λ (tp) + (acc:len (tp-shape tp)))) + +(define tp-reshape + (λ (s tp) + (cond + ((and (tpromise? tp) (= (acc:size-of s) (acc:size-of (tpromise-shape tp)))) + (tpmake-reshape tp s)) + [(and (acc:flat? tp) (= (acc:size-of s) (acc:size-of (acc:shape tp)))) + (acc:reshape s tp)] + (else (error 'shape-error "Cannot reshape ~a to ~a~%" tp s))))) + +(define tensor? + (lambda (tp) + (or (tpromise? tp) (flat? tp) (scalar? tp)))) + +(include "test/test-0-lazy.rkt") + +(provide start-vector-manager vector-manager-report) + +(provide (rename-out + (acc:len len) + (acc:ref ref) + (acc:refr refr))) +(provide tensor + tpromise? + (rename-out + (tp-tref tref) + (tp-tlen tlen) + (list->tpromise list->tensor) + (build-tpromise build-tensor) + (tp-trefs trefs))) + +(provide (rename-out + (tp-ext1-ρ ext1-ρ) + (tp-ext2-ρ ext2-ρ) + (tp-ext1-∇ ext1-∇) + (tp-ext2-∇ ext2-∇))) + +;; These will get overriden by duals +(provide tensor?) +(provide (rename-out + (tp-rank rank) + (tp-shape shape) + (tp-reshape reshape) + (acc:size-of size-of))) diff --git a/lazy/tensors/1-reflect.rkt b/lazy/tensors/1-reflect.rkt new file mode 100644 index 0000000..bc2a0c9 --- /dev/null +++ b/lazy/tensors/1-reflect.rkt @@ -0,0 +1,64 @@ +#lang racket +(require "../../accelerated-tensors/ext-impl.rkt") +(require (prefix-in acc: "../../accelerated-tensors/tensors.rkt")) +(require "c0-ast.rkt") +(require (only-in "c3-compiler.rkt" + compiler-cache + print-compiler? + get-compiled + compile-tensor)) +(require (only-in "c2-interpreter.rkt" interp-racket)) + +(define ↓ + (lambda (tp) + (match tp + [(tpromise v _ _ _) + #:when (number? v) + v] + [(? tpromise-flat?) + (car (unbox (tpromise-dst tp)))] + [(tpromise t _ _ _) + #:when (tcomp? t) + (let-values (((instrs data-segment) (compile-tensor tp))) + (let ((res (interp-racket instrs data-segment))) + (cond + ((flat? res) + (set-tpromise-tensor! tp (tcomp-ds-ref #f)) + (set-box! (tpromise-dst tp) (list res)) + (set-box! (tpromise-sign tp) (list #"dsr"))) + ((number? res) + (set-tpromise-tensor! tp res) + (set-box! (tpromise-dst tp) (list)) + (set-box! (tpromise-sign tp) (list #"s" (number->bytes res))))) + res))] + ;; NOTE: This case runs when we use tp-scalarize to turn + ;; the tensor to a scalar + (_ tp)))) + +;; We may have to replace tp-scalarize with scalarize from flat-tensors, because +;; the ↓ used in its definition is undesirable. +(define tp-scalarize + (λ (tp) + (cond + [(and (tpromise? tp) (null? (tpromise-shape tp))) + (tp-scalarize (↓ tp))] + [(and (acc:flat? tp) (null? (acc:flat-shape tp))) + (vector-ref (acc:flat-store tp) 0)] + [else tp]))) + +(define force*1 + (λ (t f) + (f (↓ t)))) + +(define force*2 + (λ (ts f) + (let-values (((t1 t2) (ts))) + (f (↓ t1) (↓ t2))))) + +(include "test/test-1-reflect.rkt") + +(provide ↓ force*1 force*2) + +(provide print-compiler? compiler-cache get-compiled + (rename-out + (tp-scalarize scalarize))) diff --git a/lazy/tensors/A-equality.rkt b/lazy/tensors/A-equality.rkt new file mode 100644 index 0000000..8cfa8ba --- /dev/null +++ b/lazy/tensors/A-equality.rkt @@ -0,0 +1,18 @@ +#lang racket + +(require "1-reflect.rkt") +(require (prefix-in acc: "../../accelerated-tensors/tensors.rkt")) + +(define tp-tensor-equal? + (λ (tp-actual tp-expected) + (acc:tensor-equal? (↓ tp-actual) (↓ tp-expected)))) + +(require rackunit) +(define-binary-check (tp-check-tensor-equal? tp-tensor-equal? actual expected)) + +(include "test/test-A-equality.rkt") + +(provide (rename-out + (acc:tolerance tolerance) + (tp-tensor-equal? tensor-equal?) + (tp-check-tensor-equal? check-tensor-equal?))) diff --git a/lazy/tensors/B-test-programs.rkt b/lazy/tensors/B-test-programs.rkt new file mode 100644 index 0000000..f301c73 --- /dev/null +++ b/lazy/tensors/B-test-programs.rkt @@ -0,0 +1,489 @@ +#lang racket +(require string-interpolation) +(require "0-lazy.rkt") +(require "../../accelerated-tensors/ext-impl.rkt") +(require (prefix-in acc: "../../accelerated-tensors/tensors.rkt")) + +(define make-tref-test-program + (λ (t) + (tref t 2))) +(define make-list->tensor-test-program + (λ (l) + (list->tensor l))) + +(struct test-program-data (prog-thunk eval-res) #:transparent) +(struct eval-res-1 (res) #:transparent) +(struct eval-res-2 (res1 res2) #:transparent) + + +;; Care must be taken while calling get-test-program within the +;; test-program-data thunk because it might lead to an infinite loop. +(define test-programs + (hasheqv + 'tensor-r1-0 (test-program-data + (λ () + (tensor 1 2 3)) + (eval-res-1 (acc:tensor 1 2 3))) + 'tensor-r1-1 (test-program-data + (λ () + (tensor 1 2 3 4 5)) + (eval-res-1 (acc:tensor 1 2 3 4 5))) + 'tensor-r1-2 (test-program-data + (λ () + (tensor 3.0 4.0 5.0)) + (eval-res-1 (acc:tensor 3.0 4.0 5.0))) + 'tensor-r2-0 (test-program-data + (λ () + (tensor (tensor 1 2 3) (tensor 4 5 6))) + (eval-res-1 (acc:tensor (acc:tensor 1 2 3) (acc:tensor 4 5 6)))) + 'tensor-r2-1 (test-program-data + (λ () + (reshape '(2 3) (tensor 3.0 4.0 5.0 7.0 8.0 9.0))) + (eval-res-1 + (acc:reshape '(2 3) (acc:tensor 3.0 4.0 5.0 7.0 8.0 9.0)))) + 'build-tensor-r1-0 (test-program-data + (λ () + (build-tensor '(6) + (λ (i) (* 3.0 (car i))))) + (eval-res-1 (acc:build-tensor '(6) + (λ (i) (* 3.0 (car i)))))) + 'build-tensor-r2-0 (test-program-data + (λ () + (build-tensor '(5 6) + (λ (i) + (match-define `(,x ,y) i) + (* 2.0 (+ (* x 6) y))))) + (eval-res-1 (acc:build-tensor '(5 6) + (λ (i) + (match-define `(,x ,y) i) + (* 2.0 (+ (* x 6) y)))))) + 'build-tensor-r2-1 (test-program-data + (λ () + (build-tensor '(3 6) + (λ (i) + (match-define `(,x ,y) i) + (* 3.0 (+ (* x 6) y))))) + (eval-res-1 (acc:build-tensor '(3 6) + (λ (i) + (match-define `(,x ,y) i) + (* 3.0 (+ (* x 6) y)))))) + 'build-tensor-r3-0 (test-program-data + (λ () + (build-tensor '(2 3 4) + (λ (i) + (match-define `(,x ,y ,z) i) + (* 2 (+ (* x 12) (* y 4) (* 1 z)))))) + (eval-res-1 (acc:build-tensor + '(2 3 4) + (λ (i) + (match-define `(,x ,y ,z) i) + (* 2 (+ (* x 12) (* y 4) (* 1 z))))))) + 'build-tensor-r3-1 (test-program-data + (λ () + (build-tensor '(3 5 6) + (λ (i) + (match-define `(,x ,y ,z) i) + (* 2.0 (+ (* x 30) (* y 6) (* 1 z)))))) + (eval-res-1 (acc:build-tensor + '(3 5 6) + (λ (i) + (match-define `(,x ,y ,z) i) + (* 2.0 (+ (* x 30) (* y 6) (* 1 z))))))) + 'extract-ds-once-tref (test-program-data + (λ () + (let ((n (tref (get-test-program 'tensor-r1-0) 1))) + (+-ρ n n))) + (eval-res-1 4)) + 'extract-ds-once-trefs (test-program-data + (λ () + (let ((tp (trefs (get-test-program 'tensor-r1-0) + '(0 2)))) + (+-ρ tp tp))) + (eval-res-1 (acc:tensor 2 6))) + 'built-tensor (test-program-data + (λ () + (let ((test-build-shape '(4 3))) + (build-tensor test-build-shape + (λ (i) + (let ([row (car i)] + [column (cadr i)]) + (+ (* (sub1 (car test-build-shape)) + row) + column)))))) + (eval-res-1 (acc:tensor (acc:tensor 0 1 2) + (acc:tensor 3 4 5) + (acc:tensor 6 7 8) + (acc:tensor 9 10 11)))) + 'multi-built-tensor (test-program-data + (λ () + (+-ρ (get-test-program 'build-tensor-r2-0) + (tref (get-test-program 'build-tensor-r3-1) 0))) + (eval-res-1 ((acc:ext2-ρ * (λ (a b) "@{a} * @{b}") 0 0) 2 + (acc:build-tensor + '(5 6) + (λ (i) + (match-define `(,x ,y) i) + (* 2.0 (+ (* x 6) y))))))) + 'tcomp-tref (test-program-data + (λ () + (make-tref-test-program (get-test-program 'tensor-r1-0))) + (eval-res-1 3)) + 'tcomp-tref-nested (test-program-data + (λ () + (tref (tref (get-test-program 'tensor-r2-0) 0) 2)) + (eval-res-1 3)) + 'tcomp-list->tensor (test-program-data + (λ () + (make-list->tensor-test-program '(5 6 7 8))) + (eval-res-1 (acc:tensor 5 6 7 8))) + 'tcomp-nested-list->tensor (test-program-data + (λ () + (list->tensor + `(,(get-test-program 'tensor-r1-0) + ,(get-test-program 'tensor-r1-0) + ,(get-test-program 'tensor-r1-0)))) + (eval-res-1 (acc:tensor + (acc:tensor 1 2 3) + (acc:tensor 1 2 3) + (acc:tensor 1 2 3)))) + 'tcomp-trefs (test-program-data + (λ () + (trefs (get-test-program 'built-tensor) '(0 2))) + (eval-res-1 (acc:tensor (acc:tensor 0 1 2) + (acc:tensor 6 7 8)))) + 'tcomp-reshape (test-program-data + (λ () + (reshape '(3 2 1) + (trefs (get-test-program 'built-tensor) '(1 3)))) + (eval-res-1 (acc:tensor (acc:tensor (acc:tensor 3) + (acc:tensor 4)) + (acc:tensor (acc:tensor 5) + (acc:tensor 9)) + (acc:tensor (acc:tensor 10) + (acc:tensor 11))))) + 'sum (test-program-data + (λ () + (sum (get-test-program 'tensor-r2-0))) + (eval-res-1 (acc:tensor 6.0 15.0))) + 'sum-nested (test-program-data + (λ () + (tensor 4.0 (sum (tensor 1 2 3)) 5.0)) + (eval-res-1 (acc:tensor 4.0 6.0 5.0))) + 'id (test-program-data + (λ () + (id-ρ (get-test-program 'tensor-r2-0))) + (eval-res-1 (acc:tensor (acc:tensor 1 2 3) + (acc:tensor 4 5 6)))) + 'id-scalar (test-program-data + (λ () + (id-ρ (sum (tensor 4 5 6)))) + (eval-res-1 15)) + 'abs-scalar (test-program-data + (λ () + (abs-ρ (tref (tensor 4 -5 6) 1))) + (eval-res-1 5)) + 'sqr (test-program-data + (λ () + (*-ρ (get-test-program 'build-tensor-r3-0) + (get-test-program 'build-tensor-r3-0))) + (eval-res-1 (acc:reshape + '(2 3 4) + (acc:tensor + 0 4 16 36 + 64 100 144 196 + 256 324 400 484 + 576 676 784 900 + 1024 1156 1296 1444 + 1600 1764 1936 2116)))) + 'r-1-2 (test-program-data + (λ () + (*-2-1 (get-test-program 'build-tensor-r2-0) + (get-test-program 'build-tensor-r1-0))) + (eval-res-1 (acc:reshape + '(5 6) + (acc:tensor + 0 6.0 24.0 54.0 96.0 150.0 + 0 42.0 96.0 162.0 240.0 330.0 + 0 78.0 168.0 270.0 384.0 510.0 + 0 114.0 240.0 378.0 528.0 690.0 + 0 150.0 312.0 486.0 672.0 870.0)))) + 'r-3-4 (test-program-data + (λ () + (*-2-1 (get-test-program 'build-tensor-r3-1) + (get-test-program 'build-tensor-r2-1))) + (eval-res-1 (acc:reshape + '(3 5 6) + (acc:tensor + 0 6.0 24.0 54.0 96.0 150.0 + 0 42.0 96.0 162.0 240.0 330.0 + 0 78.0 168.0 270.0 384.0 510.0 + 0 114.0 240.0 378.0 528.0 690.0 + 0 150.0 312.0 486.0 672.0 870.0 + + 1080.0 1302.0 1536.0 1782.0 2040.0 2310.0 + 1296.0 1554.0 1824.0 2106.0 2400.0 2706.0 + 1512.0 1806.0 2112.0 2430.0 2760.0 3102.0 + 1728.0 2058.0 2400.0 2754.0 3120.0 3498.0 + 1944.0 2310.0 2688.0 3078.0 3480.0 3894.0 + + 4320.0 4758.0 5208.0 5670.0 6144.0 6630.0 + 4752.0 5226.0 5712.0 6210.0 6720.0 7242.0 + 5184.0 5694.0 6216.0 6750.0 7296.0 7854.0 + 5616.0 6162.0 6720.0 7290.0 7872.0 8466.0 + 6048.0 6630.0 7224.0 7830.0 8448.0 9078.0)))) + 'r-sum-2-scalar (test-program-data + (λ () + (*-ρ (sum (get-test-program 'build-tensor-r1-0)) + (sum (tensor 2 3 4)))) + (eval-res-1 405.0)) + 'tcomp-dsqr-r1 (test-program-data + (λ () + (d-sqr r1-td (one-like r1-td))) + (eval-res-1 (acc:tensor 6.0 8.0 10.0))) + 'gsqr (test-program-data + (λ () + (let ([r2-td (get-test-program 'tensor-r2-1)]) + (d-sqr r2-td (one-like r2-td)))) + (eval-res-1 (acc:reshape + '(2 3) + (acc:tensor 6.0 8.0 10.0 14.0 16.0 18.0)))) + 'g+ (test-program-data + (λ () + (d+ 2.0 3.0 1.0)) + (eval-res-2 1.0 1.0)) + 'g-twice (test-program-data + (λ () + (d+ r1-td r1-td (one-like r1-td))) + (eval-res-2 (acc:tensor 1.0 1.0 1.0) + (acc:tensor 1.0 1.0 1.0))) + 'g+-r1-r2 (test-program-data + (λ () + (let ((r2-td (get-test-program 'tensor-r2-1))) + (d+ r1-td r2-td (one-like r2-td)))) + (eval-res-2 (acc:tensor 2.0 2.0 2.0) + (acc:reshape + '(2 3) + (acc:tensor 1.0 1.0 1.0 1.0 1.0 1.0)))) + 'g* (test-program-data + (λ () + (*∇ (tensor 2.0 3.0 4.0) + (tensor 1.0 2.0 3.0) + (tensor 1.0 1.0 1.0))) + (eval-res-2 (acc:tensor 1.0 2.0 3.0) + (acc:tensor 2.0 3.0 4.0))) + 'gsum-r1 (test-program-data + (λ () + (sum-∇ (tensor 2.0 3.0 4.0) + 1.0)) + (eval-res-1 (acc:tensor 1.0 1.0 1.0))) + 'gsum-r2 (test-program-data + (λ () + (sum-∇ (tensor (tensor 2.0 3.0 4.0) + (tensor 2.0 3.0 4.0)) + (tensor 2.0 1.0))) + (eval-res-1 (acc:tensor (acc:tensor 2.0 2.0 2.0) + (acc:tensor 1.0 1.0 1.0)))) + 'gs2-r1 (test-program-data + (λ () + (s2-∇ (tensor 2.0 3.0 4.0) + (tensor 1.0 2.0 3.0) + (tensor 1.0 1.0))) + (eval-res-2 (acc:tensor 1.0 1.0 1.0) + (acc:tensor 1.0 1.0 1.0))) + 'gs2-r3 (test-program-data + (λ () + (s2-∇ (tensor (tensor (tensor 1.0 2.0 6.0) + (tensor 3.0 4.0 6.0)) + (tensor (tensor 5.0 6.0 6.0) + (tensor 7.0 8.0 6.0)) + (tensor (tensor 8.0 7.0 6.0) + (tensor 6.0 5.0 6.0))) + (tensor (tensor (tensor 6.0 8.0 6.0) + (tensor 3.0 4.0 6.0)) + (tensor (tensor 9.0 7.0 6.0) + (tensor 8.0 2.0 6.0)) + (tensor (tensor 9.0 7.0 6.0) + (tensor 5.0 1.0 6.0))) + (tensor (tensor (tensor 1.0 1.0) + (tensor 1.0 1.0)) + (tensor (tensor 1.0 1.0) + (tensor 1.0 1.0)) + (tensor (tensor 1.0 1.0) + (tensor 1.0 1.0))))) + (eval-res-2 (acc:reshape '(3 2 3) + (acc:list->tensor (make-list 18 1.0))) + (acc:reshape '(3 2 3) + (acc:list->tensor (make-list 18 1.0))))) + 'env-flat-scalar (test-program-data + (λ () + ((λ (theta) (*-ρ (list-ref theta 0) (list-ref theta 1))) + (list (tensor 1.0) 3.0))) + (eval-res-1 (acc:tensor 3.0))) + 'common-subexpression (test-program-data + (λ () + (let ((t (tref (tensor 1 2 3) 0))) + (tensor t t))) + (eval-res-1 (acc:tensor 1.0 1.0))) + 'nested-common-subexpression (test-program-data + (λ () + (let ((t1 (tref (tensor (tensor 1 2 3) + (tensor 4 5 6)) + 0))) + (let ((t0 (tref t1 0))) + (tensor t0 t0)))) + (eval-res-1 (acc:tensor 1.0 1.0))) + )) + +(define get-test-program + (λ (name) + ((test-program-data-prog-thunk (hash-ref test-programs name))))) +(define get-test-eval-res + (λ (name) + (test-program-data-eval-res (hash-ref test-programs name)))) + +(define sum-f + (λ (in-v iᵢ sᵢ out-v iₒ sₒ) + (vset! out-v iₒ + (for/fold ([sum 0.0]) ([i (in-range iᵢ (+ iᵢ sᵢ))]) + (+ sum (vref in-v i)))))) +(define sum-f-acc + (λ (v0 i0 stride0 + v-out i-out stride-out) + #<tensor tcomp (lst) #:transparent) +#; +(: s (Listof Natural)) ;; non-empty +#; +(: f (-> (Listof Natural) Number)) +#; +(: tp tpromise) +#; +(: i Natural) +(struct tcomp-tref tcomp (tp i) #:transparent) +#; +(: tp tpromise) +#; +(: i (Listof Natural)) +(struct tcomp-trefs tcomp (tp b) #:transparent) +#; +(: fᵈ (U (-> Number Number (Values Number Number)) + (-> (Vector Number) Natural (Listof Natural) + (Vector Number) Natural (Listof Natural) + (Vector Number) Natural (Listof Natural)))) +(struct tcomp-ext1-ρ-scalar tcomp (f f-acc sign tp) #:transparent) +(struct tcomp-ext1-ρ tcomp (f f-acc sign m shape-fn tp) #:transparent) +(struct tcomp-ext2-ρ-scalar tcomp (f f-acc sign tp-t tp-u) #:transparent) +(struct tcomp-ext2-ρ tcomp (tp-t tp-u f f-acc sign m n shape-fn) #:transparent) +(struct tcomp-ext1-∇ tcomp (tp zp f f-acc sign m shape-fn) #:transparent) +(struct tcomp-ext2-∇ tcomp (fᵈ fᵈ-acc + sign r0 r1 shape-fn + tp-t0 tp-t1 tp-z + out-ref0 out-ref1 i) + #:transparent) +(struct tcomp-prim1-ρ tcomp (f f-acc sign shape-fn tp) #:transparent) +(struct tcomp-prim1-∇ tcomp (f f-acc sign shape-fn tp zp) #:transparent) +(struct tcomp-prim2-ρ tcomp (f f-acc sign shape-fn tp-t tp-u) #:transparent) +(struct tcomp-prim2-∇ tcomp (f f-acc sign shape-fn tp-t tp-u zp out-ref0 out-ref1 i) #:transparent) +(struct tcomp-reshape tcomp (s tp) #:transparent) +(struct tcomp-let tcomp (lhs rhs body) #:transparent) +(struct tcomp-var tcomp (name) #:transparent) +(struct tcomp-ds-ref tcomp (index) #:transparent) + +(struct tpromise ((tensor #:mutable) shape dst (sign #:mutable)) + #:guard + (λ (tensor shape data-segment-tree signature name) + (unless (or (tcomp? tensor) (number? tensor)) + (error 'make-tpromise + (string-append + "First argument must be either a" + " number or a tcomp. Got ~a") + tensor)) + (unless ((listof positive-integer?) shape) + (error 'make-tpromise + (string-append + "Second argument must be a list" + " of positive integers. Got ~a") + shape)) + (unless (and (box? data-segment-tree) + (list? (unbox data-segment-tree))) + (error 'make-tpromise + (string-append + "Third argument must be a box containing a list. Got ~a") + data-segment-tree)) + (unless (and (box? signature) (list? (unbox signature))) + (error 'make-tpromise + (string-append + "Fourth argument must be a box containing a list." + " Got ~a") + signature)) + (values tensor shape data-segment-tree signature)) + #:transparent) + +(define dst->data-segment + (λ (dst) + (apply vector-append (map dst-member->ds (unbox dst))))) + +(define dst-member->ds + (λ (dstm) + (cond + ((or (eqv? dstm 'uncalculated) + (number? dstm) (flat? dstm)) + (vector dstm)) + ((box? dstm) (dst->data-segment dstm)) + (else (error 'malformed-dst-member "Invalid signature. Got ~a" dstm))))) + +(define gdst-list->tensor + (λ (lst) + (box + (for/list ((l lst) + #:when (tpromise? l)) + (tpromise-dst l))))) + +(define gdst-tref + (λ (tp i) + (box (list (tpromise-dst tp) i)))) + +(define gdst-trefs + (λ (tp i-lst) + (box (list (tpromise-dst tp) (acc:list->tensor i-lst))))) + +(define gdst-ext2-∇ + (λ (tp-t0 tp-t1 tp-z) + (let ((dsn0 (tpromise-dst tp-t0)) + (dsn1 (tpromise-dst tp-t1)) + (dsnz (tpromise-dst tp-z))) + (box (list dsn0 dsn1 dsnz 'uncalculated))))) + +(define sign + (λ (ss) + (let ((xxh32-ctx (make-xxh32))) + (xxh32-reset! xxh32-ctx 0) + (sign-traverse-list! (unbox ss) xxh32-ctx) + (xxh32-digest xxh32-ctx)))) + +(define sign-traverse-list! + (λ (ss ctx) + (for ((s ss)) + (sign-traverse-member! s ctx)))) + +(define sign-traverse-member! + (λ (s ctx) + (cond + ((bytes? s) (xxh32-update! ctx (bytes-append #"_" s))) + ((box? s) (sign-traverse-list! (unbox s) ctx)) + (else (error 'malformed-sign-member "Invalid signature. Got ~a" s))))) + +(define number->bytes + (λ (n) + (string->bytes/utf-8 (number->string n)))) + +(define string->bytes string->bytes/utf-8) + +(define gs-list->tensor + (λ (lst) + (box (list* #"l>t" (map (λ (l) + (cond + ((tpromise? l) (tpromise-sign l)) + ((number? l) (bytes-append #"s_" (number->bytes l))) + (else (error 'gs-list->tensor "Unexpected: ~a" l)))) + lst))))) + +(define gs-tref + (λ (tp) + (box (list #"tr" (tpromise-sign tp) #"dsr")))) + +(define gs-trefs + (λ (tp) + (box (list #"trs" (tpromise-sign tp) #"dsr")))) + +(define gs-ext1-ρ-scalar + (λ (signature tp) + (box (list #"e1rs" (string->bytes signature) (tpromise-sign tp))))) + +(define gs-ext1-ρ + (λ (signature m tp) + (box (list #"e1r" (string->bytes signature) + (number->bytes m) (tpromise-sign tp))))) + +(define gs-ext2-ρ-scalar + (λ (signature tp-t tp-u) + (box (list #"e2rs" (string->bytes signature) + (tpromise-sign tp-t) (tpromise-sign tp-u))))) + +(define gs-ext2-ρ + (λ (signature m n tp-t tp-u) + (box (list #"e2r" (string->bytes signature) (number->bytes m) (number->bytes n) + (tpromise-sign tp-t) (tpromise-sign tp-u))))) + +(define gs-ext1-∇ + (λ (signature m tp zp) + (box (list #"e1n" (string->bytes signature) (number->bytes m) + (tpromise-sign tp) (tpromise-sign zp))))) + +(define gs-ext2-∇ + (λ (signature r0 r1 tp-t0 tp-t1 tp-z i) + (box (list #"e2n" (string->bytes signature) + (number->bytes r0) (number->bytes r1) + (tpromise-sign tp-t0) (tpromise-sign tp-t1) (tpromise-sign tp-z) + #"dsr" (number->bytes i))))) + +(define gs-prim1-ρ + (λ (prim-sign tp) + (box (list #"p1r" (string->bytes prim-sign) (tpromise-sign tp))))) + +(define gs-prim2-ρ + (λ (signature tp-t tp-u) + (box (list #"p2r" (string->bytes signature) + (tpromise-sign tp-t) (tpromise-sign tp-u))))) + +(define gs-prim1-∇ + (λ (signature tp zp) + (box (list #"p1n" (string->bytes signature) (tpromise-sign tp) (tpromise-sign zp))))) + +(define gs-prim2-∇ + (λ (signature tp-t0 tp-t1 tp-z i) + (box (list #"p2n" (string->bytes signature) + (tpromise-sign tp-t0) (tpromise-sign tp-t1) (tpromise-sign tp-z) + #"dsr" (number->bytes i))))) + +(define gs-reshape + (λ (shape tp) + (box (list* #"r" (tpromise-sign tp) (map number->bytes shape))))) + +(define tpromise-flat? + (λ (v) + (and (tpromise? v) + (tcomp-ds-ref? (tpromise-tensor v)) + (flat? (car (unbox (tpromise-dst v))))))) + +(define tpmake-flat + (λ (ft) + (tpromise (tcomp-ds-ref #f) (flat-shape ft) + (box (list ft)) (box (list #"dsr"))))) + +(define tpmake-list->tensor + (λ (lst shape) + (let ((tcomp-node (tcomp-list->tensor lst))) + (tpromise tcomp-node shape + (gdst-list->tensor lst) + (gs-list->tensor lst))))) + +(define tpmake-tref + (λ (tp i shape) + (tpromise (tcomp-tref tp (tcomp-ds-ref #f)) + shape + (gdst-tref tp i) + (gs-tref tp)))) + +(define tpmake-trefs + (λ (tp b shape) + (tpromise (tcomp-trefs tp (tcomp-ds-ref #f)) shape + (gdst-trefs tp b) + (gs-trefs tp)))) + +(define tpmake-ext1-ρ-scalar + (λ (f f-acc prim-sign tp shape) + (tpromise (tcomp-ext1-ρ-scalar f f-acc prim-sign tp) shape + (box (list (tpromise-dst tp))) + (gs-ext1-ρ-scalar prim-sign tp)))) + +(define tpmake-ext1-ρ + (λ (f f-acc prim-sign m shape-fn tp shape) + (tpromise (tcomp-ext1-ρ f f-acc prim-sign m shape-fn tp) + shape + (box (list (tpromise-dst tp))) + (gs-ext1-ρ prim-sign m tp)))) + +(define tpmake-ext2-ρ-scalar + (λ (f f-acc prim-sign tp-t tp-u shape) + (tpromise (tcomp-ext2-ρ-scalar f f-acc prim-sign tp-t tp-u) + shape + (box (list (tpromise-dst tp-t) (tpromise-dst tp-u))) + (gs-ext2-ρ-scalar prim-sign tp-t tp-u)))) + +(define ensure-tpromise + (λ (v) + (cond + ((number? v) (tpmake-flat (ensure-flat v))) + ((flat? v) (tpmake-flat v)) + (else v)))) + +(define tpmake-ext2-ρ + (λ (tp-t tp-u f f-acc prim-sign m n shape-fn shape) + (let ((tp-t (ensure-tpromise tp-t)) + (tp-u (ensure-tpromise tp-u))) + (tpromise + (tcomp-ext2-ρ tp-t tp-u f f-acc prim-sign m n shape-fn) + shape + (box (list (tpromise-dst tp-t) (tpromise-dst tp-u))) + (gs-ext2-ρ prim-sign m n tp-t tp-u))))) + +;; we invoke ensure-tpromise on just zp because it's the result of calling +;; force*1 which forces zp to be a non-tpromise value. We can ensure tp to +;; be a tpromise as well, but currently in our workflow we never force tp +;; before passing it to this function, nor do we need scalar tp to be wrapped in +;; a tpromise. +(define tpmake-ext1-∇ + (λ (tp zp f f-acc prim-sign m shape-fn shape) + (let ((zp (ensure-tpromise zp))) + (tpromise + (tcomp-ext1-∇ tp zp f f-acc prim-sign m shape-fn) + shape + (box (list (tpromise-dst tp) (tpromise-dst zp))) + (gs-ext1-∇ prim-sign m tp zp))))) + +(define tpmake-ext2-∇ + (λ (fᵈ fᵈ-acc prim-sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out-ref0 out-ref1 i shape) + (let ((tp-t0 (ensure-tpromise tp-t0)) + (tp-t1 (ensure-tpromise tp-t1)) + (tp-z (ensure-tpromise tp-z))) + (tpromise + (tcomp-ext2-∇ fᵈ fᵈ-acc prim-sign r0 r1 shape-fn + tp-t0 tp-t1 tp-z out-ref0 out-ref1 i) + shape + (gdst-ext2-∇ tp-t0 tp-t1 tp-z) + (gs-ext2-∇ prim-sign r0 r1 tp-t0 tp-t1 tp-z i))))) + +(define tpmake-prim1-ρ + (λ (f f-acc prim-sign shape-fn tp) + (tpromise (tcomp-prim1-ρ f f-acc prim-sign shape-fn tp) + (shape-fn (tpromise-shape tp)) + (box (list (tpromise-dst tp))) + (gs-prim1-ρ prim-sign tp)))) + +(define tpmake-prim2-ρ + (λ (f f-acc prim-sign shape-fn tp-t tp-u) + (tpromise + (tcomp-prim2-ρ f f-acc prim-sign shape-fn tp-t tp-u) + (shape-fn (tpromise-shape tp-t) (tpromise-shape tp-u)) + (box (list (tpromise-dst tp-t) (tpromise-dst tp-u))) + (gs-prim2-ρ prim-sign tp-t tp-u)))) + +(define tpmake-prim1-∇ + (λ (f f-acc prim-sign shape-fn tp zp) + (let ((zp (ensure-tpromise zp))) + (tpromise + (tcomp-prim1-∇ f f-acc prim-sign shape-fn tp zp) + (tpromise-shape tp) + (box (list (tpromise-dst tp) (tpromise-dst zp))) + (gs-prim1-∇ prim-sign tp zp))))) + +(define tpmake-prim2-∇ + (λ (fᵈ fᵈ-acc prim-sign shape-fn tp-t0 tp-t1 tp-z out-ref0 out-ref1 i) + (let ((tp-t0 (ensure-tpromise tp-t0)) + (tp-t1 (ensure-tpromise tp-t1)) + (tp-z (ensure-tpromise tp-z))) + (tpromise + (tcomp-prim2-∇ fᵈ fᵈ-acc prim-sign shape-fn + tp-t0 tp-t1 tp-z out-ref0 out-ref1 i) + (if (zero? i) (tpromise-shape tp-t0) (tpromise-shape tp-t1)) + (gdst-ext2-∇ tp-t0 tp-t1 tp-z) ;; dst constucted in the same way as ext2-∇ + (gs-prim2-∇ prim-sign tp-t0 tp-t1 tp-z i))))) + +(define tpmake-reshape + (λ (tp shape) + (tpromise + (tcomp-reshape shape tp) shape + (tpromise-dst tp) + (gs-reshape shape tp)))) + +(provide (struct-out tcomp) + (struct-out tcomp-list->tensor) + (struct-out tcomp-tref) + (struct-out tcomp-trefs) + (struct-out tcomp-ext1-ρ-scalar) + (struct-out tcomp-ext1-ρ) + (struct-out tcomp-ext2-ρ-scalar) + (struct-out tcomp-ext2-ρ) + (struct-out tcomp-ext1-∇) + (struct-out tcomp-ext2-∇) + (struct-out tcomp-prim1-ρ) + (struct-out tcomp-prim2-ρ) + (struct-out tcomp-prim1-∇) + (struct-out tcomp-prim2-∇) + (struct-out tcomp-reshape) + (struct-out tcomp-let) + (struct-out tcomp-var) + (struct-out tcomp-ds-ref) + (struct-out tpromise) + dst->data-segment + sign + number->bytes + tpromise-flat? + tpmake-flat + tpmake-list->tensor + tpmake-tref + tpmake-trefs + tpmake-ext1-ρ-scalar + tpmake-ext1-ρ + tpmake-ext2-ρ-scalar + tpmake-ext2-ρ + tpmake-ext1-∇ + tpmake-ext2-∇ + tpmake-prim1-ρ + tpmake-prim2-ρ + tpmake-prim1-∇ + tpmake-prim2-∇ + tpmake-reshape) diff --git a/lazy/tensors/c1-racket-runtime.rkt b/lazy/tensors/c1-racket-runtime.rkt new file mode 100644 index 0000000..70561b8 --- /dev/null +++ b/lazy/tensors/c1-racket-runtime.rkt @@ -0,0 +1,111 @@ +#lang racket + +(require ffi/vector) +(require "../../impl-loader.rkt") +(require "../../accelerated-tensors/ext-impl.rkt") +(require (prefix-in acc: "../../accelerated-tensors/tensors.rkt")) + +(struct ext2-∇-result (res) #:mutable #:transparent) + +(define ext2-∇-forcer! + (λ (fᵈ fᵈ-acc fᵈ-sign r0 r1 shape-fn t0 t1 z out-idx0 out-idx1) + (let* ((f0 (ensure-flat t0)) + (f1 (ensure-flat t1)) + (fz (ensure-flat z)) + + (s0 (flat-shape f0)) + (sf0 (min-shape r0 s0)) + (stride0 (acc:size-of sf0)) + + (s1 (flat-shape t1)) + (sf1 (min-shape r1 s1)) + (stride1 (acc:size-of sf1)) + + (sf-z (shape-fn sf0 sf1)) + (stride-z (acc:size-of sf-z)) + + (v0 (flat-store f0)) + (v1 (flat-store f1)) + (vz (flat-store fz)) + + (off0 (flat-offset f0)) + (off1 (flat-offset f1)) + (offz (flat-offset fz))) + (ext2-shapes + s0 s1 r0 r1 sf-z + (λ (sz size-z q0 q1 strides) + (let ((g0 (new-vec (acc:size-of s0) 0.0)) + (g1 (new-vec (acc:size-of s1) 0.0))) + (cond + ((accelerate?) + (let-values (((kernel-code kernel-name) + (ext2-∇-kernel/name fᵈ-acc fᵈ-sign strides s0 s1 r0 r1 sz + (length sf-z)))) + (run-prim2-∇! kernel-code kernel-name + g0 g1 + v0 off0 (acc:size-of s0) stride0 + v1 off1 (acc:size-of s1) stride1 + vz offz size-z stride-z))) + (else + (for ([iz (in-range 0 size-z stride-z)]) + (let-values (((i0 i1) (idxs strides iz off0 off1))) + (fᵈ g0 g1 v0 i0 stride0 v1 i1 stride1 vz (+ offz iz) stride-z))))) + (when out-idx0 + (data-segment-set! out-idx0 (scalarize (flat s0 g0 0)))) + (when out-idx1 + (data-segment-set! out-idx1 (scalarize (flat s1 g1 0)))))))))) + +(define prim2-∇-forcer! + (λ (fᵈ fᵈ-acc fᵈ-sign shape-fn t0 t1 z out-idx0 out-idx1) + (let* ((in-shape-a (flat-shape t0)) + (in-size-a (size-of in-shape-a)) + (in-shape-b (flat-shape t1)) + (in-size-b (size-of in-shape-b)) + (out-shape (shape-fn in-shape-a in-shape-b)) + (out-size (size-of out-shape))) + (let ((g0 (new-vec in-size-a 0.0)) + (g1 (new-vec in-size-b 0.0))) + (cond + ((null? out-shape) + (let ((v-z (new-vec 1 z))) + (fᵈ g0 g1 + (flat-store t0) (flat-offset t0) in-size-a + (flat-store t1) (flat-offset t1) in-size-b + v-z 0 1))) + (else + (fᵈ g0 g1 + (flat-store t0) (flat-offset t0) in-size-a + (flat-store t1) (flat-offset t1) in-size-b + (flat-store z) (flat-offset z) out-size))) + (when out-idx0 + (data-segment-set! out-idx0 (scalarize (flat in-shape-a g0 0)))) + (when out-idx1 + (data-segment-set! out-idx1 (scalarize (flat in-shape-b g1 0)))))))) + +(define rt:trefs + (λ (ft b) + (cond + ((= (acc:rank b) 1) (acc:trefs ft (map inexact->exact (f32vector->list (flat-store b))))) + (else (error 'trefs-err "~a should be a tensor¹" b))))) + +(define data-segment + (make-parameter #f)) + +(define data-segment-set! + (λ (i v) + (vector-set! (data-segment) i v))) + +(define data-segment-ref + (λ (i) + (vector-ref (data-segment) i))) + +(define-namespace-anchor a) +(define runtime + (namespace-anchor->namespace a)) + +(provide runtime flat? acc:build-tensor acc:list->tensor + acc:tref rt:trefs (struct-out ext2-∇-result) set-ext2-∇-result-res! + ext2-∇-forcer! scalarize flat-ext1-∇ ensure-flat flat-ext2-ρ + flat flat-store flat-offset flat-ext1-ρ data-segment data-segment-ref + apply-flat-ρ-fn-1 apply-flat-ρ-fn-2 apply-flat-∇-fn-1 apply-flat-∇-fn-2 + prim2-∇-forcer!) diff --git a/lazy/tensors/c2-interpreter.rkt b/lazy/tensors/c2-interpreter.rkt new file mode 100644 index 0000000..7bbba2c --- /dev/null +++ b/lazy/tensors/c2-interpreter.rkt @@ -0,0 +1,144 @@ +#lang racket + +(require "c0-ast.rkt") +(require (only-in "c1-racket-runtime.rkt" + runtime flat? acc:build-tensor acc:list->tensor + set-ext2-∇-result-res! acc:tref rt:trefs ext2-∇-result-res + ext2-∇-forcer! scalarize flat-ext1-∇ ensure-flat flat-ext2-ρ + flat flat-store flat-offset flat-ext1-ρ data-segment + apply-flat-ρ-fn-1 apply-flat-ρ-fn-2 apply-flat-∇-fn-1 apply-flat-∇-fn-2 + data-segment-ref prim2-∇-forcer!)) + +(define interp-tcomp + (λ (tc env) + (match tc + [(tcomp-list->tensor lst) + (let ((eval-list + (for/list ((arg lst)) + (cond + ((tpromise? arg) (interp-tpromise arg env)) + ((number? arg) arg) + (else (error 'interp-list->tensor "Unexpected: ~a" arg)))))) + (acc:list->tensor eval-list))] + [(tcomp-tref tp (and i (tcomp-ds-ref _))) + (acc:tref (interp-tpromise tp env) + (interp-tcomp i env))] + [(tcomp-trefs tp (and b (tcomp-ds-ref _))) + (rt:trefs (interp-tpromise tp env) + (interp-tcomp b env))] + [(tcomp-ext2-∇ fᵈ fᵈ-acc f-sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + (let ((t0-instrs (interp-tpromise tp-t0 env)) + (t1-instrs (interp-tpromise tp-t1 env)) + (z-instrs (interp-tpromise tp-z env))) + (cond + ((and (eqv? i 0) + (not (tcomp-ds-ref-index (ext2-∇-result-res out0)))) + (set-ext2-∇-result-res! out0 (tcomp-ds-ref (current-ds-ref-index))) + (current-ds-ref-index (add1 (current-ds-ref-index)))) + ((and (eqv? i 1) + (not (tcomp-ds-ref-index (ext2-∇-result-res out1)))) + (set-ext2-∇-result-res! out1 (tcomp-ds-ref (current-ds-ref-index))) + (current-ds-ref-index (add1 (current-ds-ref-index))))) + (let* ((out-idx0 (tcomp-ds-ref-index (ext2-∇-result-res out0))) + (out-idx1 (tcomp-ds-ref-index (ext2-∇-result-res out1))) + (index (if (zero? i) out-idx0 out-idx1)) + (v (data-segment-ref index))) + (cond + ((eqv? v 'uncalculated) + (ext2-∇-forcer! fᵈ fᵈ-acc f-sign r0 r1 shape-fn + t0-instrs t1-instrs + z-instrs out-idx0 out-idx1) + (data-segment-ref index)) + (else v))))] + [(tcomp-ext1-∇ tp zp f f-acc f-sign m shape-fn) + (scalarize + (flat-ext1-∇ f f-acc m shape-fn f-sign + (ensure-flat (interp-tpromise tp env)) + (ensure-flat (interp-tpromise zp env))))] + [(tcomp-ext2-ρ-scalar f f-acc _ tp-t tp-u) + (f (interp-tpromise tp-t env) (interp-tpromise tp-u env))] + [(tcomp-ext2-ρ tp-t tp-u f f-acc f-sign m n shape-fn) + (scalarize + (flat-ext2-ρ f f-acc m n shape-fn f-sign + (ensure-flat (interp-tpromise tp-t env)) + (ensure-flat (interp-tpromise tp-u env))))] + [(tcomp-ext1-ρ-scalar f f-acc _ tp) + (f (interp-tpromise tp env))] + [(tcomp-ext1-ρ f f-acc f-sign m shape-fn tp) + (scalarize + (flat-ext1-ρ f f-acc m shape-fn f-sign + (ensure-flat (interp-tpromise tp env))))] + [(tcomp-prim1-ρ f f-acc sign shape-fn tp) + (apply-flat-ρ-fn-1 f (interp-tpromise tp) shape-fn)] + [(tcomp-prim2-ρ f f-acc sign shape-fn tp-t tp-u) + (apply-flat-ρ-fn-2 f (interp-tensor tp-t) (interp-tpromise tp-u) shape-fn)] + [(tcomp-prim1-∇ f f-acc sign shape-fn tp zp) + (apply-flat-∇-fn-1 f (interp-tpromise tp) (scalarize (interp-tpromise zp)) shape-fn)] + [(tcomp-prim2-∇ fᵈ fᵈ-acc f-sign shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + (let ((t0-instrs (interp-tpromise tp-t0 env)) + (t1-instrs (interp-tpromise tp-t1 env)) + (z-instrs (interp-tpromise tp-z env))) + (cond + ((and (eqv? i 0) + (not (tcomp-ds-ref-index (ext2-∇-result-res out0)))) + (set-ext2-∇-result-res! out0 (tcomp-ds-ref (current-ds-ref-index))) + (current-ds-ref-index (add1 (current-ds-ref-index)))) + ((and (eqv? i 1) + (not (tcomp-ds-ref-index (ext2-∇-result-res out1)))) + (set-ext2-∇-result-res! out1 (tcomp-ds-ref (current-ds-ref-index))) + (current-ds-ref-index (add1 (current-ds-ref-index))))) + (let* ((out-idx0 (tcomp-ds-ref-index (ext2-∇-result-res out0))) + (out-idx1 (tcomp-ds-ref-index (ext2-∇-result-res out1))) + (index (if (zero? i) out-idx0 out-idx1)) + (v (data-segment-ref index))) + (cond + ((eqv? v 'uncalculated) + (prim2-∇-forcer! fᵈ f-sign shape-fn + t0-instrs t1-instrs + z-instrs out-idx0 out-idx1) + (data-segment-ref index)) + (else v))))] + [(tcomp-reshape s tp) + (let ([interp-tp (interp-tpromise tp env)]) + (flat s (flat-store interp-tp) (flat-offset interp-tp)))] + [(tcomp-let lhs rhs body) + (interp-tpromise + body + (cons + (cons lhs + (interp-tpromise rhs env)) + env))] + [(tcomp-var name) + (cond + ((assv name env) + => + (λ (p) (cdr p))) + (else (error 'interpret-free "Free variable: ~a" name)))] + [(tcomp-ds-ref #f) + (let ([out (data-segment-ref (current-ds-ref-index))]) + (current-ds-ref-index (add1 (current-ds-ref-index))) + out)] + [(tcomp-ds-ref index) + ;; This case is run only for languages where the tcomp-ds-ref indices are + ;; generated by the generate-ds-refs pass. + (data-segment-ref index)]))) + +(define interp-tpromise + (λ (t env) + (match t + [(tpromise tc _ _ _) (interp-tcomp tc env)]))) + +(define current-ds-ref-index (make-parameter #f)) +(define interp-tensor + (λ (tp) + (parameterize ([current-ds-ref-index 0] + [data-segment (dst->data-segment (tpromise-dst tp))]) + (interp-tpromise tp '())))) + +(define interp-racket + (lambda (instrs ds) + (parameterize ((data-segment ds)) + (eval instrs runtime)))) + +(include "test/test-c2-interpreter.rkt") +(provide interp-racket interp-tensor) diff --git a/lazy/tensors/c3-compiler.rkt b/lazy/tensors/c3-compiler.rkt new file mode 100644 index 0000000..15d171c --- /dev/null +++ b/lazy/tensors/c3-compiler.rkt @@ -0,0 +1,612 @@ +#lang racket + +(require "c0-ast.rkt") +(require (only-in "c2-interpreter.rkt" interp-tensor interp-racket)) +(require (only-in "c1-racket-runtime.rkt" + runtime ext2-∇-result-res + set-ext2-∇-result-res!)) +(require rackunit) + +(struct counter-data (binding-name ref-count) + #:transparent) + +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; +;; Compiler Passes +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; + +;; The data segment is a vector that contains +;; * scalars (arguments to tref) +;; * flat tensors +;; * flat tensor¹ of indices that will be the arguments to trefs +;; * the symbol 'uncalculated as an initial placeholder for the output of +;; tcomp-ext2-∇ which will be later replaced by the flat tensor output + +(define generate-ds-refs + (λ (t) + (let-values (((t^ ref) (gdr-tpromise t 0 (make-hasheq)))) + t^))) + +(define gdr-tpromise + (λ (tp ref memo) + (match tp + ((tpromise tc s dss sign) + (let-values (((tc^ ref^) (gdr-tcomp tc ref memo))) + (values (tpromise tc^ s dss sign) ref^)))))) + +(define gdr-tcomp + (λ (tc ref memo) + (cond + ((hash-ref memo tc #f) + => + (λ (res/ref-count) + (match-let (((cons res ref-count) res/ref-count)) + (values res (+ ref ref-count))))) + (else + (let-values + (((res ref^) + (match tc + ((? number?) (values tc ref)) + [(tcomp-list->tensor lst) + (for/fold + ((tcs '()) + (ref^ ref) + #:result (values (tcomp-list->tensor (reverse tcs)) ref^)) + ((l lst)) + (let-values (((tc ref^^) + (cond + ((tpromise? l) (gdr-tpromise l ref^ memo)) + ((number? l) (values l ref^)) + (else (error 'gdr-list->tensor + "Unexpected: ~a" l))))) + (values (cons tc tcs) ref^^)))] + [(tcomp-tref tp (tcomp-ds-ref #f)) + (let-values (((tp^ ref^) (gdr-tpromise tp ref memo))) + (values (tcomp-tref tp^ (tcomp-ds-ref ref^)) (add1 ref^)))] + [(tcomp-trefs tp (tcomp-ds-ref #f)) + (let-values (((tp^ ref^) (gdr-tpromise tp ref memo))) + (values (tcomp-trefs tp^ (tcomp-ds-ref ref^)) (add1 ref^)))] + [(tcomp-ext2-∇ fᵈ fᵈ-acc sign r0 r1 shape-fn tp-t0 tp-t1 tp-z + out-ref0 + out-ref1 i) + (let*-values (((tp-t0^ ref^) (gdr-tpromise tp-t0 ref memo)) + ((tp-t1^ ref^^) (gdr-tpromise tp-t1 ref^ memo)) + ((tp-z^ ref^^^) (gdr-tpromise tp-z ref^^ memo))) + (cond + ((and (eqv? i 0) + (not (tcomp-ds-ref-index (ext2-∇-result-res out-ref0)))) + (set-ext2-∇-result-res! out-ref0 (tcomp-ds-ref ref^^^))) + ((and (eqv? i 1) + (not (tcomp-ds-ref-index (ext2-∇-result-res out-ref1)))) + (set-ext2-∇-result-res! out-ref1 (tcomp-ds-ref ref^^^)))) + (values (tcomp-ext2-∇ fᵈ fᵈ-acc sign r0 r1 shape-fn tp-t0^ tp-t1^ tp-z^ + out-ref0 out-ref1 i) + (add1 ref^^^)))] + [(tcomp-ext1-∇ tp zp f f-acc sign m shape-fn) + (let*-values (((tp^ ref^) (gdr-tpromise tp ref memo)) + ((zp^ ref^^) (gdr-tpromise zp ref^ memo))) + (values (tcomp-ext1-∇ tp^ zp^ f f-acc sign m shape-fn) ref^^))] + [(tcomp-ext2-ρ-scalar f f-acc sign tp-t tp-u) + (let*-values (((tp-t^ ref^) (gdr-tpromise tp-t ref memo)) + ((tp-u^ ref^^) (gdr-tpromise tp-u ref^ memo))) + (values (tcomp-ext2-ρ-scalar f f-acc sign tp-t^ tp-u^) ref^^))] + [(tcomp-ext2-ρ tp-t tp-u f f-acc sign m n shape-fn) + (let*-values (((tp-t^ ref^) (gdr-tpromise tp-t ref memo)) + ((tp-u^ ref^^) (gdr-tpromise tp-u ref^ memo))) + (values (tcomp-ext2-ρ tp-t^ tp-u^ f f-acc sign m n shape-fn) ref^^))] + [(tcomp-ext1-ρ-scalar f f-acc sign tp) + (let-values (((tp^ ref^) (gdr-tpromise tp ref memo))) + (values (tcomp-ext1-ρ-scalar f f-acc sign tp^) ref^))] + [(tcomp-ext1-ρ f f-acc sign m shape-fn tp) + (let-values (((tp^ ref^) (gdr-tpromise tp ref memo))) + (values (tcomp-ext1-ρ f f-acc sign m shape-fn tp^) ref^))] + [(tcomp-prim1-ρ f f-acc sign shape-fn tp) + (let-values (((tp^ ref^) (gdr-tpromise tp ref memo))) + (values (tcomp-prim1-ρ f f-acc sign shape-fn tp^) ref^))] + [(tcomp-prim2-ρ f f-acc sign shape-fn tp-t tp-u) + (let*-values (((tp-t^ ref^) (gdr-tpromise tp-t ref memo)) + ((tp-u^ ref^^) (gdr-tpromise tp-u ref^ memo))) + (values (tcomp-prim2-ρ f f-acc sign shape-fn tp-t^ tp-u^) ref^^))] + [(tcomp-prim1-∇ f f-acc sign shape-fn tp zp) + (let*-values (((tp^ ref^) (gdr-tpromise tp ref memo)) + ((zp^ ref^^) (gdr-tpromise zp ref^ memo))) + (values (tcomp-prim1-∇ f f-acc sign shape-fn tp^ zp^) ref^^))] + [(tcomp-prim2-∇ fᵈ fᵈ-acc f-sign shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + (let*-values (((tp-t0^ ref^) (gdr-tpromise tp-t0 ref memo)) + ((tp-t1^ ref^^) (gdr-tpromise tp-t1 ref^ memo)) + ((tp-z^ ref^^^) (gdr-tpromise tp-z ref^^ memo))) + (cond + ((and (eqv? i 0) + (not (tcomp-ds-ref-index (ext2-∇-result-res out0)))) + (set-ext2-∇-result-res! out0 (tcomp-ds-ref ref^^^))) + ((and (eqv? i 1) + (not (tcomp-ds-ref-index (ext2-∇-result-res out1)))) + (set-ext2-∇-result-res! out1 (tcomp-ds-ref ref^^^)))) + (values (tcomp-prim2-∇ fᵈ fᵈ-acc f-sign shape-fn tp-t0^ tp-t1^ tp-z^ + out0 out1 i) + (add1 ref^^^)))] + [(tcomp-reshape s tp) + (let-values (((tp^ ref^) (gdr-tpromise tp ref memo))) + (values (tcomp-reshape s tp^) ref^))] + [(tcomp-ds-ref #f) (values (tcomp-ds-ref ref) (add1 ref))] + ;;need these cases for testing compiler invariant + [(tcomp-let lhs rhs body) + (let*-values (((rhs^ ref^) (gdr-tpromise rhs ref memo)) + ((body^ ref^^) (gdr-tpromise body ref^ memo))) + (values (tcomp-let lhs rhs^ body^) ref^^))] + [(tcomp-var name) (values (tcomp-var name) ref)]))) + (hash-set! memo tc (cons res (- ref^ ref))) + (values res ref^)))))) + +;; Count references so that the tcomp AST nodes that refer to the same memory +;; location i.e. common AST nodes get extracted by let-binding them in the +;; compiled output racket code. +(define count-references + (λ (t) + (let-values (((counter uid) (cr-tpromise t (hasheq) 0))) + counter))) + +;; TODO: Try using the signature field of tpromise struct as keys instead tcomp +;; references. The naive way to do this might be inefficient because of the +;; constant conversion between the tree representation of the signature and the +;; numeric hash signature. +(define cr-tpromise + (λ (t counter uid) + (match t + ((tpromise tc _ _ _) + (cr-tcomp tc counter uid))))) + +(define cr-tcomp + (λ (tc counter uid) + (cond + ((number? tc) (values counter uid)) + (else + (match-let (((counter-data tc-binding-name tc-ref-count) + (hash-ref counter tc + (λ () + (let-values (((st _) (struct-info tc))) + (let-values (((tcomp-name _0 _1 _2 _3 _4 _5 _6) + (struct-type-info st))) + (counter-data (string->symbol + (format "~a_~a" tcomp-name uid)) + 0))))))) + (let* ((new-count (add1 tc-ref-count)) + (counter^ (hash-set counter tc + (counter-data tc-binding-name + new-count))) + (uid^ (add1 uid))) + (cond + ((> new-count 1) + ;; No need to increase reference count of children if parent occurs + ;; more than once. This helps avoid creating extra tcom-var later in + ;; ecs if child tcomp occurs only once within the parent tcomp, but + ;; the parent tcomp itself occurs more than once. + (values counter^ uid^)) + (else + (match tc + [(tcomp-list->tensor lst) + (for/fold + ((counter^^ counter^) + (uid^^ uid^)) + ((l lst)) + (cond + ((tpromise? l) (cr-tpromise l counter^^ uid^^)) + ((number? l) (values counter^^ uid^^)) + (else (error 'cr-list->tensor "Unexpected: ~a" l))))] + [(tcomp-tref tp (and i (tcomp-ds-ref _))) + (cr-tpromise tp counter^ uid^)] + [(tcomp-trefs tp (and b (tcomp-ds-ref _))) + (cr-tpromise tp counter^ uid^)] + [(tcomp-ext2-∇ fᵈ _ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + (let*-values (((counter-1 uid-1) (cr-tpromise tp-t0 counter^ uid^)) + ((counter-2 uid-2) (cr-tpromise tp-z counter-1 uid-1))) + (cr-tpromise tp-t1 counter-2 uid-2))] + [(tcomp-ext1-∇ tp zp f _ sign m shape-fn) + (let-values (((counter-1 uid-1) (cr-tpromise tp counter^ uid^))) + (cr-tpromise zp counter-1 uid-1))] + [(tcomp-ext2-ρ-scalar f _ sign tp-t tp-u) + (let-values (((counter-1 uid-1) (cr-tpromise tp-t counter^ uid^))) + (cr-tpromise tp-u counter-1 uid-1))] + [(tcomp-ext2-ρ tp-t tp-u f _ sign m n shape-fn) + (let-values (((counter-1 uid-1) (cr-tpromise tp-t counter^ uid^))) + (cr-tpromise tp-u counter-1 uid-1))] + [(tcomp-ext1-ρ-scalar f _ sign tp) + (cr-tpromise tp counter^ uid^)] + [(tcomp-ext1-ρ f _ sign m shape-fn tp) + (cr-tpromise tp counter^ uid^)] + [(tcomp-prim1-ρ f f-acc sign shape-fn tp) + (cr-tpromise tp counter^ uid^)] + [(tcomp-prim2-ρ f f-acc sign shape-fn tp-t tp-u) + (let-values (((counter-1 uid-1) (cr-tpromise tp-t counter^ uid^))) + (cr-tpromise tp-u counter-1 uid-1))] + [(tcomp-prim1-∇ f f-acc sign shape-fn tp zp) + (let-values (((counter-1 uid-1) (cr-tpromise tp counter^ uid^))) + (cr-tpromise zp counter-1 uid-1))] + [(tcomp-prim2-∇ fᵈ fᵈ-acc f-sign shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + (let*-values (((counter-1 uid-1) (cr-tpromise tp-t0 counter^ uid^)) + ((counter-2 uid-2) (cr-tpromise tp-z counter-1 uid-1))) + (cr-tpromise tp-t1 counter-2 uid-2))] + [(tcomp-reshape s tp) + (cr-tpromise tp counter^ uid^)] + [(tcomp-ds-ref index) (values counter^ uid^)] + ;;need these cases for testing compiler invariant + [(tcomp-let lhs rhs body) + (let-values (((counter-1 uid-1) (cr-tpromise rhs counter^ uid^))) + (cr-tpromise body counter-1 uid-1))] + [(tcomp-var name) (values counter^ uid^)]))))))))) + +(define extract-common-subexpressions + (λ (t counter) + (let-values (((instrs bindings) + (run-compiler-ecs (ecs-tpromise t counter) '()))) + (for/fold ((body instrs) + #:result ;; set-box! the data-segment of result so that + ;; applying it to interp-tensor works + (begin + (set-box! (tpromise-dst body) (unbox (tpromise-dst t))) + (set-box! (tpromise-sign body) (unbox (tpromise-sign t))) + body)) + ((binding bindings)) + (tpromise (tcomp-let (car binding) + (tpromise (cdr binding) '() (box '()) (box '())) + body) + '() (box '()) (box '())))))) + +(define ecs-tpromise + (λ (tc counter) + (match tc + [(tpromise tc s dss sign) + (->ecs + (ecs-tcomp tc counter) + (λ (instrs) + (inj-ecs-val (tpromise instrs s dss sign))))]))) + +(define ecs-tcomp + (λ (tc counter) + (let ((tc-counter-data + (hash-ref counter tc + (λ () + (counter-data (gensym 'illegal) 0))))) + (match tc + [tc #:when (number? tc) + (inj-ecs-val tc)] + [(tcomp-list->tensor lst) + (let ((instrs-list-compiler + (for/foldr + ((list-compiler (inj-ecs-val '()))) + ((arg lst)) + (->ecs + (cond + ((tpromise? arg) (ecs-tpromise arg counter)) + ((number? arg) (inj-ecs-val arg)) + (else (error 'ecs-list->tensor "Unexpected: ~a" arg))) + (λ (instrs) + (->ecs + list-compiler + (λ (instrs-list) + (inj-ecs-val (cons instrs instrs-list))))))))) + (->ecs + instrs-list-compiler + (λ (instrs-list) + (inj-ecs-tcomp (tcomp-list->tensor instrs-list) tc-counter-data))))] + [(tcomp-tref tp (and i (tcomp-ds-ref _))) + (->ecs + (ecs-tpromise tp counter) + (λ (instrs) + (inj-ecs-tcomp (tcomp-tref instrs i) tc-counter-data)))] + [(tcomp-trefs tp (and b (tcomp-ds-ref _))) + (->ecs + (ecs-tpromise tp counter) + (λ (instrs) + (inj-ecs-tcomp (tcomp-trefs instrs b) tc-counter-data)))] + [(tcomp-ext2-∇ fᵈ fᵈ-acc sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + (->ecs + (ecs-tpromise tp-t0 counter) + (λ (t0-instrs) + (->ecs + (ecs-tpromise tp-t1 counter) + (λ (t1-instrs) + (->ecs + (ecs-tpromise tp-z counter) + (λ (z-instrs) + (inj-ecs-tcomp + (tcomp-ext2-∇ fᵈ fᵈ-acc sign r0 r1 shape-fn + t0-instrs t1-instrs z-instrs + out0 out1 i) + tc-counter-data)))))))] + [(tcomp-ext1-∇ tp zp f f-acc sign m shape-fn) + (->ecs + (ecs-tpromise tp counter) + (λ (t-instrs) + (->ecs + (ecs-tpromise zp counter) + (λ (z-instrs) + (inj-ecs-tcomp + (tcomp-ext1-∇ t-instrs z-instrs f f-acc sign m shape-fn) + tc-counter-data)))))] + [(tcomp-ext2-ρ-scalar f f-acc sign tp-t tp-u) + (->ecs + (ecs-tpromise tp-t counter) + (λ (t-instrs) + (->ecs + (ecs-tpromise tp-u counter) + (λ (u-instrs) + (inj-ecs-tcomp + (tcomp-ext2-ρ-scalar f f-acc sign t-instrs u-instrs) + tc-counter-data)))))] + [(tcomp-ext2-ρ tp-t tp-u f f-acc sign m n shape-fn) + (->ecs + (ecs-tpromise tp-t counter) + (λ (t-instrs) + (->ecs + (ecs-tpromise tp-u counter) + (λ (u-instrs) + (inj-ecs-tcomp + (tcomp-ext2-ρ t-instrs u-instrs f f-acc sign m n shape-fn) + tc-counter-data)))))] + [(tcomp-ext1-ρ-scalar f f-acc sign tp) + (->ecs + (ecs-tpromise tp counter) + (λ (instrs) + (inj-ecs-tcomp (tcomp-ext1-ρ-scalar f f-acc sign instrs) tc-counter-data)))] + [(tcomp-ext1-ρ f f-acc sign m shape-fn tp) + (->ecs + (ecs-tpromise tp counter) + (λ (instrs) + (inj-ecs-tcomp (tcomp-ext1-ρ f f-acc sign m shape-fn instrs) tc-counter-data)))] + [(tcomp-prim1-ρ f f-acc sign shape-fn tp) + (->ecs + (ecs-tpromise tp counter) + (λ (instrs) + (inj-ecs-tcomp (tcomp-prim1-ρ f f-acc sign shape-fn instrs) tc-counter-data)))] + [(tcomp-prim2-ρ f f-acc sign shape-fn tp-t tp-u) + (->ecs + (ecs-tpromise tp-t counter) + (λ (t-instrs) + (->ecs + (ecs-tpromise tp-u counter) + (λ (u-instrs) + (inj-ecs-tcomp + (tcomp-prim2-ρ f f-acc sign shape-fn t-instrs u-instrs) + tc-counter-data)))))] + [(tcomp-prim1-∇ f f-acc sign shape-fn tp zp) + (->ecs + (ecs-tpromise tp counter) + (λ (t-instrs) + (->ecs + (ecs-tpromise zp counter) + (λ (z-instrs) + (inj-ecs-tcomp + (tcomp-prim1-∇ f f-acc sign shape-fn t-instrs z-instrs) + tc-counter-data)))))] + [(tcomp-prim2-∇ fᵈ fᵈ-acc f-sign shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + (->ecs + (ecs-tpromise tp-t0 counter) + (λ (t0-instrs) + (->ecs + (ecs-tpromise tp-t1 counter) + (λ (t1-instrs) + (->ecs + (ecs-tpromise tp-z counter) + (λ (z-instrs) + (inj-ecs-tcomp + (tcomp-prim2-∇ fᵈ fᵈ-acc f-sign shape-fn + t0-instrs t1-instrs z-instrs + out0 out1 i) + tc-counter-data)))))))] + [(tcomp-reshape s tp) + (->ecs + (ecs-tpromise tp counter) + (λ (instrs) + (inj-ecs-tcomp (tcomp-reshape s instrs) tc-counter-data)))] + [(tcomp-ds-ref index) (inj-ecs-tcomp tc tc-counter-data)])))) + +(struct CompilerECS (run-compiler) #:transparent) + +(define run-compiler-ecs + (λ (c bindings) + ((CompilerECS-run-compiler c) bindings))) + +(define inj-ecs-val + (λ (v) + (CompilerECS (λ (bindings) (values v bindings))))) + +(define inj-ecs-tcomp + (λ (instrs cd) + (match-let (((counter-data binding-var ref-count) cd)) + (CompilerECS (λ (bindings) + (cond + ((<= ref-count 1) + (values instrs bindings)) + ((assv binding-var bindings) + (values (tcomp-var binding-var) bindings)) + (else + (values (tcomp-var binding-var) + (extend-bindings binding-var instrs bindings))))))))) + +(define ->ecs + (λ (c f) + (CompilerECS + (λ (bindings) + (let-values (((instrs bindings^) (run-compiler-ecs c bindings))) + (run-compiler-ecs (f instrs) bindings^)))))) + +(define extend-env + (λ (k v env) + `((,k . ,v) . ,env))) + +(define extend-bindings extend-env) + +(define exists-in-env? + (λ (ft env) + (match env + ('() #f) + (`((,k . ,v) . ,_) #:when (eq? ft v) k) + (`(,_ . ,rest-env) (exists-in-env? ft rest-env))))) + +(define generate-racket + (λ (t) + (gr-tpromise t))) + +(define gr-tpromise + (λ (t) + (match t + [(tpromise tc _ _ _) (gr-tcomp tc)]))) + +(define gr-tcomp + (λ (tc) + (match tc + [v #:when (number? v) v] + [(tcomp-list->tensor lst) + (let ((instrs-list + (map (λ (t) + (cond + ((tpromise? t) (gr-tpromise t)) + ((number? t) t) + (else (error 'gr-list->tensor "Unexpected: ~a" t)))) + lst))) + `(acc:list->tensor (list ,@instrs-list)))] + [(tcomp-tref tp (and i (tcomp-ds-ref _))) + (let ((instrs (gr-tpromise tp)) + (i-instrs (gr-tcomp i))) + `(acc:tref ,instrs ,i-instrs))] + [(tcomp-trefs tp (and b (tcomp-ds-ref _))) + (let ((instrs (gr-tpromise tp)) + (b-instrs (gr-tcomp b))) + `(rt:trefs ,instrs ,b-instrs))] + [(tcomp-ext2-∇ fᵈ fᵈ-acc sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + (let ((t0-instrs (gr-tpromise tp-t0)) + (t1-instrs (gr-tpromise tp-t1)) + (z-instrs (gr-tpromise tp-z)) + (out-idx0 (tcomp-ds-ref-index (ext2-∇-result-res out0))) + (out-idx1 (tcomp-ds-ref-index (ext2-∇-result-res out1)))) + (let ((index (if (zero? i) out-idx0 out-idx1))) + `(let* ([index ,index] + [v (data-segment-ref index)]) + (cond + ((eqv? v 'uncalculated) + (ext2-∇-forcer! ,fᵈ ,fᵈ-acc ,sign ,r0 ,r1 ,shape-fn + ,t0-instrs ,t1-instrs + ,z-instrs ,out-idx0 ,out-idx1) + (data-segment-ref index)) + (else v)))))] + [(tcomp-ext1-∇ tp zp f f-acc sign m shape-fn) + (let ((t-instrs (gr-tpromise tp)) + (z-instrs (gr-tpromise zp))) + `(scalarize + (flat-ext1-∇ ,f ,f-acc ,m ,shape-fn ,sign + (ensure-flat ,t-instrs) + (ensure-flat ,z-instrs))))] + [(tcomp-ext2-ρ-scalar f f-acc sign tp-t tp-u) + (let ((t-instrs (gr-tpromise tp-t)) + (u-instrs (gr-tpromise tp-u))) + `(,f ,t-instrs ,u-instrs))] + [(tcomp-ext2-ρ tp-t tp-u f f-acc sign m n shape-fn) + (let ((t-instrs (gr-tpromise tp-t)) + (u-instrs (gr-tpromise tp-u))) + `(scalarize + (flat-ext2-ρ ,f ,f-acc ,m ,n ,shape-fn ,sign + (ensure-flat ,t-instrs) + (ensure-flat ,u-instrs))))] + [(tcomp-ext1-ρ-scalar f f-acc sign tp) + (let ((instrs (gr-tpromise tp))) + `(,f ,instrs))] + [(tcomp-ext1-ρ f f-acc sign m shape-fn tp) + (let ((instrs (gr-tpromise tp))) + `(scalarize + (flat-ext1-ρ ,f ,f-acc ,m ,shape-fn ,sign + (ensure-flat ,instrs))))] + [(tcomp-prim1-ρ f f-acc sign shape-fn tp) + (let ((instrs (gr-tpromise tp))) + `(apply-flat-ρ-fn-1 ,f ,instrs ,shape-fn))] + [(tcomp-prim2-ρ f f-acc sign shape-fn tp-t tp-u) + (let ((t-instrs (gr-tpromise tp-t)) + (u-instrs (gr-tpromise tp-u))) + `(apply-flat-ρ-fn-2 ,f ,t-instrs ,u-instrs ,shape-fn))] + [(tcomp-prim1-∇ f f-acc sign shape-fn tp zp) + (let ((t-instrs (gr-tpromise tp)) + (z-instrs (gr-tpromise zp))) + `(apply-flat-∇-fn-1 ,f ,t-instrs (scalarize ,z-instrs) ,shape-fn))] + [(tcomp-prim2-∇ fᵈ fᵈ-acc f-sign shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + (let ((t0-instrs (gr-tpromise tp-t0)) + (t1-instrs (gr-tpromise tp-t1)) + (z-instrs (gr-tpromise tp-z)) + (out-idx0 (tcomp-ds-ref-index (ext2-∇-result-res out0))) + (out-idx1 (tcomp-ds-ref-index (ext2-∇-result-res out1)))) + (let ((index (if (zero? i) out-idx0 out-idx1))) + `(let* ([index ,index] + [v (data-segment-ref index)]) + (cond + ((eqv? v 'uncalculated) + (prim2-∇-forcer! ,fᵈ ,fᵈ-acc ,f-sign ,shape-fn + ,t0-instrs ,t1-instrs + ,z-instrs ,out-idx0 ,out-idx1) + (data-segment-ref index)) + (else v)))))] + [(tcomp-reshape s tp) + (let ((instrs (gr-tpromise tp))) + `(flat ',s + (flat-store ,instrs) + (flat-offset ,instrs)))] + [(tcomp-let lhs rhs body) + (let ((rhs-instrs (gr-tpromise rhs)) + (body-instrs (gr-tpromise body))) + `(let ((,lhs ,rhs-instrs)) + ,body-instrs))] + [(tcomp-var name) name] + [(tcomp-ds-ref index) `(data-segment-ref ,index)]))) + +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; +;; Composing Compiler Passes +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; + +(define print-compiler? (make-parameter #f)) +(define display-compiler-trace + (λ (title value) + (when (or (equal? (print-compiler?) #t) + (and (list? (print-compiler?)) (member title (print-compiler?)))) + (printf "~a:~n" title) + (displayln "--------------") + (pretty-print value) + (displayln "")))) +(define cache + (make-parameter (make-hash))) +(define compile-tensor + (λ (t) + (display-compiler-trace 'Source-Tensor t) + (let ((ds (dst->data-segment (tpromise-dst t)))) + (display-compiler-trace 'Data-Segment ds) + (let ((signature (sign (tpromise-sign t)))) + (display-compiler-trace 'Signature signature) + (cond + ((hash-has-key? (cache) signature) + (let ((compiled (hash-ref (cache) signature))) + (display-compiler-trace 'Cache-Hit signature) + (values compiled ds))) + (else + (let ((instrs-dsr (generate-ds-refs t))) + (display-compiler-trace 'Generate-DS-Refs instrs-dsr) + (let ((counter (count-references instrs-dsr))) + (display-compiler-trace 'Count-References counter) + (let ((extracted (extract-common-subexpressions instrs-dsr counter))) + (display-compiler-trace 'Extract-Common-Subexpressions extracted) + (let* ((gr (generate-racket extracted)) + (rkt (compile-racket gr))) + (display-compiler-trace 'Generate-Racket gr) + (hash-set! (cache) signature rkt) + (values rkt ds))))))))))) + +(define compile-racket + (λ (r) + (parameterize ([current-namespace runtime]) + (compile-syntax (expand r))))) + +(define get-compiled + (λ (t) + (let-values (((instrs env) + (compile-tensor t))) + `(parameterize ((data-segment ,env)) + ,instrs)))) + +(include "test/test-c3-compiler.rkt") +(provide get-compiled compile-tensor print-compiler? + (rename-out (cache compiler-cache))) diff --git a/lazy/tensors/test/test-0-lazy.rkt b/lazy/tensors/test/test-0-lazy.rkt new file mode 100644 index 0000000..5592dbe --- /dev/null +++ b/lazy/tensors/test/test-0-lazy.rkt @@ -0,0 +1,7 @@ +(module+ test + (require rackunit) + + (define test-nested-tensor (tensor (tensor 1 2 3) (tensor 4 5 6) (tensor 7 8 9))) + (check-true (bounded-idx*? test-nested-tensor (list 0 1))) + (check-false (bounded-idx*? test-nested-tensor (list 1 3))) + (check-false (bounded-idx*? test-nested-tensor (list 1 1 0)))) diff --git a/lazy/tensors/test/test-1-reflect.rkt b/lazy/tensors/test/test-1-reflect.rkt new file mode 100644 index 0000000..e266e65 --- /dev/null +++ b/lazy/tensors/test/test-1-reflect.rkt @@ -0,0 +1,118 @@ +(module+ test + (require rackunit) + (require ffi/vector) + (require "0-lazy.rkt") + (require "B-test-programs.rkt") + + (define evaluated-tpromise? + (λ (tp) + (or (tpromise-flat? tp) + (number? (tpromise-tensor tp))))) + + (for (((test-name test-data) (in-hash test-programs))) + (match-define (test-program-data th res) test-data) + (match res + ((eval-res-1 res) + (let* ((tp (th)) + (forced (↓ tp))) + (acc:check-tensor-equal? + forced res + (format "Expected result doesn't match in test case ~a" + test-name)) + (check-pred evaluated-tpromise? tp) + (check-equal? (tpromise-shape tp) (acc:shape forced)))) + ((eval-res-2 res1 res2) + (let*-values (((tp1 tp2) (th)) + ((forced1) (↓ tp1)) + ((forced2) (↓ tp2))) + (acc:check-tensor-equal? + forced1 res1 + (format "Expected first result doesn't match in test case ~a" + test-name)) + (check-pred evaluated-tpromise? tp1) + (check-equal? (tpromise-shape tp1) (acc:shape forced1)) + (acc:check-tensor-equal? + forced2 res2 + (format "Expected second result doesn't match in test case ~a" + test-name)) + (check-pred evaluated-tpromise? tp2) + (check-equal? (tpromise-shape tp2) (acc:shape forced2)))))) + + + (define test-tensor-r1-0 (get-test-program 'tensor-r1-0)) + (check-false (acc:flat? (tpromise-tensor test-tensor-r1-0))) + (check-true (acc:flat? (car (unbox (tpromise-dst test-tensor-r1-0))))) + (check-exn exn:fail? (λ () (tensor test-tensor-r1-0 4))) + (check-exn exn:fail? (λ () (tensor 4 test-tensor-r1-0))) + + (define test-tcomp-tref (get-test-program 'tcomp-tref)) + (check-exn exn:fail? (λ () (tref test-tensor-r1-0 5))) + + (define test-nested-tensor (get-test-program 'tensor-r2-0)) + (check-exn exn:fail? (λ () (tref (tref test-nested-tensor 2) 0))) + (check-exn exn:fail? (λ () (tref test-nested-tensor 2))) + (check-exn exn:fail? (λ () (tensor test-nested-tensor test-nested-tensor test-tensor-r1-0))) + + (check-equal? (tlen test-tensor-r1-0) 3) + (check-equal? (tlen test-nested-tensor) 2) + + (define test-nested-list->tensor + (get-test-program 'tcomp-nested-list->tensor)) + (check-equal? (tpromise-shape test-nested-list->tensor) '(3 3)) + + (define test-tcomp-partial-eval + (begin + (↓ test-nested-list->tensor) + (↓ test-nested-tensor) + (↓ test-tensor-r1-0) + (tref + (tref (tensor (tensor (tensor 1 2 3) (tensor 4 5 6) (tensor 7 8 9)) + test-nested-list->tensor + (list->tensor (list (tref test-nested-tensor 0) + (tref test-nested-tensor 1) + test-tensor-r1-0))) + 1) + 2))) + (acc:check-tensor-equal? (↓ test-tcomp-partial-eval) + (↓ (tensor 1 2 3))) + + (define test-id-scalar (get-test-program 'id-scalar)) + (define test-force-scalar + (+-ρ test-id-scalar + (get-test-program 'sum-nested))) + (void (↓ test-id-scalar)) + (acc:check-tensor-equal? (↓ test-force-scalar) + (↓ (tensor 19 21 20))) + + (define test-force-subexpr + (+-ρ (get-test-program 'id-scalar) + (get-test-program 'sum-nested))) + (define test-force-mutate + (+-ρ test-force-subexpr + (+-ρ (get-test-program 'sum-nested) + (get-test-program 'sum-nested)))) + (void (↓ test-force-subexpr)) + (acc:check-tensor-equal? (↓ test-force-mutate) + (↓ (tensor 27 33 30))) + + (define test-tp-r1 (tensor -1 -2 -3)) + (define test-force-supexpr (abs-ρ test-tp-r1)) + (void (↓ test-force-supexpr)) + (acc:check-tensor-equal? (↓ test-tp-r1) + (↓ (tensor -1 -2 -3))) + + (define test-trefs (get-test-program 'tcomp-trefs)) + (check-true (tcomp? (tpromise-tensor test-trefs))) + (check-exn exn:fail? (λ () (trefs test-nested-tensor '(0 4)))) + + (define test-reshape (get-test-program 'tcomp-reshape)) + (check-exn exn:fail? (λ () (reshape '(4 5) test-reshape))) + + (check-pred + (λ (fs) (andmap (λ (e) (integer? (sqrt e))) fs)) + (f32vector->list (acc:flat-store (↓ test-build-random))) + "Side-effect of generating random tensor must only be run once") + + (acc:check-tensor-equal? (↓ (get-test-program 'multi-built-tensor)) + (eval-res-1-res (get-test-eval-res 'multi-built-tensor))) +) diff --git a/lazy/tensors/test/test-A-equality.rkt b/lazy/tensors/test/test-A-equality.rkt new file mode 100644 index 0000000..7d5e04d --- /dev/null +++ b/lazy/tensors/test/test-A-equality.rkt @@ -0,0 +1,50 @@ +(module+ test + (require rackunit) + (require "0-lazy.rkt") + + (define t0 + (reshape '(2 3 4) + (build-tensor + '(24) + (λ (i) + (* 2.0 (car i)))))) + + + (define t1 + (reshape '(2 3 4) + (build-tensor + '(24) + (λ (i) + (* 2.000001 (car i)))))) + + (define t2 + (reshape '(1 2 3 4) + (build-tensor + '(24) + (λ (i) + (* 2.000001 (car i)))))) + + (define t3 + (reshape '(2 2 3 4) + (build-tensor + '(48) + (λ (i) + (* (quotient (car i) 24) (car i)))))) + + (define t4 + (reshape '(2 2 3 4) + (build-tensor + '(48) + (λ (i) + (- (* 2.000001 (* (quotient (car i) 24) (car i))) 48.0))))) + + (check-true (tp-tensor-equal? t0 t1)) + + (check-false (tp-tensor-equal? t0 t2)) ;; elements are equal, but shapes are not + + (check-true (tp-tensor-equal? t0 (reshape '(2 3 4) + t2))) + + (tp-check-tensor-equal? t0 t1) + + (tp-check-tensor-equal? t0 (reshape '(2 3 4) t2))) diff --git a/lazy/tensors/test/test-c2-interpreter.rkt b/lazy/tensors/test/test-c2-interpreter.rkt new file mode 100644 index 0000000..be70371 --- /dev/null +++ b/lazy/tensors/test/test-c2-interpreter.rkt @@ -0,0 +1,32 @@ +(module+ test + (require rackunit) + (require "B-test-programs.rkt") + (require (prefix-in acc: "../../accelerated-tensors/tensors.rkt")) + + (for (((test-name test-data) (in-hash test-programs))) + (match-define (test-program-data th res) test-data) + (match res + ((eval-res-1 res) + (let* ((tp (th)) + (interped (interp-tensor tp))) + (acc:check-tensor-equal? + interped res + (format "Expected result doesn't match in test case ~a" + test-name)) + (check-equal? (tpromise-shape tp) (acc:shape interped)))) + ((eval-res-2 res1 res2) + (let*-values (((tp1 tp2) (th)) + ((interped1) (interp-tensor tp1)) + ((interped2) (interp-tensor tp2))) + (acc:check-tensor-equal? + interped1 res1 + (format "Expected first result doesn't match in test case ~a" + test-name)) + (check-equal? (tpromise-shape tp1) (acc:shape interped1)) + (acc:check-tensor-equal? + interped2 res2 + (format "Expected second result doesn't match in test case ~a" + test-name)) + (check-equal? (tpromise-shape tp2) (acc:shape interped2)))))) + +) diff --git a/lazy/tensors/test/test-c3-compiler.rkt b/lazy/tensors/test/test-c3-compiler.rkt new file mode 100644 index 0000000..6bae788 --- /dev/null +++ b/lazy/tensors/test/test-c3-compiler.rkt @@ -0,0 +1,213 @@ +(module+ test + (require rackunit) + (require "B-test-programs.rkt") + (require "0-lazy.rkt") + (require "c2-interpreter.rkt") + (require (prefix-in acc: (only-in "../../accelerated-tensors/autodiff.rkt" + make-printable))) + (require (prefix-in acc: "../../accelerated-tensors/tensors.rkt")) + + (define current-test-program-name (make-parameter #f)) + (define-check (check-compiler-invariants tp) + (define ds (dst->data-segment (tpromise-dst tp))) + (define signature (sign (tpromise-sign tp))) + (define interp-tp (interp-tensor tp)) + (with-check-info + (('data-segment ds) + ('signature signature) + ('input-computation (tpromise-tensor tp)) + ('expected-interpretation (acc:make-printable interp-tp)) + ('test-name (current-test-program-name))) + (for ((d ds)) + (unless (or (number? d) + (acc:flat? d) + (eqv? d 'uncalculated)) + (fail-check (format (string-append "Data segment should only contain flat tensors " + ", the symbol 'uncalculated or numbers." + " Found: ~a") + d)))) + (parameterize ((cache (make-hash))) + (let* ((instrs-dsr (generate-ds-refs tp)) + (interp-dsr (interp-tensor instrs-dsr))) + (unless (acc:tensor-equal? interp-dsr interp-tp) + (fail-check (format + (string-append + "Result of interpreting pass generate-ds-ref doesn't" + " match expected interpretation. Actual " + "interpretation: ~a~n") + interp-dsr))) + (let ((counter (count-references instrs-dsr))) + (let* ((extracted (extract-common-subexpressions instrs-dsr counter)) + (interp-extracted (interp-tensor extracted))) + (unless (acc:tensor-equal? interp-extracted interp-tp) + (fail-check (format + (string-append + "Result of interpreting pass" + " extract-common-subexpression doesn't" + " match expected interpretation. Actual " + "interpretation: ~a~n") + (acc:make-printable interp-extracted)))) + (let* ((gr (generate-racket extracted)) + (rkt (compile-racket gr)) + (interp-rkt (interp-racket rkt ds))) + (unless (acc:tensor-equal? interp-rkt interp-tp) + (fail-check (format + (string-append + "Result of interpreting compiled racket code doesn't" + " match expected interpretation. Actual " + "interpretation: ~a~n") + (acc:make-printable interp-rkt)))) + (hash-set! (cache) signature rkt) + (compile-tensor tp) + (unless (eqv? (hash-count (cache)) 1) + (fail-check (format + (string-append + "Compiling the same tpromise again shouldn't" + " change the number of entries in the cache." + " Number of cache entries: ~a~n") + (hash-count (cache)))))))))))) + + (for (((test-name test-data) (in-hash test-programs))) + (match-define (test-program-data th res) test-data) + (parameterize ((current-test-program-name test-name)) + (match res + ((eval-res-1 res) + (let* ((tp (th))) + (check-compiler-invariants tp))) + ((eval-res-2 res1 res2) + (let*-values (((tp1 tp2) (th))) + (check-compiler-invariants tp1) + (check-compiler-invariants tp2)))))) + + (define-check (check-signatures-equal? t1 t2) + (let ((sig1 (tpromise-sign t1)) + (sig2 (tpromise-sign t2))) + (with-check-info + (('signature-1 sig1) + ('signature-2 sig2)) + (unless (equal? sig1 sig2) + (fail-check "signature mismatch"))))) + + (define-check (check-signatures-not-equal? t1 t2) + (let ((sig1 (tpromise-sign t1)) + (sig2 (tpromise-sign t2))) + (with-check-info + (('signature-1 sig1) + ('signature-2 sig2)) + (when (equal? sig1 sig2) + (fail-check "signatures musn't match"))))) + + (define test-tensor-r1-1 (get-test-program 'tensor-r1-1)) + (define test-tcomp-tref (get-test-program 'tcomp-tref)) + (check-signatures-equal? test-tcomp-tref + (make-tref-test-program test-tensor-r1-1)) + (check-signatures-not-equal? test-tcomp-tref + (make-list->tensor-test-program `(,test-tensor-r1-1))) + + (define tensor-r1 (get-test-program 'tensor-r1-0)) + (check-signatures-equal? (*-ρ 2 tensor-r1) (*-ρ 3 tensor-r1)) + + (define v^ (random-tensor (list 10 4))) + (define r^ (random-tensor (list 10 4 2))) + (check-signatures-equal? mean-v (mean v^)) + (check-signatures-equal? (mean (get-test-program 'tensor-r2-0)) + (mean (tensor (tensor 12 23 44) + (tensor 23 46 57)))) + (check-signatures-equal? (mean (get-test-program 'tensor-r2-0)) + (mean (tensor (tensor 12 23 44) + (tensor 23 46 57) + (tensor 67 32 58)))) + (check-signatures-not-equal? (mean (get-test-program 'tensor-r2-0)) + (mean (reshape '(2 3) (tensor 1 2 3 4 5 6)))) + (check-signatures-equal? variance-v (variance v^)) + (check-signatures-equal? mean-r (mean r^)) + (check-signatures-equal? variance-r (variance r^)) + (check-signatures-not-equal? mean-v mean-r) + (check-signatures-not-equal? mean-v variance-v) + (check-signatures-not-equal? variance-v mean-r) + (check-signatures-equal? (+-ρ mean-v (tensor 0 1 2 3 4 5 6 7 8 9)) + (+-ρ (mean v^) (tensor 0 1 2 3 4 5 6 7 8 9))) + + (let ((a 2) + (b 3)) + (let*-values (((da- db-) (d- a b 1.0)) + ((da+ db+) (d+ a b 1.0))) + (check-signatures-not-equal? da- da+) + (check-signatures-not-equal? db- db+))) + + (let-values (((rkt ds) (compile-tensor (get-test-program 'extract-ds-once-tref)))) + (check-pred + (λ (ds) + (eqv? (set-count (list->seteq (vector->list ds))) 2)) + ds + (string-append "eq? equivalent flat tensors and tref indices" + " used to construct the source AST must" + " be eq? equivalent in the data segment as well."))) + (let-values (((rkt ds) (compile-tensor (get-test-program 'extract-ds-once-trefs)))) + (check-pred + (λ (ds) + (eqv? (set-count (list->seteq (vector->list ds))) 2)) + ds + (string-append "eq? equivalent flat tensors and trefs index lists" + " used to construct the source AST must" + " be eq? equivalent in the data segment as well."))) + + (define count-tcomp-var + (λ (tp) + (ctv-tcomp (tpromise-tensor tp)))) + + (define ctv-tcomp + (λ (tc) + (match tc + ((? number?) 0) + [(tcomp-list->tensor lst) + (for/sum + ((l lst)) + (cond + ((tpromise? l) (count-tcomp-var l)) + ((number? l) 0) + (else (error 'cdsr-list->tensor "Unexpected: ~a" l))))] + [(tcomp-tref tp _) (count-tcomp-var tp)] + [(tcomp-trefs tp _) (count-tcomp-var tp)] + [(tcomp-ext2-∇ fᵈ _ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z + out-ref0 + out-ref1 i) + (let ((c0 (count-tcomp-var tp-t0)) + (c1 (count-tcomp-var tp-t1)) + (cz (count-tcomp-var tp-z))) + (+ c0 c1 cz))] + [(tcomp-ext1-∇ tp zp f _ sign m shape-fn) + (let ((ct (count-tcomp-var tp)) + (cz (count-tcomp-var zp))) + (+ ct cz))] + [(tcomp-ext2-ρ-scalar f _ sign tp-t tp-u) + (let ((ct (count-tcomp-var tp-t)) + (cu (count-tcomp-var tp-u))) + (+ ct cu))] + [(tcomp-ext2-ρ tp-t tp-u f _ sign m n shape-fn) + (let ((ct (count-tcomp-var tp-t)) + (cu (count-tcomp-var tp-u))) + (+ ct cu))] + [(tcomp-ext1-ρ-scalar f _ sign tp) (count-tcomp-var tp)] + [(tcomp-ext1-ρ f _ sign m shape-fn tp) (count-tcomp-var tp)] + [(tcomp-reshape s tp) (count-tcomp-var tp)] + [(tcomp-ds-ref i) 0] + [(tcomp-let lhs rhs body) + (let ((cr (count-tcomp-var rhs)) + (cb (count-tcomp-var body))) + (+ cr cb))] + [(tcomp-var name) 1]))) + + (define get-common-subexprs + (λ (tp) + (let ((instrs (generate-ds-refs tp))) + (extract-common-subexpressions instrs (count-references instrs))))) + + (check-equal? + (count-tcomp-var (get-common-subexprs (get-test-program 'common-subexpression))) + 2) + (check-equal? + (count-tcomp-var + (get-common-subexprs (get-test-program 'nested-common-subexpression))) + 2) +) diff --git a/learner.rkt b/learner.rkt index 0068aa4..e49333d 100644 --- a/learner.rkt +++ b/learner.rkt @@ -13,6 +13,8 @@ (error "ext2-∇ is not provided by the learner implementation"))) (provide + tolerance + len ref refr tref tlen tmap list->tensor tensor build-tensor @@ -39,7 +41,7 @@ rectify flatten concat concat-n +-ρ --ρ *-ρ /-ρ rectify-ρ - exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ sum-ρ abs-ρ *-2-1-ρ argmax-ρ max-ρ sum-cols-ρ correlate-ρ flatten-ρ concat-ρ diff --git a/learner/ext-ops.rkt b/learner/ext-ops.rkt index 5db1276..d1c8413 100644 --- a/learner/ext-ops.rkt +++ b/learner/ext-ops.rkt @@ -21,7 +21,7 @@ +-ρ --ρ *-ρ /-ρ expt-ρ exp-ρ log-ρ abs-ρ - sqrt-ρ sqr-ρ) + sqrt-ρ sqr-ρ zeroes-ρ) (provide =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/learner/ext-ops/J-nd-ops.rkt b/learner/ext-ops/J-nd-ops.rkt index 107e896..fff3489 100644 --- a/learner/ext-ops/J-nd-ops.rkt +++ b/learner/ext-ops/J-nd-ops.rkt @@ -35,9 +35,12 @@ (λ (x) (*-ρ x x))) +(define zeroes-ρ + (ext1 (λ (_) 0.0) 0)) + (provide +-ρ --ρ *-ρ /-ρ expt-ρ exp-ρ log-ρ abs-ρ - rectify-ρ sqrt-ρ sqr-ρ) + rectify-ρ sqrt-ρ sqr-ρ zeroes-ρ) (define *-2-1-ρ (ext2 *-ρ 2 1)) diff --git a/learner/no-duals-no-overrides.rkt b/learner/no-duals-no-overrides.rkt index 5b40500..ee52a1b 100644 --- a/learner/no-duals-no-overrides.rkt +++ b/learner/no-duals-no-overrides.rkt @@ -18,7 +18,7 @@ ;; From ext-ops +-ρ --ρ *-ρ /-ρ rectify-ρ - exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ sum-ρ abs-ρ *-2-1-ρ argmax-ρ max-ρ sum-cols-ρ correlate-ρ flatten-ρ concat-ρ concat-n-ρ diff --git a/learner/no-duals.rkt b/learner/no-duals.rkt index b20d585..361dd22 100644 --- a/learner/no-duals.rkt +++ b/learner/no-duals.rkt @@ -21,7 +21,7 @@ ;; From ext-ops (rename-out (+-ρ +) (--ρ -) (*-ρ *) (/-ρ /) (rectify-ρ rectify) - (exp-ρ exp) (log-ρ log) (expt-ρ expt) (sqrt-ρ sqrt) (sqr-ρ sqr) + (exp-ρ exp) (log-ρ log) (expt-ρ expt) (sqrt-ρ sqrt) (sqr-ρ sqr) (zeroes-ρ zeroes) (sum-ρ sum) (abs-ρ abs) (*-2-1-ρ *-2-1) (argmax-ρ argmax) (max-ρ max) (sum-cols-ρ sum-cols) (correlate-ρ correlate) (flatten-ρ flatten) (concat-ρ concat) (concat-n-ρ concat-n)) diff --git a/learner/no-overrides.rkt b/learner/no-overrides.rkt index d2d521f..13640e5 100644 --- a/learner/no-overrides.rkt +++ b/learner/no-overrides.rkt @@ -33,7 +33,7 @@ (rename-out (concat-n d-concat-n)) +-ρ --ρ *-ρ /-ρ rectify-ρ - exp-ρ log-ρ expt-ρ sqrt-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ sum-ρ abs-ρ *-2-1-ρ argmax-ρ max-ρ sum-cols-ρ correlate-ρ flatten-ρ concat-ρ diff --git a/malted/A-core.rkt b/malted/A-core.rkt index 6933b20..4139f6b 100644 --- a/malted/A-core.rkt +++ b/malted/A-core.rkt @@ -1,6 +1,7 @@ #lang racket (require "../base.rkt") +(require (only-in "../lazy/tensors.rkt" ↓)) (define dot-product (λ (w t) @@ -14,4 +15,4 @@ (include "test/test-A-core.rkt") -(provide dot-product dot-product-2-1) +(provide dot-product dot-product-2-1 ↓) diff --git a/malted/D-gradient-descent.rkt b/malted/D-gradient-descent.rkt index 33a3d28..259fe3d 100644 --- a/malted/D-gradient-descent.rkt +++ b/malted/D-gradient-descent.rkt @@ -12,12 +12,30 @@ (declare-hyper revs) (declare-hyper alpha) +;;TODO: abstract away the lazy implementation specific ↓ using a with-aspect + +;; For lazy implementation +#; +(define-syntax with-aspect + (syntax-rules () + [(_ 'gd-update f) + (lambda (pa g) (map* ↓ (f pa g)))] + [(_ _ f) f])) + +;; For other implementations +#; +(define-syntax with-aspect + (syntax-rules () + [(_ _ f) + f])) + (define gradient-descent (lambda (inflate deflate update) (λ (obj theta) (let ((ctr 0)) (let ((f (λ (big-theta) - (map update + (map #;(with-aspect 'gd-update update) + (lambda (pa g) (map* ↓ (update pa g))) big-theta (gradient-of obj (map deflate big-theta)))))) diff --git a/malted/E-gd-common.rkt b/malted/E-gd-common.rkt index 35a9549..84a75f3 100644 --- a/malted/E-gd-common.rkt +++ b/malted/E-gd-common.rkt @@ -3,9 +3,6 @@ ;; Extended operators are non-dualized (require "../base-no-duals.rkt") -(define zeroes - (ext1-ρ (λ (_) 0.0) 0)) - (define smooth (λ (decay-rate average g) (+ (* decay-rate average) diff --git a/malted/I-adam.rkt b/malted/I-adam.rkt index 818e300..bb72940 100644 --- a/malted/I-adam.rkt +++ b/malted/I-adam.rkt @@ -19,7 +19,7 @@ (let ((r (smooth beta (ref pa 2) (sqr g)))) (let ((alpha-hat (/ alpha (+ (sqrt r) epsilon))) (v (smooth mu (ref pa 1) g))) - (list (- (ref pa 0) (* alpha-hat v)) v r))))) + (list (- (ref pa 0) (* alpha-hat v)) v r))))) (define adam-gradient-descent (gradient-descent diff --git a/malted/test/test-A-core.rkt b/malted/test/test-A-core.rkt index 79d363d..596ebde 100644 --- a/malted/test/test-A-core.rkt +++ b/malted/test/test-A-core.rkt @@ -1,5 +1,6 @@ (module+ test (require rackunit) + (require "../base.rkt") (let ((a 7) (b 13)) @@ -39,9 +40,11 @@ (tensor -0.04142 -0.03111)))) (let ((a (tensor 7 8 9))) - (check-dual-equal? (exp a) (tensor 1096.6331 2980.9579 8103.0839)) - (check-dual-equal? ((∇¹ exp) a) - (list (tensor 1096.6331 2980.9579 8103.0839))) + ;; Lower tolerance because the openCL implementation of exp gives a slightly different answer + (parameterize ((tolerance 0.001)) + (check-dual-equal? (exp a) (tensor 1096.6332 2980.9579 8103.0839)) + (check-dual-equal? ((∇¹ exp) a) + (list (tensor 1096.6332 2980.9579 8103.0839)))) (check-dual-equal? (log a) (tensor 1.9459 2.0794 2.1972)) (check-dual-equal? ((∇¹ log) a) (list (tensor 0.1428 0.125 0.1111))) (check-dual-equal? (sqrt a) (tensor 2.6457 2.8284 3.0)) diff --git a/malted/test/test-D-gradient-descent.rkt b/malted/test/test-D-gradient-descent.rkt index b21193e..554cb57 100644 --- a/malted/test/test-D-gradient-descent.rkt +++ b/malted/test/test-D-gradient-descent.rkt @@ -12,8 +12,9 @@ (define naked-gd (gradient-descent id id (λ (theta g) (--ρ theta (*-ρ g alpha))))) - (check-equal? + (check-within (with-hypers ((revs 500) (alpha 0.01)) (naked-gd obj (list 3.0))) - '(29.998892352401082))) + '(29.998892352401082) + (tolerance))) diff --git a/malted/test/test-E-gd-common.rkt b/malted/test/test-E-gd-common.rkt index 092a062..f105a4b 100644 --- a/malted/test/test-E-gd-common.rkt +++ b/malted/test/test-E-gd-common.rkt @@ -1,12 +1,13 @@ (module+ test (require rackunit) + (require "../impl.rkt") - (check-equal? (zeroes (tensor 1 2 3)) + (check-dual-equal? (zeroes (tensor 1 2 3)) (tensor 0.0 0.0 0.0)) - (check-equal? (smooth 0.9 31 -8) 27.1) + (check-dual-equal? (smooth 0.9 31 -8) 27.1) - (check-equal? (smooth 0.9 27.1 4) 24.79) + (check-dual-equal? (smooth 0.9 27.1 4) 24.79) (with-hypers ((mu 0.5) (beta 0.3)) - (check-equal? (+ mu beta) 0.8))) + (check-dual-equal? (+ mu beta) 0.8))) diff --git a/malted/test/test-F-naked.rkt b/malted/test/test-F-naked.rkt index 2f0ebf7..d59ad59 100644 --- a/malted/test/test-F-naked.rkt +++ b/malted/test/test-F-naked.rkt @@ -5,8 +5,9 @@ (define obj (λ (theta) (sqr (- 30 (ref theta 0))))) - (check-equal? + (check-within (with-hypers ((revs 400) (alpha 0.01)) (naked-gradient-descent obj (list 3.0))) - '(29.991647931623252))) + '(29.991647931623252) + (tolerance))) diff --git a/malted/test/test-G-velocity.rkt b/malted/test/test-G-velocity.rkt index 8502003..154de9d 100644 --- a/malted/test/test-G-velocity.rkt +++ b/malted/test/test-G-velocity.rkt @@ -5,9 +5,10 @@ (define obj (λ (theta) (sqr (- 30 (ref theta 0))))) - (check-dual-equal? + (check-within (with-hypers ((revs 70) (alpha 0.01) (mu 0.9)) (velocity-gradient-descent obj (list 3.0))) - '(30.686162582787535))) + '(30.686162582787535) + (tolerance))) diff --git a/malted/test/test-H-rms.rkt b/malted/test/test-H-rms.rkt index 750352b..cb7e84b 100644 --- a/malted/test/test-H-rms.rkt +++ b/malted/test/test-H-rms.rkt @@ -5,10 +5,10 @@ (define obj (λ (theta) (sqr (- 30 (ref theta 0))))) - (check-dual-equal? + (check-within (with-hypers ((revs 170) (alpha 0.1) (beta 0.999)) (rms-gradient-descent obj (list 3.0))) - '(29.990436450964964)) - ) + '(29.990436450964964) + (tolerance))) diff --git a/malted/test/test-O-init.rkt b/malted/test/test-O-init.rkt index 5b142c0..0316e13 100644 --- a/malted/test/test-O-init.rkt +++ b/malted/test/test-O-init.rkt @@ -1,17 +1,24 @@ (module+ test (require rackunit) + (require (only-in "../base.rkt" ρ)) + (define v (init-shape (list 1000 4))) - (define mean-v (abs (/ (sum (sum v)) 4000))) - (define variance-v (- (/ (sum (sum (* v v))) 4000) (* mean-v mean-v))) - (check-true (< mean-v 0.05)) - (check-true (and (>= variance-v 0.4) - (<= variance-v 0.6))) + (define mean-v + (abs (/ (sum (sum v)) 4000))) + (define variance-v + (- (/ (sum (sum (* v v))) 4000) (* mean-v mean-v))) + (check-true (< (ρ mean-v) 0.05)) + (check-true (let ((forced (ρ variance-v))) + (and (>= forced 0.4) + (<= forced 0.6)))) ;; Here variance will be 2/8 = 0.25 (define r (init-shape (list 1000 4 2))) (define mean-r (abs (/ (sum (sum (sum r))) 8000))) - (define variance-r (- (/ (sum (sum (sum (* r r)))) 8000) (* mean-r mean-r))) + (define variance-r (- (/ (sum (sum (sum (* r r)))) 8000) + (* mean-r mean-r))) - (check-true (< mean-r 0.05)) - (check-true (and (>= variance-r 0.22) - (<= variance-r 0.28)))) + (check-true (< (ρ mean-r) 0.05)) + (check-true (let ((forced (ρ variance-r))) + (and (>= forced 0.22) + (<= forced 0.28))))) diff --git a/nested-tensors.rkt b/nested-tensors.rkt index 277061c..b1db486 100644 --- a/nested-tensors.rkt +++ b/nested-tensors.rkt @@ -8,6 +8,8 @@ (require "nested-tensors/ext-ops.rkt") (provide + tolerance + len ref refr tref tlen tmap list->tensor tensor build-tensor @@ -30,7 +32,7 @@ (d-flatten flatten) (d-concat concat) (d-concat-n concat-n)) +-ρ --ρ *-ρ /-ρ rectify-ρ - exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ sum-ρ abs-ρ *-2-1-ρ argmax-ρ max-ρ sum-cols-ρ correlate-ρ flatten-ρ diff --git a/nested-tensors/ext-ops.rkt b/nested-tensors/ext-ops.rkt index 41428e1..473f37f 100644 --- a/nested-tensors/ext-ops.rkt +++ b/nested-tensors/ext-ops.rkt @@ -19,7 +19,7 @@ +-ρ --ρ *-ρ /-ρ expt-ρ exp-ρ log-ρ abs-ρ - rectify-ρ sqrt-ρ sqr-ρ) + rectify-ρ sqrt-ρ sqr-ρ zeroes-ρ) (provide =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/nested-tensors/ext-ops/A-scalar-ops.rkt b/nested-tensors/ext-ops/A-scalar-ops.rkt index 3e98bd1..bcd1596 100644 --- a/nested-tensors/ext-ops/A-scalar-ops.rkt +++ b/nested-tensors/ext-ops/A-scalar-ops.rkt @@ -119,6 +119,9 @@ (λ (x) (*-ρ x x))) +(define zeroes-ρ + (ext1-ρ (λ (_) 0.0) 0)) + (include "test/test-A-scalar-ops.rkt") (provide d+ d- d* d/ @@ -130,4 +133,4 @@ +-ρ --ρ *-ρ /-ρ expt-ρ exp-ρ log-ρ abs-ρ - rectify-ρ sqrt-ρ sqr-ρ) + rectify-ρ sqrt-ρ sqr-ρ zeroes-ρ) diff --git a/nested-tensors/no-duals-no-overrides.rkt b/nested-tensors/no-duals-no-overrides.rkt index c3bbad0..ac909a5 100644 --- a/nested-tensors/no-duals-no-overrides.rkt +++ b/nested-tensors/no-duals-no-overrides.rkt @@ -20,7 +20,7 @@ ;; From ext-ops +-ρ --ρ *-ρ /-ρ rectify-ρ - exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ sum-ρ abs-ρ *-2-1-ρ argmax-ρ max-ρ sum-cols-ρ correlate-ρ flatten-ρ concat-ρ concat-n-ρ diff --git a/nested-tensors/no-duals.rkt b/nested-tensors/no-duals.rkt index baa635a..d9dcf7d 100644 --- a/nested-tensors/no-duals.rkt +++ b/nested-tensors/no-duals.rkt @@ -20,7 +20,7 @@ ;; From ext-ops (rename-out (+-ρ +) (--ρ -) (*-ρ *) (/-ρ /) (rectify-ρ rectify) - (exp-ρ exp) (log-ρ log) (expt-ρ expt) (sqrt-ρ sqrt) (sqr-ρ sqr) + (exp-ρ exp) (log-ρ log) (expt-ρ expt) (sqrt-ρ sqrt) (sqr-ρ sqr) (zeroes-ρ zeroes) (sum-ρ sum) (abs-ρ abs) (*-2-1-ρ *-2-1) (argmax-ρ argmax) (max-ρ max) (sum-cols-ρ sum-cols) (correlate-ρ correlate) (flatten-ρ flatten) diff --git a/nested-tensors/no-overrides.rkt b/nested-tensors/no-overrides.rkt index 3d39c8b..417a96e 100644 --- a/nested-tensors/no-overrides.rkt +++ b/nested-tensors/no-overrides.rkt @@ -35,7 +35,7 @@ d-flatten d-concat d-concat-n +-ρ --ρ *-ρ /-ρ rectify-ρ - exp-ρ log-ρ expt-ρ sqrt-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ sum-ρ abs-ρ *-2-1-ρ argmax-ρ max-ρ sum-cols-ρ correlate-ρ flatten-ρ concat-ρ concat-n-ρ diff --git a/set-impl.rkt b/set-impl.rkt index 20f67b5..469e7b5 100644 --- a/set-impl.rkt +++ b/set-impl.rkt @@ -7,7 +7,12 @@ (define set-impl (λ (impl) - (when (not (member impl '(learner nested-tensors flat-tensors))) + (when (not (member impl '(learner + nested-tensors + flat-tensors + uniform-tensors + accelerated-tensors + lazy))) (error "Unknown implementation: ~a~%" impl)) (setup #:collections (list (list "malt")) #:clean? #t) (write-implementation-to-config-file impl) diff --git a/tools/C-logging.rkt b/tools/C-logging.rkt index 2150a6c..33780ec 100644 --- a/tools/C-logging.rkt +++ b/tools/C-logging.rkt @@ -50,7 +50,9 @@ ((eq? data 'reset) (loop 0.0 0.0 0.0 0.0 0.0 0 sampling-frequency)) ((and data (= sampling-count 1)) - (print-average (/ (sum-all d0 d1 d2 d3 d4 data) (* 6 (ρ (product (shape data))))) count) + (print-average (/ (ρ (sum-all d0 d1 d2 d3 d4 data)) + (* 6 (ρ (product (shape data))))) + count) (loop d1 d2 d3 d4 data (add1 count) sampling-frequency)) (data (loop d1 d2 d3 d4 data (add1 count) (sub1 sampling-count))) diff --git a/uniform-tensors.rkt b/uniform-tensors.rkt new file mode 100644 index 0000000..461c9f4 --- /dev/null +++ b/uniform-tensors.rkt @@ -0,0 +1,47 @@ +#lang racket/base + +(require + (except-in "uniform-tensors/tensors.rkt" + rank shape reshape tref trefs tensor? tlen ref refr)) + +(require "uniform-tensors/autodiff.rkt") +(require "uniform-tensors/ext-ops.rkt") + +(provide + tolerance + + len ref refr + + tref tlen list->tensor tensor build-tensor + + ext1-ρ ext2-ρ ext1-∇ ext2-∇ + + dual dual? ρ κ ∇ ∇¹ (rename-out (∇ gradient-of)) map* + + ext1 ext2 prim1 prim2 + + scalar? tensor? rank shape reshape trefs + + trace-print check-dual-equal? check-ρ-∇ + max-tensor-print-length make-printable + + (rename-out (d+ +) (d- -) (d* *) (d/ /) (d-rectify rectify) + (d-exp exp) (d-log log) (d-expt expt) (d-sqrt sqrt) (d-sqr sqr) + (d-sum sum) (d-abs abs) (d*-2-1 *-2-1) (d-argmax argmax) + (d-max max) (d-sum-cols sum-cols) (d-correlate correlate) + (d-flatten flatten) + (d-concat concat) (d-concat-n concat-n)) + + +-ρ --ρ *-ρ /-ρ rectify-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ + sum-ρ abs-ρ *-2-1-ρ argmax-ρ + max-ρ sum-cols-ρ correlate-ρ + flatten-ρ concat-ρ + + +-0-0 --0-0 *-0-0 /-0-0 expt-0-0 + exp-0 log-0 abs-0 rectify-0 sqrt-0 + + sum-1 argmax-1 max-1 flatten-2 concat-1-1 + + =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/uniform-tensors/autodiff.rkt b/uniform-tensors/autodiff.rkt new file mode 100644 index 0000000..68168c9 --- /dev/null +++ b/uniform-tensors/autodiff.rkt @@ -0,0 +1,23 @@ +#lang racket + +(require "autodiff/A-autodiff.rkt") +(require "autodiff/B-prims.rkt") +(require "autodiff/C-dualized-tensor-ops.rkt") +(require "autodiff/D-test-helpers.rkt") +(require "autodiff/E-print.rkt") + +(provide dual dual? ρ κ ∇ ∇¹ scalar? trace-print dual* map*) +(provide prim1 prim2 ext1 ext2) +(provide (rename-out (d-rank rank) + (d-shape shape) + (d-reshape reshape) + (d-tref tref) + (d-trefs trefs) + (d-tensor? tensor?) + (d-tlen tlen) + (d-ref ref) + (d-refr refr))) + +(provide check-dual-equal? check-ρ-∇) + +(provide max-tensor-print-length make-printable) diff --git a/uniform-tensors/autodiff/A-autodiff.rkt b/uniform-tensors/autodiff/A-autodiff.rkt new file mode 100644 index 0000000..03b9c93 --- /dev/null +++ b/uniform-tensors/autodiff/A-autodiff.rkt @@ -0,0 +1,120 @@ +#lang racket + +(require "../tensors.rkt") + +;;---------------------------- +;; Real part of a dual is always a tensor (of any rank) +;;---------------------------- + +(define dual? + (λ (x) + (and (vector? x) (eq? (vector-ref x 0) dual)))) + +(define dual + (λ (r k) + (vector dual r k))) + +(define dual* + (λ (d) + (dual (ρ d) end-of-chain))) + +(define ρ + (λ (d) + (cond + ((dual? d) (vector-ref d 1)) + (else d)))) + +(define κ + (λ (d) + (cond + ((dual? d) (vector-ref d 2)) + (else end-of-chain)))) + +(define scalar? + (λ (d) + (or (number? d) + (and (dual? d) + (number? (ρ d)))))) + +(define dual-like? + (λ (d) + (or (dual? d) + (number? d) + (vector? d)))) + +;;---------------------------- +;; Chain rule +;;---------------------------- + +(define end-of-chain + (λ (d z σ) + (let ((g (hash-ref σ d 0.0))) + (hash-set σ d (+-ρ z g))))) + +(define +-ρ + (ext2-ρ + 0 0)) + +;;---------------------------- +;; Reverse-mode AD +;;---------------------------- + +(define ∇ + (λ (f theta) + (let ((wrt (map* dual* theta))) + (∇-once (f wrt) wrt)))) + +(define ∇¹ + (λ (f) + (λ xs + (let ((wrt (map* dual* xs))) + (∇-once (apply f wrt) wrt))))) + +(define ∇-once + (λ (y wrt) + (let ((σ (∇σ y (hasheq)))) + (map* (λ (d) + (hash-ref σ d 0.0)) + wrt)))) + +(define ∇σ + (λ (y σ) + (cond + ((dual-like? y) ((κ y) y (one-like (ρ y)) σ)) + ((list? y) (∇σ-list y σ)) + (else (printf "Unknown: ~a~%" y))))) + +(define ∇σ-list + (λ (y σ) + (cond + ((null? y) σ) + (else + (let ((σ-hat (∇σ (ref y 0) σ))) + (∇σ-list (refr y 1) σ-hat)))))) + +;;---------------------------- +;; General helpers +;;---------------------------- + +(define map* + (λ (f y) + (cond + ((dual-like? y) (f y)) + ((list? y) + (map (λ (yi) + (map* f yi)) + y)) + (else y)))) + +(define trace-print + (λ (v port) + (cond + ((dual? v) (trace-print (ρ v) port)) + (else (fprintf port "~a~%" v))))) + +(define (one-like s) ((ext1-ρ (λ (x) 1.0) 0) s)) + +(include "test/test-A-autodiff.rkt") + +(provide + dual dual? ρ κ ∇ ∇¹ dual* scalar? end-of-chain map* + trace-print) diff --git a/uniform-tensors/autodiff/B-prims.rkt b/uniform-tensors/autodiff/B-prims.rkt new file mode 100644 index 0000000..277e9ea --- /dev/null +++ b/uniform-tensors/autodiff/B-prims.rkt @@ -0,0 +1,221 @@ +#lang racket + +(require "../tensors/0-vectors.rkt") +(require "../tensors.rkt") +(require "A-autodiff.ss") + +(define ρ-function + (λ (f) (f ρ-function))) + +(define ∇-function + (λ (f) (f ∇-function))) + +(define shape-fn + (λ (f) (f shape-fn))) + +;; For flat tensors, ρ-fn and ∇-fn +;; are of two types: functional and pre-allocated +;; When they are functional, they return values +;; When they are pre-allocated, they expect expect the +;; return flat-store to be pre-allocated, and simply +;; operate as fillers. +;; +;; Pre-allocated ρ and ∇ have arities +;; 6 and 7 for unary ops, and 9 and 10 for binary ops. +;; We test for this arity to determine the type. +;; +;; Generally speaking, scalar operations are functional +;; and vector operations are pre-allocated. +;; +;; The functions ensure-ρ-callable-1, ensure-∇-callable-1 +;; and ensure-ρ-callable-2, ensure-∇-callable-2 provide +;; the preallocation for flat-stores when a vector-op is +;; provided, but the invocation of prim1 expects functional +;; results. +;; + +(define prim1 + (λ (ρ-fn ∇-fn [shape (λ (l . r) l)]) + (let ((ρ-callable (ensure-ρ-callable-1 ρ-fn shape)) + (∇-callable (ensure-∇-callable-1 ∇-fn shape))) + (λ (daf) + (cond + ((eq? daf ρ-function) ρ-fn) + ((eq? daf ∇-function) ∇-fn) + ((eq? daf shape-fn) shape) + (else (prim1-dual ρ-callable ∇-callable daf))))))) + +(define prim1-dual + (λ (ρ-fn ∇-fn da) + (let ((ra (ρ da))) + (dual (ρ-fn ra) + (λ (d z σ) + (let ((ga (∇-fn ra z))) + ((κ da) da ga σ))))))) + +(define prim2 + (λ (ρ-fn ∇-fn [shape (λ (l . r) l)]) + (let ((ρ-callable (ensure-ρ-callable-2 ρ-fn shape)) + (∇-callable (ensure-∇-callable-2 ∇-fn shape))) + (λ ds + (let ((daf (ref ds 0))) + (cond + ((eq? daf ρ-function) ρ-fn) + ((eq? daf ∇-function) ∇-fn) + ((eq? daf shape-fn) shape) + (else (prim2-dual ρ-callable ∇-callable daf (ref ds 1))))))))) + +(define prim2-dual + (λ (ρ-fn ∇-fn da db) + (let ((ra (ρ da)) + (rb (ρ db))) + (dual (ρ-fn ra rb) + (λ (d z σ) + (let-values (((ga gb) (∇-fn ra rb z))) + (let ((σ-hat ((κ da) da ga σ))) + ((κ db) db gb σ-hat)))))))) + +;;---------------------------- +;; Managing flat-optimized and +;; non-flat ρ and ∇ functions +;;---------------------------- + +(define ensure-ρ-callable-1 + (λ (ρ-fn shape-fn) + (cond + ((expects-preallocated? ρ-fn) + (λ (ra) + (apply-flat-ρ-fn-1 ρ-fn ra shape-fn))) + (else ρ-fn)))) + +(define ensure-∇-callable-1 + (λ (∇-fn shape-fn) + (cond + ((expects-preallocated? ∇-fn) + (λ (ra z) + (apply-flat-∇-fn-1 ∇-fn ra z shape-fn))) + (else ∇-fn)))) + +(define ensure-ρ-callable-2 + (λ (ρ-fn shape-fn) + (cond + ((expects-preallocated? ρ-fn) + (λ (ra rb) + (apply-flat-ρ-fn-2 ρ-fn ra rb shape-fn))) + (else ρ-fn)))) + +(define ensure-∇-callable-2 + (λ (∇-fn shape-fn) + (cond + ((expects-preallocated? ∇-fn) + (λ (ra rb z) + (apply-flat-∇-fn-1 ∇-fn ra rb z shape-fn))) + (else ∇-fn)))) + +(define apply-flat-ρ-fn-1 + (λ (ρ-fn ra shape-fn) + (let* ((in-shape (flat-shape ra)) + (in-size (size-of in-shape)) + (out-shape (shape-fn in-shape)) + (out-size (size-of out-shape))) + (cond + ((null? out-shape) + (let ((v-out (new-vec 1 0.0))) + (ρ-fn (flat-store ra) (flat-offset ra) in-size + v-out 0 1) + (vref v-out 0))) + (else + (let ((v-out (new-vec out-size 0.0))) + (ρ-fn (flat-store ra) (flat-offset ra) in-size + v-out 0 out-size) + (flat out-shape v-out 0))))))) + +(define apply-flat-∇-fn-1 + (λ (∇-fn ra z shape-fn) + (let* ((in-shape (flat-shape ra)) + (in-size (size-of in-shape)) + (out-shape (shape-fn in-shape)) + (out-size (size-of out-shape))) + (let ((g (new-vec in-size 0.0))) + (cond + ((null? out-shape) + (let ((v-z (new-vec 1 z))) + (∇-fn g (flat-store ra) (flat-offset ra) in-size + v-z 0 1) + (flat in-shape g 0))) + (else + (∇-fn g (flat-store ra) (flat-offset ra) in-size + (flat-store z) (flat-offset z) out-size) + (flat in-shape g 0))))))) + +(define apply-flat-ρ-fn-2 + (λ (ρ-fn ra rb shape-fn) + (let* ((in-shape-a (flat-shape ra)) + (in-size-a (size-of in-shape-a)) + (in-shape-b (flat-shape rb)) + (in-size-b (size-of in-shape-b)) + (out-shape (shape-fn in-shape-a in-shape-b)) + (out-size (size-of out-shape))) + (cond + ((null? out-shape) + (let ((v-out (new-vec 1 0.0))) + (ρ-fn + (flat-store ra) (flat-offset ra) in-size-a + (flat-store rb) (flat-offset rb) in-size-b + v-out 0 1) + (vref v-out 0))) + (else + (let ((v-out (new-vec out-size 0.0))) + (ρ-fn + (flat-store ra) (flat-offset ra) in-size-a + (flat-store rb) (flat-offset rb) in-size-b + v-out 0 out-size) + (flat out-shape v-out 0))))))) + +(define apply-flat-∇-fn-2 + (λ (∇-fn ra rb z shape-fn) + (let* ((in-shape-a (flat-shape ra)) + (in-size-a (size-of in-shape-a)) + (in-shape-b (flat-shape rb)) + (in-size-b (size-of in-shape-b)) + (out-shape (shape-fn in-shape-a in-shape-b)) + (out-size (size-of out-shape))) + (let ((g0 (new-vec in-size-a 0.0)) + (g1 (new-vec in-size-b 0.0))) + (cond + ((null? out-shape) + (let ((v-z (new-vec 1 z))) + (∇-fn g0 g1 + (flat-store ra) (flat-offset ra) in-size-a + (flat-store rb) (flat-offset rb) in-size-b + v-z 0 1) + (values + (flat in-shape-a g0 0) + (flat in-shape-b g1 0)))) + (else + (∇-fn g0 g1 + (flat-store ra) (flat-offset ra) in-size-a + (flat-store rb) (flat-offset rb) in-size-b + (flat-store z) (flat-offset z) out-size) + (values + (flat in-shape-a g0 0) + (flat in-shape-b g1 0)))))))) + +;;---------------------------- +;; Dualized tensor op creators +;;---------------------------- +(define ext1 + (λ (f n) + (prim1 + (ext1-ρ (ρ-function f) n (shape-fn f)) + (ext1-∇ (∇-function f) n (shape-fn f)) + (shape-fn f)))) + +(define ext2 + (λ (f m n) + (prim2 + (ext2-ρ (ρ-function f) m n (shape-fn f)) + (ext2-∇ (∇-function f) m n (shape-fn f)) + (shape-fn f)))) + +(provide prim1 prim2 ext1 ext2) diff --git a/uniform-tensors/autodiff/C-dualized-tensor-ops.rkt b/uniform-tensors/autodiff/C-dualized-tensor-ops.rkt new file mode 100644 index 0000000..5b02ba3 --- /dev/null +++ b/uniform-tensors/autodiff/C-dualized-tensor-ops.rkt @@ -0,0 +1,51 @@ +#lang racket + +(require "../tensors.rkt") +(require "A-autodiff.ss") + + +;;---------------------------- +;; Tensor ops, cleaned up. +;;---------------------------- + +(define d-rank + (lambda (t) + (rank (ρ t)))) + +(define d-shape + (λ (t) + (shape (ρ t)))) + +(define d-reshape + (λ (s t) + (cond + ((dual? t) + (dual (reshape s (ρ t)) + (κ t))) + (else (reshape s t))))) + +(define d-trefs + (λ (t b) + (trefs (ρ t) b))) + +(define d-tref + (λ (t i) + (tref (ρ t) i))) + +(define d-tensor? + (λ (t) + (tensor? (ρ t)))) + +(define d-tlen + (λ (t) + (tlen (ρ t)))) + +(define d-ref + (λ (l i) + (ref l (ρ i)))) + +(define d-refr + (λ (l i) + (refr l (ρ i)))) + +(provide d-rank d-shape d-reshape d-trefs d-tensor? d-tlen d-ref d-refr d-tref) diff --git a/uniform-tensors/autodiff/D-test-helpers.rkt b/uniform-tensors/autodiff/D-test-helpers.rkt new file mode 100644 index 0000000..5b21797 --- /dev/null +++ b/uniform-tensors/autodiff/D-test-helpers.rkt @@ -0,0 +1,47 @@ +#lang racket + +(require "../tensors.rkt") +(require "A-autodiff.ss") + +(require rackunit) + +(define-binary-check (check-dual-equal? equal-wt? actual expected)) +(define-check (ρ-∇-checker fn args ans grads) + (let* ((y (apply fn args)) + (g (apply (∇¹ fn) args))) + (cond + ((and (equal-wt? ans (ρ y)) + (equal-wt? grads (ρ g))) (void)) + ((equal-wt? ans (ρ y)) + (fail-check (format "Gradients failed to match.~%actual:~%~s~%expected:~s~%" + (ρ g) grads))) + (else + (fail-check (format "Answers failed to match.~%actual:~%~s~%expected:~s~%" + (ρ y) ans)))))) + +(define-syntax check-ρ-∇ + (syntax-rules () + [(check-both (fn args ...) ans grads) + (ρ-∇-checker fn (list args ...) ans grads)])) + +(define equal-wt? + (λ (a b) + (cond + ((and (tensor? a) (tensor? b)) + (tensor-equal? a b)) + ((dual? a) (equal-wt? (ρ a) b)) + ((dual? b) (equal-wt? a (ρ b))) + ((and (vector? a) (vector? b) + (= (vector-length a) (vector-length b))) + (vector-andmap equal-wt? a b)) + ((and (pair? a) (pair? b) + (= (length a) (length b))) + (andmap equal-wt? a b)) + (else (equal? a b))))) + +(define vector-andmap + (λ (f v1 v2) + (for/fold ([s #t]) ([v1 v1][v2 v2]) + (and s (f v1 v2))))) + +(provide check-dual-equal? check-ρ-∇) diff --git a/uniform-tensors/autodiff/E-print.rkt b/uniform-tensors/autodiff/E-print.rkt new file mode 100644 index 0000000..7f4090e --- /dev/null +++ b/uniform-tensors/autodiff/E-print.rkt @@ -0,0 +1,87 @@ +#lang racket + +(require "../tensors/0-vectors.rkt") +(require "A-autodiff.rkt") +(require "../tensors.rkt") + +(define max-tensor-print-length (make-parameter 5)) + +(struct fake-tensor (members) + #:transparent + #:methods gen:custom-write + ((define write-proc + (λ (fake-tensor port mode) + (let ((n (length (fake-tensor-members fake-tensor)))) + (case mode + ((#t) + (display "(tensor " port) + (for ([m (fake-tensor-members fake-tensor)] + [c (in-range 0 n)]) + (if (symbol? m) + (display m port) + (write m port)) + (when (< c (- n 1)) + (display " " port))) + (display ")" port)) + ((#f) + (display "(tensor " port) + (for ([m (fake-tensor-members fake-tensor)] + [c (in-range 0 n)]) + (display m port) + (when (< c (- n 1)) + (display " " port))) + (display ")" port)) + (else + (display "(tensor " port) + (for ([m (fake-tensor-members fake-tensor)] + [c (in-range 0 n)]) + (if (symbol? m) + (display m port) + (print m port mode)) + (when (< c (- n 1)) + (display " " port))) + (display ")" port)))))))) + +(define make-printable + (λ (y [max-length (max-tensor-print-length)]) + (cond + ((dual? y) (make-printable (ρ y))) + ((flat? y) (make-printable-flat y max-length)) + ((list? y) + (map (λ (le) (make-printable le max-length)) y)) + ((vector? y) + (vector-map (λ (ve) (make-printable ve max-length)) y)) + (else y)))) + +(define make-printable-flat + (λ (y max-length) + (flat->tensor-list + (flat-store y) (flat-offset y) (flat-shape y) + (strides (flat-shape y)) max-length))) + +(define flat->tensor-list + (λ (store offset shape strides max-length) + (cond + ((null? shape) (vref store offset)) + (else + (let ((top-len (car shape)) + (stride (car strides))) + (fake-tensor + (reverse + (call/cc + (λ (return) + (for/fold ((lst '())) ((i (in-range offset (+ offset (* top-len stride)) stride)) + (count (in-naturals 0))) + (cond + ((and (> max-length 0) (= count max-length)) (return (cons '... lst))) + (else + (cons (flat->tensor-list store i (cdr shape) (cdr strides) max-length) + lst))))))))))))) + +(include "test/test-E-print.rkt") + +(provide max-tensor-print-length + make-printable + ;; This is used in ext-impl.rkt + make-printable-flat + fake-tensor) diff --git a/uniform-tensors/autodiff/test/test-A-autodiff.rkt b/uniform-tensors/autodiff/test/test-A-autodiff.rkt new file mode 100644 index 0000000..b453b3e --- /dev/null +++ b/uniform-tensors/autodiff/test/test-A-autodiff.rkt @@ -0,0 +1,15 @@ +(module+ test + (require rackunit) + (let ((k0 end-of-chain)) + (let ((dual0 0) + (dual1 (dual 1 k0))) + + (check-equal? dual1 (dual 1 k0)) + (check-true (dual? dual1)) + (check-false (dual? 1)) + (check-equal? (ρ dual1) 1) + (check-equal? (ρ dual0) 0) + (check-equal? (κ dual1) k0) + + (check-equal? (map* (λ (d) (ρ d)) (∇-once dual1 (list dual0 dual1))) + '(0.0 1.0))))) diff --git a/uniform-tensors/autodiff/test/test-E-print.rkt b/uniform-tensors/autodiff/test/test-E-print.rkt new file mode 100644 index 0000000..30f51e6 --- /dev/null +++ b/uniform-tensors/autodiff/test/test-E-print.rkt @@ -0,0 +1,71 @@ +(module+ test + (require rackunit) + (require "../tensors.rkt") + + (define long-tensor + (tensor 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15)) + + (define dualized-long-tensor + (dual long-tensor end-of-chain)) + + (define deep-tensor + (tensor long-tensor long-tensor long-tensor long-tensor long-tensor + long-tensor long-tensor long-tensor long-tensor long-tensor + long-tensor long-tensor long-tensor long-tensor long-tensor)) + + (define deeper-tensor + (tensor deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor + deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor + deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor)) + + (check-equal? (make-printable-flat long-tensor 3) (fake-tensor '(1.0 2.0 3.0 ...))) + (check-equal? (make-printable-flat deep-tensor 3) + (fake-tensor + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + '...))) + + (check-equal? (make-printable-flat deeper-tensor 3) + (fake-tensor + (list + (fake-tensor + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + '...)) + (fake-tensor + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + '...)) + (fake-tensor + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + '...)) + '...))) + (parameterize ((max-tensor-print-length 3)) + (check-equal? (make-printable dualized-long-tensor 3) (fake-tensor '(1.0 2.0 3.0 ...))) + (check-equal? (make-printable (list long-tensor dualized-long-tensor deeper-tensor)) + (list + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor + (list + (fake-tensor + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + '...)) + (fake-tensor + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + '...)) + (fake-tensor + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + '...)) + '...)))))) diff --git a/uniform-tensors/ext-impl.rkt b/uniform-tensors/ext-impl.rkt new file mode 100644 index 0000000..6eb7d34 --- /dev/null +++ b/uniform-tensors/ext-impl.rkt @@ -0,0 +1,28 @@ +#lang racket +(require "tensors/0-vectors.rkt") +(require "tensors/1-flats.rkt") +(require (only-in "tensors/B-tensor-basics.rkt" + merge-flats)) +(require (only-in "tensors/D-extend.rkt" + merge-shapes + min-shape + ext2-shapes + flat-ext1-∇ + flat-ext1-ρ + flat-ext2-ρ + functional->preallocated-1-ρ + functional->preallocated-1-∇ + functional->preallocated-2-ρ + functional->preallocated-2-∇ + idxs + scalarize + ensure-flat)) +(require (only-in "autodiff/E-print.rkt" + make-printable-flat + fake-tensor)) + +(provide (all-from-out "tensors/0-vectors.rkt")) +(provide (all-from-out "tensors/1-flats.rkt")) +(provide (all-from-out "tensors/B-tensor-basics.rkt")) +(provide (all-from-out "tensors/D-extend.rkt")) +(provide (all-from-out "autodiff/E-print.rkt")) diff --git a/uniform-tensors/ext-ops.rkt b/uniform-tensors/ext-ops.rkt new file mode 100644 index 0000000..fc223f5 --- /dev/null +++ b/uniform-tensors/ext-ops.rkt @@ -0,0 +1,40 @@ +#lang racket + +(require "ext-ops/A-scalar-ops.rkt") +(require "ext-ops/B-comparators.rkt") +(require "ext-ops/C-star-2-1.rkt") +(require "ext-ops/D-sum.rkt") +(require "ext-ops/E-argmax.rkt") +(require "ext-ops/F-max.rkt") +(require "ext-ops/G-correlate.rkt") +(require "ext-ops/I-flatten.rkt") +(require "ext-ops/K-concat.rkt") + +(provide d+ d- d* d/ + d-expt d-exp d-log d-abs + d-rectify d-sqrt d-sqr + + +-0-0 --0-0 *-0-0 /-0-0 expt-0-0 + exp-0 log-0 abs-0 rectify-0 sqrt-0 + + +-ρ --ρ *-ρ /-ρ + expt-ρ exp-ρ log-ρ abs-ρ + rectify-ρ sqrt-ρ sqr-ρ zeroes-ρ) + +(provide =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) + +(provide d*-2-1 *-2-1-ρ) + +(provide sum-1 d-sum sum-ρ d-sum-cols sum-cols-ρ) + +(provide argmax-1 d-argmax argmax-ρ) + +(provide max-1 d-max max-ρ) + +(provide correlate-ρ d-correlate) + +(provide flatten-2 d-flatten flatten-ρ) + +(provide concat-1-1 d-concat concat-ρ + d-concat-n concat-n-ρ) diff --git a/uniform-tensors/ext-ops/A-scalar-ops.rkt b/uniform-tensors/ext-ops/A-scalar-ops.rkt new file mode 100644 index 0000000..9096209 --- /dev/null +++ b/uniform-tensors/ext-ops/A-scalar-ops.rkt @@ -0,0 +1,138 @@ +#lang racket + +(require (only-in "../tensors.rkt" ext1-ρ ext2-ρ)) +(require "../autodiff.rkt") + +(define +-0-0 + (prim2 + + (λ (a b z) + (values z z)))) + +(define --0-0 + (prim2 - + (λ (a b z) + (values z (- z))))) + +(define *-0-0 + (prim2 * + (λ (a b z) + (values (* b z) (* a z))))) + +(define /-0-0 + (prim2 / + (λ (a b z) + (values (* z (/ 1 b)) + (* z (/ (- a) (* b b))))))) + +(define expt-0-0 + (prim2 expt + (λ (a b z) + (values (* z (* b (expt a (- b 1)))) + (* z (* (expt a b) (log a))))))) + +(define exp-0 + (prim1 exp + (λ (a z) + (* z (exp a))))) + +(define log-0 + (prim1 log + (λ (a z) + (* z (/ 1 a))))) + +(define sqrt-0 + (prim1 sqrt + (λ (x z) + (/ z (* 2 (sqrt x)))))) + +(define abs-0-ρ + (λ (x) + (cond + ((< x 0) (* -1 x)) + (else x)))) + +(define abs-0-∇ + (λ (x z) + (cond + ((< x 0) (- z)) + (else z)))) + +(define abs-0 + (prim1 abs-0-ρ abs-0-∇)) + +(define rectify-0-ρ + (λ (s) + (cond + ((< s 0.0) 0.0) + (else s)))) + +(define rectify-0-∇ + (λ (s z) + (cond + ((< s 0.0) 0.0) + (else z)))) + +(define rectify-shape + (λ (s) s)) + +(define rectify-0 + (prim1 rectify-0-ρ rectify-0-∇ rectify-shape)) + +;;------------------------------------ +;; differentiable extended functions. +;;------------------------------------ + +(define d* (ext2 *-0-0 0 0)) +(define d+ (ext2 +-0-0 0 0)) +(define d- (ext2 --0-0 0 0)) +(define d/ (ext2 /-0-0 0 0)) +(define d-expt (ext2 expt-0-0 0 0)) + +(define d-exp (ext1 exp-0 0)) +(define d-log (ext1 log-0 0)) +(define d-abs (ext1 abs-0 0)) +(define d-rectify (ext1 rectify-0 0)) +(define d-sqrt (ext1 sqrt-0 0)) + +(define d-sqr + (λ (x) + (d* x x))) + +;;------------------------------------ +;; non-differentiable extended functions. +;;------------------------------------ + +(define *-ρ (ext2-ρ * 0 0)) +(define +-ρ (ext2-ρ + 0 0)) +(define --ρ (ext2-ρ - 0 0)) +(define /-ρ (ext2-ρ / 0 0)) +(define expt-ρ (ext2-ρ expt 0 0)) + +(define exp-ρ (ext1-ρ exp 0)) +(define log-ρ (ext1-ρ log 0)) +(define abs-ρ (ext1-ρ abs-0-ρ 0)) +(define rectify-ρ (ext1-ρ rectify-0-ρ 0)) + +(define sqrt-ρ + (λ (a) + (expt-ρ a 1/2))) + +(define sqr-ρ + (λ (x) + (*-ρ x x))) + +(define zeroes-ρ + (ext1-ρ (λ (_) 0.0) 0)) + +(include "test/test-A-scalar-ops.rkt") + +(provide +-0-0 --0-0 *-0-0 /-0-0 expt-0-0 + exp-0 log-0 abs-0 rectify-0 sqrt-0 + + d+ d- d* d/ + d-expt d-exp d-log d-abs + d-rectify d-sqrt d-sqr + + +-ρ --ρ *-ρ /-ρ + expt-ρ exp-ρ log-ρ abs-ρ + rectify-ρ sqrt-ρ sqr-ρ zeroes-ρ) diff --git a/uniform-tensors/ext-ops/B-comparators.rkt b/uniform-tensors/ext-ops/B-comparators.rkt new file mode 100644 index 0000000..c42a2cf --- /dev/null +++ b/uniform-tensors/ext-ops/B-comparators.rkt @@ -0,0 +1,85 @@ +#lang racket + +(require "../autodiff.rkt") + +;;---------------------------- +;; Boolean comparators +;;---------------------------- + +(define comparator + (λ (f) + (λ (da db) + (f (ρ da) (ρ db))))) + +(define =-0-0 + (comparator =)) + +(define <-0-0 + (comparator <)) + +(define <=-0-0 + (comparator <=)) + +(define >-0-0 + (comparator >)) + +(define >=-0-0 + (comparator >)) + +;;---------------------------- +;; Tensorized comparators +;;---------------------------- + +(define comparator-ρ + (λ (f) + (λ (da db) + (cond + ((f (ρ da) (ρ db)) 1.0) + (else 0.0))))) + +(define comparator-∇ + (λ (f) + (λ (da db z) + (cond + ((f (ρ da) (ρ db)) (values z z)) + (else (values 0.0 0.0)))))) + +(define comparator-shape + (λ (f) + (λ (sa sb) + sa))) + +(define comparator-prim + (λ (f) + (prim2 (comparator-ρ f) (comparator-∇ f) (comparator-shape f)))) + +(define extended-comparator + (λ (f) + (ext2 (comparator-prim f) 0 0))) + +(define =-1 + (extended-comparator =)) + +(define <-1 + (extended-comparator <)) + +(define >-1 + (extended-comparator >)) + +(define <=-1 + (extended-comparator <=)) + +(define >=-1 + (extended-comparator >=)) + +(define != + (λ (a b) + (not (= a b)))) + +(define !=-1 + (extended-comparator !=)) + +(include "test/test-B-comparators.rkt") + +(provide =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/uniform-tensors/ext-ops/C-star-2-1.rkt b/uniform-tensors/ext-ops/C-star-2-1.rkt new file mode 100644 index 0000000..5eb0d63 --- /dev/null +++ b/uniform-tensors/ext-ops/C-star-2-1.rkt @@ -0,0 +1,44 @@ +#lang racket + +(require "../tensors/0-vectors.rkt") +(require (only-in "../tensors.rkt" ext2-ρ)) +(require "../autodiff.rkt") + +(define *-2-1-base-ρ + (λ (v0 i0 stride0 + v1 i1 stride1 + v-out i-out stride-out) + (for ([i (in-range 0 stride-out)]) + (vset! v-out (+ i-out i) + (* (vref v0 (+ i0 i)) + (vref v1 (+ i1 (modulo i stride1)))))))) + +(define *-2-1-base-∇ + (λ (g0 g1 v0 i0 stride0 + v1 i1 stride1 + vz iz stride-z) + (for ([i (in-range 0 stride-z)]) + (let ((a (vref v0 (+ i0 i))) + (b (vref v1 (+ i1 (modulo i stride1)))) + (z (vref vz (+ iz i)))) + (vset! g0 (+ i0 i) + (+ (vref g0 (+ i0 i)) (* z b))) + (vset! g1 (+ i1 (modulo i stride1)) + (+ (vref g1 (+ i1 (modulo i stride1))) (* z a))))))) + +(define *-2-1-shape + (λ (s t) + s)) + +(define *-2-1 + (prim2 *-2-1-base-ρ *-2-1-base-∇ *-2-1-shape)) + +(define d*-2-1 + (ext2 *-2-1 2 1)) + +(define *-2-1-ρ + (ext2-ρ *-2-1-base-ρ 2 1 *-2-1-shape)) + +(include "test/test-C-star-2-1.rkt") + +(provide *-2-1-ρ d*-2-1) diff --git a/uniform-tensors/ext-ops/D-sum.rkt b/uniform-tensors/ext-ops/D-sum.rkt new file mode 100644 index 0000000..44c6b8e --- /dev/null +++ b/uniform-tensors/ext-ops/D-sum.rkt @@ -0,0 +1,68 @@ +#lang racket + +(require "../tensors/0-vectors.rkt") +(require (only-in "../tensors.rkt" ext1-ρ)) +(require "../autodiff.rkt") + +(define sum-1-ρ + (λ (v0 i0 stride0 + v-out i-out stride-out) + (vset! v-out i-out + (for/fold ([sum 0.0]) ([i (in-range i0 (+ i0 stride0))]) + (+ sum (vref v0 i)))))) + +(define sum-1-∇ + (λ (g0 v0 i0 stride0 + vz iz stride-z) + (let ((z (vref vz iz))) + (for ([i (in-range i0 (+ i0 stride0))]) + (vset! g0 i + (+ (vref g0 i) z)))))) + +(define sum-shape + (λ (st) + (refr st 1))) + +(define sum-1 + (prim1 sum-1-ρ sum-1-∇ sum-shape)) + +(define d-sum + (ext1 sum-1 1)) + +(define sum-ρ + (ext1-ρ sum-1-ρ 1 sum-shape)) + +(provide d-sum sum-ρ) + +(define sum-cols-2-ρ + (λ (v0 i0 stride0 + v-out i-out stride-out) + (for ((i (in-range 0 stride-out))) + (vset! v-out (+ i i-out) + (for/fold ([sum 0.0]) ([j (in-range i0 (+ i0 stride0) stride-out)]) + (+ sum (vref v0 (+ j i)))))))) + +(define sum-cols-2-∇ + (λ (g0 v0 i0 stride0 + vz iz stride-z) + (for ((i (in-range 0 stride-z))) + (for ([j (in-range i0 (+ i0 stride0) stride-z)]) + (vset! g0 (+ i j) + (+ (vref g0 (+ i j)) (vref vz (+ i iz)))))))) + +(define sum-cols-shape + (λ (s) + (refr s 1))) + +(define sum-cols-2 + (prim1 sum-cols-2-ρ sum-cols-2-∇ sum-cols-shape)) + +(define d-sum-cols + (ext1 sum-cols-2 2)) + +(define sum-cols-ρ + (ext1-ρ sum-cols-2-ρ 2 sum-cols-shape)) + +(include "test/test-D-sum.rkt") + +(provide sum-1 d-sum-cols sum-cols-ρ) diff --git a/uniform-tensors/ext-ops/E-argmax.rkt b/uniform-tensors/ext-ops/E-argmax.rkt new file mode 100644 index 0000000..1a2a285 --- /dev/null +++ b/uniform-tensors/ext-ops/E-argmax.rkt @@ -0,0 +1,41 @@ +#lang racket + +(require "../tensors/0-vectors.rkt") +(require (only-in "../tensors.rkt" ext1-ρ)) +(require "../autodiff.rkt") + +(define argmax-1-ρ + (λ (v0 i0 stride0 + v-out i-out stride-out) + (vset! v-out i-out + (for/fold ([max -inf.0] + [max-i -1] #:result max-i) + ([i (in-range i0 (+ i0 stride0))]) + (let ((v (vref v0 i))) + (cond + ((> v max) (values v (+ (- i i0) 0.0))) + (else (values max max-i)))))))) + +(define argmax-1-∇ + (λ (g0 v0 i0 stride0 + vz iz stride-z) + (let ((z (vref vz iz))) + (for ([i (in-range i0 (+ i0 stride0))]) + (vset! g0 i 0.0))))) + +(define argmax-shape + (λ (st) + '())) + +(define argmax-1 + (prim1 argmax-1-ρ argmax-1-∇ argmax-shape)) + +(define d-argmax + (ext1 argmax-1 1)) + +(define argmax-ρ + (ext1-ρ argmax-1-ρ 1 argmax-shape)) + +(include "test/test-E-argmax.rkt") + +(provide argmax-1 d-argmax argmax-ρ) diff --git a/uniform-tensors/ext-ops/F-max.rkt b/uniform-tensors/ext-ops/F-max.rkt new file mode 100644 index 0000000..101ffe2 --- /dev/null +++ b/uniform-tensors/ext-ops/F-max.rkt @@ -0,0 +1,49 @@ +#lang racket + +(require "../tensors/0-vectors.rkt") +(require (only-in "../tensors.rkt" ext1-ρ)) +(require "../autodiff.rkt") + +(define max-1-ρ + (λ (v0 i0 stride0 + v-out i-out stride-out) + (vset! v-out i-out + (for/fold ([max -inf.0]) + ([i (in-range i0 (+ i0 stride0))]) + (let ((v (vref v0 i))) + (cond + ((> v max) v) + (else max))))))) + +(define max-1-∇ + (λ (g0 v0 i0 stride0 + vz iz stride-z) + (let ((z (vref vz iz))) + (for/fold ([max -inf.0] + [max-i -1] #:result + (for ([i (in-range i0 (+ i0 stride0))]) + (cond + ((= i (+ i0 max-i)) (vset! g0 i z)) + (else (vset! g0 i 0.0))))) + ([i (in-range i0 (+ i0 stride0))]) + (let ((v (vref v0 i))) + (cond + ((> v max) (values v (- i i0))) + (else (values max max-i)))))))) + +(define max-shape + (λ (st) + (cdr st))) + +(define max-1 + (prim1 max-1-ρ max-1-∇ max-shape)) + +(define d-max + (ext1 max-1 1)) + +(define max-ρ + (ext1-ρ max-1-ρ 1 max-shape)) + +(include "test/test-F-max.rkt") + +(provide max-1 d-max max-ρ) diff --git a/uniform-tensors/ext-ops/G-correlate.rkt b/uniform-tensors/ext-ops/G-correlate.rkt new file mode 100644 index 0000000..9db2109 --- /dev/null +++ b/uniform-tensors/ext-ops/G-correlate.rkt @@ -0,0 +1,95 @@ +#lang racket + +(require "../tensors/0-vectors.rkt") +(require (only-in "../tensors.rkt" ext2-ρ len)) +(require "../autodiff.rkt") + +;; Correlation is written taking into account how ext2 works +;; Ext2 is responsible for producing the i-out'th output from +;; v0[i0] and v1[i1], we take advantage of this. The shape constants +;; n b m d are pre-calculated the striding constants nd md and qd +;; are calculated. + +(define correlate-3-1-ρ + (λ (nd md qd) + (λ (v0 i0 _ + v1 i1 d + v-out i-out b) + (let* ((i1-min (- i1 (modulo i1 nd))) + (i1-max (+ i1-min nd))) + (for ((i (in-range 0 b))) + (vset! v-out (+ i-out i) + (for/fold ([sum 0.0]) ([j (in-range 0 md)]) + (let ((ai (+ i0 (* i md) j)) + (bi (- (+ i1 j) qd))) + (cond + ((and (>= bi i1-min) (< bi i1-max)) + (let ((a (vref v0 ai)) + (b (vref v1 bi))) + (+ sum (* a b)))) + (else sum)))))))))) + +(define correlate-3-1-∇ + (λ (nd md qd) + (λ (g0 g1 + v0 i0 bmd + v1 i1 d + vz iz b) + (let* ((i1-min (- i1 (modulo i1 nd))) + (i1-max (+ i1-min nd))) + (for ((i (in-range 0 b))) + (let ((z (vref vz (+ iz i)))) + (for ([j (in-range 0 md)]) + (let ((ai (+ i0 (* i md) j)) + (bi (- (+ i1 j) qd))) + (when (and (>= bi i1-min) (< bi i1-max)) + (let ((a (vref v0 ai)) + (b (vref v1 bi))) + (vset! g0 ai + (+ (vref g0 ai) (* z b))) + (vset! g1 bi + (+ (vref g1 bi) (* z a))))))))))))) + +(define correlate-shape + (λ (bmd nd) + (list (car bmd)))) + +(define correlate-3-1 + (λ (nd md qd) + (prim2 + (correlate-3-1-ρ nd md qd) + (correlate-3-1-∇ nd md qd) + correlate-shape))) + +(define d-correlate + (λ (bank signal) + (let* ((b-m-d (last 3 (shape (ρ bank)))) + (n-d (last 2 (shape (ρ signal)))) + (d (ref n-d 1)) + (nd (* d (ref n-d 0))) + (m (ref b-m-d 1)) + (q (/ (- m 1) 2)) ;; This is the padding. + (qd (* q d)) + (md (* m d))) + ((ext2 (correlate-3-1 nd md qd) 3 1) bank signal)))) + +(define correlate-ρ + (λ (bank signal) + (let* ((b-m-d (last 3 (shape (ρ bank)))) + (n-d (last 2 (shape (ρ signal)))) + (d (ref n-d 1)) + (nd (* d (ref n-d 0))) + (m (ref b-m-d 1)) + (q (/ (- m 1) 2)) ;; This is the padding. + (qd (* q d)) + (md (* m d))) + ((ext2-ρ (correlate-3-1-ρ nd md qd) 3 1 correlate-shape) + bank signal)))) + +(define last + (λ (n s) + (refr s (- (len s) n)))) + +(include "test/test-G-correlate.rkt") + +(provide d-correlate correlate-ρ) diff --git a/uniform-tensors/ext-ops/I-flatten.rkt b/uniform-tensors/ext-ops/I-flatten.rkt new file mode 100644 index 0000000..bf24773 --- /dev/null +++ b/uniform-tensors/ext-ops/I-flatten.rkt @@ -0,0 +1,31 @@ +#lang racket + +(require (only-in "../tensors.rkt" ext1-ρ tref reshape shape ref)) +(require (only-in "../autodiff.rkt" prim1 ext1)) + +(define flatten-2-ρ + (λ (t) + (reshape (flatten-shape (shape t)) t))) + +(define flatten-2-∇ + (λ (t z) + (reshape (shape t) z))) + +(define flatten-shape + (λ (s) + (let ((rows (ref s 0)) + (cols (ref s 1))) + (list (* rows cols))))) + +(define flatten-2 + (prim1 flatten-2-ρ flatten-2-∇ flatten-shape)) + +(define d-flatten + (ext1 flatten-2 2)) + +(define flatten-ρ + (ext1-ρ flatten-2-ρ 2)) + +(include "test/test-I-flatten.rkt") + +(provide flatten-2 d-flatten flatten-ρ) diff --git a/uniform-tensors/ext-ops/K-concat.rkt b/uniform-tensors/ext-ops/K-concat.rkt new file mode 100644 index 0000000..5cc1ee5 --- /dev/null +++ b/uniform-tensors/ext-ops/K-concat.rkt @@ -0,0 +1,76 @@ +#lang racket + +(require "../tensors/0-vectors.rkt") +(require (rename-in (only-in "../tensors.rkt" ext2-ρ tref tlen shape len ref) + (shape shape-ρ))) +(require (only-in "../autodiff.rkt" prim2 ext2 shape)) + +(define concat-shape + (λ (st su) + (cons (+ (ref st 0) (ref su 0)) + (cdr st)))) + +(define concat-base-ρ + (λ (v0 i0 stride0 + v1 i1 stride1 + v-out i-out stride-out) + (for ([i (in-range 0 stride-out)]) + (cond + ((< i stride0) + (vset! v-out (+ i-out i) (vref v0 (+ i0 i)))) + (else + (vset! v-out (+ i-out i) (vref v1 (+ i1 (- i stride0))))))))) + +(define concat-base-∇ + (λ (g0 g1 v0 i0 stride0 + v1 i1 stride1 + vz iz stride-z) + (for ([i (in-range 0 stride-z)]) + (cond + ((< i stride0) + (vset! g0 (+ i0 i) + (+ (vref g0 (+ i0 i)) + (vref vz (+ iz i))))) + (else + (vset! g1 (+ i1 (- i stride0)) + (+ (vref g1 (+ i1 (- i stride0))) + (vref vz (+ iz i))))))))) + +(define concat-base + (prim2 concat-base-ρ concat-base-∇ concat-shape)) + +(define d-concat-n + (λ (n) + (λ (t u) + (let ((st (shape t)) + (su (shape u))) + (ensure-compatible-shapes n st su) + ((ext2 concat-base n n) t u))))) + +(define concat-n-ρ + (λ (n) + (λ (t u) + (let ((st (shape-ρ t)) + (su (shape-ρ u))) + (ensure-compatible-shapes n st su) + ((ext2-ρ concat-base-ρ n n concat-shape) t u))))) + +(define ensure-compatible-shapes + (λ (n st su) + (let ((rt (len st)) + (ru (len su))) + ;; The shape of the tensor of rank r at rank n-1 + ;; is given by (drop st (+ 1 (- r n))) + (when (not (equal? (drop st (+ 1 (- rt n))) (drop su (+ 1 (- ru n))))) + (error 'concat "Incompatible concat shapes: ~a and ~a at last ~a dimensions" + st su n))))) + +(define d-concat (d-concat-n 1)) +(define concat-ρ (concat-n-ρ 1)) +(define concat-1-1 concat-base) + +(include "test/test-K-concat.rkt") + +(provide concat-1-1 + d-concat concat-ρ + d-concat-n concat-n-ρ) diff --git a/uniform-tensors/ext-ops/test/test-A-scalar-ops.rkt b/uniform-tensors/ext-ops/test/test-A-scalar-ops.rkt new file mode 100644 index 0000000..2c13e39 --- /dev/null +++ b/uniform-tensors/ext-ops/test/test-A-scalar-ops.rkt @@ -0,0 +1,122 @@ +(module+ test + (require rackunit) + (require (only-in "../tensors.rkt" tensor)) + + ;; Check basic numericals + (let ((a 2) + (b 3)) + (check-ρ-∇ (d+ a b) 5 (list 1.0 1.0)) + (check-ρ-∇ (d- a b) -1 (list 1.0 -1.0)) + (check-ρ-∇ (d* a b) 6 (list 3.0 2.0)) + (check-ρ-∇ (d/ a b) + 2/3 + '(0.3333333333333333 -0.2222222222222222)) + (check-ρ-∇ (d-exp a) (exp 2) (list (exp 2))) + (check-ρ-∇ (d-log a) (log 2) (list 0.5)) + (check-ρ-∇ (d-expt a b) 8 + (list 12.0 5.545177444479562)) + (check-ρ-∇ (d-sqrt a) + (sqrt 2) + (list 0.3535533905932738)) + + (let* ((first-derivative ((∇¹ d-sqr) b))) + (check-dual-equal? first-derivative (list 6.0))) + + (define z + (lambda (x y) + (d+ (d-log x) (d* x y)))) + + (let ((c (dual* 2)) + (d (dual* 2))) + (check-dual-equal? + ((∇¹ z) c d) + (list 2.5 2.0))) + + (check-dual-equal? ((∇¹ z) a a) (list 2.5 2.0))) + + ;; Check numericals with vector-duals + + (let ((a (tensor 2.0 3.0 4.0)) + (b (tensor 3.0 8.0 9.0))) + (check-ρ-∇ (d+ a b) (tensor 5.0 11.0 13.0) (list (tensor 1.0 1.0 1.0) (tensor 1.0 1.0 1.0))) + (check-ρ-∇ (d- a b) (tensor -1 -5 -5) (list (tensor 1.0 1.0 1.0) (tensor -1.0 -1.0 -1.0))) + (check-ρ-∇ (d* a b) (tensor 6.0 24.0 36.0) (list (tensor 3.0 8.0 9.0) (tensor 2.0 3.0 4.0))) + (check-ρ-∇ (d/ a b) + (tensor (/ 2.0 3.0) (/ 3.0 8.0) (/ 4.0 9.0)) + (list (tensor 0.3333333333333333 0.125 0.1111111111111111) + (tensor -0.2222222222222222 -0.046875 -0.04938271604938271))) + (check-ρ-∇ (d-exp a) (tensor (exp 2.0) (exp 3.0) (exp 4.0)) + (list (tensor (exp 2.0) (exp 3.0) (exp 4.0)))) + (check-ρ-∇ (d-log a) (tensor (log 2.0) (log 3.0) (log 4.0)) + (list (tensor 0.5 (/ 1 3.0) (/ 1 4.0)))) + (check-ρ-∇ (d-expt a b) + (tensor (expt 2.0 3.0) (expt 3.0 8.0) (expt 4.0 9.0)) + (list (tensor 12.0 17496.0 589824.0) + (tensor 5.545177444479562 7207.9952259514685 363408.7490014126))) + + (check-ρ-∇ (d-sqrt a) + (tensor (sqrt 2.0) (sqrt 3.0) (sqrt 4.0)) + (list (tensor 0.3535533905932738 0.28867513459481287 0.25))) + + (check-ρ-∇ (d-sqr b) (tensor 9.0 64.0 81.0) (list (tensor 6.0 16.0 18.0))) + + (define z + (lambda (x y) + (d+ (d-log x) (d* x y)))) + + (let ((c (dual* (tensor 2.0 3.0 4.0))) + (d (dual* (tensor 2.0 3.0 4.0)))) + (check-dual-equal? + ((∇¹ z) c d) + (list (tensor 2.5 3.3333333333333335 4.25) (tensor 2.0 3.0 4.0)))) + + (check-dual-equal? ((∇¹ z) a a) (list (tensor 2.5 3.3333333333333335 4.25) (tensor 2.0 3.0 4.0)))) + + ;; Check numericals with lists + (let ((x (dual* 3)) + (y (dual* 2))) + (let ((f d*)) + (check-ρ-∇ (d* x y) 6 '(2.0 3.0)) + (check-dual-equal? + ((∇¹ (λ (m n) + (list (f m m) (f n n)))) + x y) + '(6.0 4.0)) + (check-dual-equal? + ((∇¹ (λ (m n) + (list + (list (f m m) (f n n)) + (list (f m m) (f n n))))) + x y) + '(12.0 8.0)) + (check-dual-equal? + ((∇¹ (λ (m n) + (map (λ (m) (f m m)) + (list m n)))) + x y) + '(6.0 4.0)) + + (check-dual-equal? + ((∇¹ (λ (m n) + (map f (list m n) (list m n)))) + x y) + '(6.0 4.0)) + + (check-dual-equal? + ((∇¹ (λ (m n) + (map f (list m n) (list n m)))) + x y) + '(4.0 6.0)))) + + (let ((a 7) + (b (tensor 13))) + (check-ρ-∇ (d+ a b) (tensor 20) (list 1.0 (tensor 1.0))) + (check-ρ-∇ (d* a b) (tensor 91) (list 13.0 (tensor 7.0))) + (check-ρ-∇ (d/ a b) (tensor 7/13) (list 0.07692 (tensor -0.04142)))) + + (let ((a 7) + (b (tensor 13 15))) + (check-ρ-∇ (d+ a b) (tensor 20 22) (list 2.0 (tensor 1.0 1.0))) + (check-ρ-∇ (d* a b) (tensor 91 105) (list 28.0 (tensor 7.0 7.0))) + (check-ρ-∇ (d/ a b) (tensor 7/13 7/15) + (list 0.14358 (tensor -0.04142 -0.03111))))) diff --git a/uniform-tensors/ext-ops/test/test-B-comparators.rkt b/uniform-tensors/ext-ops/test/test-B-comparators.rkt new file mode 100644 index 0000000..9f3fdf5 --- /dev/null +++ b/uniform-tensors/ext-ops/test/test-B-comparators.rkt @@ -0,0 +1,12 @@ +(module+ test + (require rackunit) + (let ((a 2) + (b 3)) + (check-true (<-0-0 a b)) + (check-false (>-0-0 a b)) + (check-true (<=-0-0 a b)) + (check-false (>=-0-0 a b)) + (check-false (=-0-0 a b)) + (check-true (=-0-0 a a)) + (check-true (zero? 0)) + (check-false (zero? a)))) diff --git a/uniform-tensors/ext-ops/test/test-C-star-2-1.rkt b/uniform-tensors/ext-ops/test/test-C-star-2-1.rkt new file mode 100644 index 0000000..bbb2c8b --- /dev/null +++ b/uniform-tensors/ext-ops/test/test-C-star-2-1.rkt @@ -0,0 +1,24 @@ +(module+ test + (require rackunit) + (require (only-in "../tensors.rkt" tensor)) + + (let ((a (tensor (tensor 3 4 5 6) + (tensor 7 8 9 10))) + (b (tensor 2 3 4 5))) + (check-ρ-∇ (d*-2-1 a b) + (tensor (tensor 6 12 20 30) (tensor 14 24 36 50)) + (list (tensor (tensor 2.0 3.0 4.0 5.0) (tensor 2.0 3.0 4.0 5.0)) + (tensor 10.0 12.0 14.0 16.0)))) + + (let ((a (tensor (tensor 3 4 5 6) + (tensor 7 8 9 10))) + (b (tensor (tensor 2 3 4 5) + (tensor 12 13 14 15)))) + + (check-ρ-∇ (d*-2-1 a b) + (tensor (tensor (tensor 6 12 20 30) (tensor 14 24 36 50)) + (tensor (tensor 36 52 70 90) (tensor 84 104 126 150))) + (list (tensor (tensor 14.0 16.0 18.0 20.0) + (tensor 14.0 16.0 18.0 20.0)) + (tensor (tensor 10.0 12.0 14.0 16.0) + (tensor 10.0 12.0 14.0 16.0)))))) diff --git a/uniform-tensors/ext-ops/test/test-D-sum.rkt b/uniform-tensors/ext-ops/test/test-D-sum.rkt new file mode 100644 index 0000000..7b07d0d --- /dev/null +++ b/uniform-tensors/ext-ops/test/test-D-sum.rkt @@ -0,0 +1,58 @@ +(module+ test + (require rackunit) + (require "C-star-2-1.ss") + (require (only-in "../tensors.rkt" tensor)) + (require (only-in "A-scalar-ops.ss" d-sqr d* d-)) + + (let ((a (tensor 3 4 5))) + (check-ρ-∇ (sum-1 a) 12 + (list (tensor 1.0 1.0 1.0))) + (check-ρ-∇ (d-sum a) 12 + (list (tensor 1.0 1.0 1.0)))) + + (let ((a (tensor (tensor 3 4 5) + (tensor 6 7 8)))) + (check-dual-equal? (d-sum a) (tensor 12 21)) + (check-dual-equal? ((∇¹ (λ (b) (d-sum (d* b b)))) a) + (list (tensor (tensor 6.0 8.0 10.0) + (tensor 12.0 14.0 16.0))))) + + (define dot-product + (λ (a b) + (d-sum (d*-2-1 a b)))) + + (define sse + (λ (a b) + (d-sum (d-sqr (d- a b))))) + + (let ((a (tensor (tensor 3 4 5 6) + (tensor 7 8 9 10))) + (b (tensor 2 3 4 5))) + + (check-ρ-∇ (sum-1 b) 14 + (list (tensor 1.0 1.0 1.0 1.0))) + + (check-ρ-∇ (dot-product a b) + (tensor 68 124) + (list (tensor (tensor 2.0 3.0 4.0 5.0) + (tensor 2.0 3.0 4.0 5.0)) + (tensor 10.0 12.0 14.0 16.0))) + + (check-ρ-∇ (sse a b) + (tensor 4 100) + (list (tensor (tensor 2.0 2.0 2.0 2.0) + (tensor 10.0 10.0 10.0 10.0)) + (tensor -12.0 -12.0 -12.0 -12.0)))) + + (let ((a (tensor (tensor 3 4 5 6) + (tensor 7 8 9 10))) + (b (tensor (tensor 2 3 4 5) + (tensor 12 13 14 15)))) + + (check-ρ-∇ (dot-product a b) + (tensor (tensor 68 124) + (tensor 248 464)) + (list (tensor (tensor 14.0 16.0 18.0 20.0) + (tensor 14.0 16.0 18.0 20.0)) + (tensor (tensor 10.0 12.0 14.0 16.0) + (tensor 10.0 12.0 14.0 16.0)))))) diff --git a/uniform-tensors/ext-ops/test/test-E-argmax.rkt b/uniform-tensors/ext-ops/test/test-E-argmax.rkt new file mode 100644 index 0000000..f72819e --- /dev/null +++ b/uniform-tensors/ext-ops/test/test-E-argmax.rkt @@ -0,0 +1,17 @@ +(module+ test + (require (only-in "../tensors.rkt" tensor)) + + (let ((y (tensor 0.0 0.0 1.0 0.0))) + (check-ρ-∇ (d-argmax y) 2.0 + (list (tensor 0.0 0.0 0.0 0.0)))) + + (let ((y (tensor (tensor 0.0 0.0 1.0 0.0) + (tensor 0.0 1.0 0.0 0.0) + (tensor 1.0 0.0 0.0 0.0) + (tensor 0.0 0.0 0.0 1.0)))) + (check-ρ-∇ (d-argmax y) (tensor 2.0 1.0 0.0 3.0) + (list + (tensor (tensor 0.0 0.0 0.0 0.0) + (tensor 0.0 0.0 0.0 0.0) + (tensor 0.0 0.0 0.0 0.0) + (tensor 0.0 0.0 0.0 0.0)))))) diff --git a/uniform-tensors/ext-ops/test/test-F-max.rkt b/uniform-tensors/ext-ops/test/test-F-max.rkt new file mode 100644 index 0000000..01ab1a5 --- /dev/null +++ b/uniform-tensors/ext-ops/test/test-F-max.rkt @@ -0,0 +1,10 @@ +(module+ test + (require rackunit) + (require (only-in "../tensors.rkt" tensor)) + + (let ((y (tensor (tensor 0.0 0.0 1.0 0.0) + (tensor 0.0 1.0 0.0 0.0) + (tensor 1.0 0.0 0.0 0.0) + (tensor 0.0 0.0 0.0 1.0)))) + (check-ρ-∇ (d-max y) (tensor 1.0 1.0 1.0 1.0) + (list y)))) diff --git a/uniform-tensors/ext-ops/test/test-G-correlate.rkt b/uniform-tensors/ext-ops/test/test-G-correlate.rkt new file mode 100644 index 0000000..417723c --- /dev/null +++ b/uniform-tensors/ext-ops/test/test-G-correlate.rkt @@ -0,0 +1,118 @@ +(module+ test + (require rackunit) + (require (only-in "../tensors.rkt" tensor ext2-∇ check-tensor-equal?)) + + ;; for testing b = 4 + ;; m = 3 + ;; d = 2 + + ;; signal length n = 6 + + ;; (1 2) (3 4) (5 6) (7 8) (9 10) (11 12) + ;; (1 2) (3 4) (5 6) + ;; (7 8) (9 10) (11 12) + ;; (13 14) (15 16) (17 18) + ;; (19 20) (21 22) (23 24) + + ;; Signal is (n d) + (define signal (tensor (tensor 1 2) + (tensor 3 4) + (tensor 5 6) + (tensor 7 8) + (tensor 9 10) + (tensor 11 12))) + + (define bank (tensor (tensor + (tensor 1 2) + (tensor 3 4) + (tensor 5 6)) + (tensor + (tensor 7 8) + (tensor 9 10) + (tensor 11 12)) + (tensor + (tensor 13 14) + (tensor 15 16) + (tensor 17 18)) + (tensor + (tensor 19 20) + (tensor 21 22) + (tensor 23 24)))) + + (define corr-ρ + (ext2-ρ (correlate-3-1-ρ 12 6 2) 3 1 correlate-shape)) + + (define corr-∇ + (ext2-∇ (correlate-3-1-∇ 12 6 2) 3 1 correlate-shape)) + + (check-tensor-equal? (corr-ρ bank signal) + ;; Should be of size nb + (tensor (tensor 50.0 110.0 170.0 230.0) + (tensor 91.0 217.0 343.0 469.0) + (tensor 133.0 331.0 529.0 727.0) + (tensor 175.0 445.0 715.0 985.0) + (tensor 217.0 559.0 901.0 1243.0) + (tensor 110.0 362.0 614.0 866.0))) + + (let-values (((filter-∇ signal-∇) + (corr-∇ bank signal (tensor (tensor 1.0 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0 1.0))))) + (check-tensor-equal? filter-∇ + (tensor + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)))) + (check-tensor-equal? signal-∇ + ;; Should be of size nb + (tensor (tensor 88.0 96.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 104.0 112.0)))) + + (check-dual-equal? (d-correlate bank signal) + ;; Should be of size nb + (tensor (tensor 50.0 110.0 170.0 230.0) + (tensor 91.0 217.0 343.0 469.0) + (tensor 133.0 331.0 529.0 727.0) + (tensor 175.0 445.0 715.0 985.0) + (tensor 217.0 559.0 901.0 1243.0) + (tensor 110.0 362.0 614.0 866.0))) + + (let ((gs ((∇¹ d-correlate) bank signal))) + (check-dual-equal? (car gs) + (tensor + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)))) + (check-dual-equal? (cadr gs) + ;; Should be of size nb + (tensor (tensor 88.0 96.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 104.0 112.0))))) diff --git a/uniform-tensors/ext-ops/test/test-I-flatten.rkt b/uniform-tensors/ext-ops/test/test-I-flatten.rkt new file mode 100644 index 0000000..f7740cf --- /dev/null +++ b/uniform-tensors/ext-ops/test/test-I-flatten.rkt @@ -0,0 +1,13 @@ +(module+ test + (require rackunit) + (require (only-in "../tensors.rkt" tensor)) + (require (only-in "../autodiff.rkt" check-ρ-∇ check-dual-equal?)) + (require (only-in "A-scalar-ops.rkt" d*)) + + (define r2-t1 (tensor (tensor 3.0 4.0) (tensor 5.0 6.0))) + (define r1-t1 (tensor 3.0 4.0 5.0 6.0)) + + (check-dual-equal? (flatten-2 r2-t1) r1-t1) + (check-ρ-∇ ((λ (t1 t2) (d* t1 (flatten-2 t2))) r1-t1 r2-t1) + (tensor 9.0 16.0 25.0 36.0) + (list (tensor 3.0 4.0 5.0 6.0) (tensor (tensor 3.0 4.0) (tensor 5.0 6.0))))) diff --git a/uniform-tensors/ext-ops/test/test-K-concat.rkt b/uniform-tensors/ext-ops/test/test-K-concat.rkt new file mode 100644 index 0000000..b427ced --- /dev/null +++ b/uniform-tensors/ext-ops/test/test-K-concat.rkt @@ -0,0 +1,126 @@ +(module+ test + (require rackunit) + (require (only-in "../tensors.rkt" tensor)) + (require (only-in "../autodiff.rkt" check-ρ-∇ check-dual-equal?)) + (require (only-in "A-scalar-ops.rkt" d*)) + + (define r2-t1 (tensor (tensor 3.0 4.0) (tensor 5.0 6.0))) + (define r1-t2 (tensor 5.0 6.0 7.0)) + (define r1-t1 (tensor 3.0 4.0 5.0 6.0 7.0)) + + (check-dual-equal? + (d-concat r2-t1 r1-t2) + (tensor (tensor 3.0 4.0 5.0 6.0 7.0) + (tensor 5.0 6.0 5.0 6.0 7.0))) + + (check-ρ-∇ ((λ (t1 t2 t3) (d* t3 (d-concat t1 t2))) r2-t1 r1-t2 r1-t1) + (tensor (tensor 9.0 16.0 25.0 36.0 49.0) + (tensor 15.0 24.0 25.0 36.0 49.0)) + (list (tensor (tensor 3.0 4.0) (tensor 3.0 4.0)) + (tensor 10.0 12.0 14.0) + (tensor 8.0 10.0 10.0 12.0 14.0))) + (define r3-t1 + (tensor (tensor (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0)) + + (tensor (tensor 9.0 10.0) + (tensor 11.0 12.0) + (tensor 13.0 14.0) + (tensor 15.0 16.0)) + + (tensor (tensor 17.0 18.0) + (tensor 19.0 20.0) + (tensor 21.0 22.0) + (tensor 23.0 24.0)))) + + + (define r2-t2 + (tensor (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0))) + + (define r1-t3 + (tensor 0.5 0.5)) + + (define concat-2 (d-concat-n 2)) + + (check-dual-equal? + (concat-2 r3-t1 r2-t2) + (tensor (tensor (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0) + (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0)) + + (tensor (tensor 9.0 10.0) + (tensor 11.0 12.0) + (tensor 13.0 14.0) + (tensor 15.0 16.0) + (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0)) + + (tensor (tensor 17.0 18.0) + (tensor 19.0 20.0) + (tensor 21.0 22.0) + (tensor 23.0 24.0) + (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0)))) + + + (check-ρ-∇ ((λ (t1 t2 t3) (d* t3 (concat-2 t1 t2))) r3-t1 r2-t2 r1-t3) + (tensor (tensor (tensor 0.5 1.0) + (tensor 1.5 2.0) + (tensor 2.5 3.0) + (tensor 3.5 4.0) + (tensor 0.5 1.0) + (tensor 1.5 2.0) + (tensor 2.5 3.0) + (tensor 3.5 4.0)) + + (tensor (tensor 4.5 5.0) + (tensor 5.5 6.0) + (tensor 6.5 7.0) + (tensor 7.5 8.0) + (tensor 0.5 1.0) + (tensor 1.5 2.0) + (tensor 2.5 3.0) + (tensor 3.5 4.0)) + + (tensor (tensor 8.5 9.0) + (tensor 9.5 10.0) + (tensor 10.5 11.0) + (tensor 11.5 12.0) + (tensor 0.5 1.0) + (tensor 1.5 2.0) + (tensor 2.5 3.0) + (tensor 3.5 4.0))) + (list + (tensor (tensor (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5)) + (tensor (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5)) + (tensor (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5))) + + (tensor (tensor 1.5 1.5) + (tensor 1.5 1.5) + (tensor 1.5 1.5) + (tensor 1.5 1.5)) + + (tensor 192.0 216.0)))) diff --git a/uniform-tensors/no-duals-no-overrides.rkt b/uniform-tensors/no-duals-no-overrides.rkt new file mode 100644 index 0000000..07ca22e --- /dev/null +++ b/uniform-tensors/no-duals-no-overrides.rkt @@ -0,0 +1,29 @@ +#lang racket/base + +(module+ test + (require rackunit)) + +(require "tensors.rkt") +(require "ext-ops.rkt") + +(define scalar? number?) + +(provide + ;; From tensors + len ref refr + + tref tlen list->tensor tensor build-tensor + + ext1-ρ ext2-ρ + + scalar? tensor? rank shape reshape trefs + + ;; From ext-ops + +-ρ --ρ *-ρ /-ρ rectify-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ + sum-ρ abs-ρ *-2-1-ρ argmax-ρ + max-ρ sum-cols-ρ correlate-ρ + concat-ρ concat-n-ρ flatten-ρ + + =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/uniform-tensors/no-duals.rkt b/uniform-tensors/no-duals.rkt new file mode 100644 index 0000000..cd1bcaf --- /dev/null +++ b/uniform-tensors/no-duals.rkt @@ -0,0 +1,29 @@ +#lang racket/base + +(module+ test + (require rackunit)) + +(require "tensors.rkt") +(require "ext-ops.rkt") + +(define scalar? number?) + +(provide + ;; From tensors + len ref refr + + tref tlen list->tensor tensor build-tensor + + ext1-ρ ext2-ρ + + scalar? tensor? rank shape reshape trefs + + ;; From ext-ops + (rename-out (+-ρ +) (--ρ -) (*-ρ *) (/-ρ /) (rectify-ρ rectify) + (exp-ρ exp) (log-ρ log) (expt-ρ expt) (sqrt-ρ sqrt) (sqr-ρ sqr) (zeroes-ρ zeroes) + (sum-ρ sum) (abs-ρ abs) (*-2-1-ρ *-2-1) (argmax-ρ argmax) + (max-ρ max) (sum-cols-ρ sum-cols) (correlate-ρ correlate) + (flatten-ρ flatten) (concat-ρ concat) (concat-n-ρ concat-n)) + + =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/uniform-tensors/no-overrides.rkt b/uniform-tensors/no-overrides.rkt new file mode 100644 index 0000000..05844b7 --- /dev/null +++ b/uniform-tensors/no-overrides.rkt @@ -0,0 +1,43 @@ +#lang racket/base + +(require + (except-in "tensors.rkt" + rank shape reshape trefs tref tensor? tlen ref refr)) + +(require "autodiff.rkt") +(require "ext-ops.rkt") + +(provide + len ref refr + + tref tlen list->tensor tensor build-tensor + + ext1-ρ ext2-ρ ext1-∇ ext2-∇ + + dual dual? ρ κ ∇ ∇¹ + + ext1 ext2 prim1 prim2 + + scalar? tensor? rank shape reshape trefs + + trace-print check-dual-equal? check-ρ-∇ + make-printable + + +-0-0 --0-0 *-0-0 /-0-0 expt-0-0 + exp-0 log-0 abs-0 rectify-0 sqrt-0 + flatten-2 concat-1-1 + + d+ d- d* d/ d-rectify + d-exp d-log d-expt d-sqrt d-sqr + d-sum d-abs d*-2-1 d-argmax + d-max d-sum-cols d-correlate + d-flatten d-concat d-concat-n + + +-ρ --ρ *-ρ /-ρ rectify-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ + sum-ρ abs-ρ *-2-1-ρ argmax-ρ + max-ρ sum-cols-ρ correlate-ρ + flatten-ρ concat-ρ concat-n-ρ + + =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/uniform-tensors/tensors.rkt b/uniform-tensors/tensors.rkt new file mode 100644 index 0000000..0daca26 --- /dev/null +++ b/uniform-tensors/tensors.rkt @@ -0,0 +1,22 @@ +#lang racket +(require "tensors/0-vectors.rkt") +(require "tensors/1-flats.rkt") +(require "tensors/A-equality.rkt") +(require "tensors/B-tensor-basics.rkt") +(require "tensors/C-tensor-ops.rkt") +(require "tensors/D-extend.rkt") + +(provide start-vector-manager vector-manager-report) + +(provide tolerance tensor-equal? check-tensor-equal?) + +(provide len ref refr) +(provide tref tlen list->tensor tensor build-tensor trefs) + +(provide ext1-ρ ext2-ρ ext1-∇ ext2-∇ expects-preallocated?) + +(provide flat flat? flat-shape flat-store flat-offset size-of strides) + +;; These will get overriden by duals +(provide tensor?) +(provide rank shape reshape size-of) diff --git a/uniform-tensors/tensors/0-vectors.rkt b/uniform-tensors/tensors/0-vectors.rkt new file mode 100644 index 0000000..7b73c29 --- /dev/null +++ b/uniform-tensors/tensors/0-vectors.rkt @@ -0,0 +1,101 @@ +#lang racket +(require ffi/vector) +(require ffi/unsafe) + +;;------------------------------------------------ +;; Raw representation of vectors +;;------------------------------------------------ + +(define vec? f32vector?) +(define vec f32vector) +(define make-vec make-f32vector) +(define vref f32vector-ref) +(define vset! f32vector-set!) +(define vlen f32vector-length) +(define list->vec list->f32vector) +(define build-vec + (λ (n proc) + (list->vec (map (compose exact->inexact proc) (range n))))) +(define vec->cpointer f32vector->cpointer) +(define vref-cpointer + (λ (v i) + (unless (< i (vlen v)) + (error 'vref-cpointer + "Index ~a out of range [0, ~a]" + i (vlen v))) + (ptr-add (vec->cpointer v) i _float))) + +(define-for-syntax debug-leaks? #f) +(define-syntax when-debug-leaks + (λ (x) + (syntax-case x () + ((when-debug-leaks expr) + debug-leaks? + #'expr) + ((when-debug-leaks expr) + #'(void))))) + +(define new-vec + (λ (size initial-value [context 'new-vec]) + (let ((m (make-vec size initial-value))) + (when-debug-leaks (manage-flat-vector! m context)) + m))) + +(define vcopy + (λ (dest idest src isrc n) + (for ([id (in-range idest (+ n idest))] + [is (in-range isrc (+ n isrc))]) + (vset! dest id (vref src is))))) + +(define print-vec + (λ (v (off 0) (port (current-output-port))) + (fprintf port "#(") + (for ((i (in-range off (vlen v)))) + (fprintf port "~a " (vref v i))) + (fprintf port ")"))) + +(provide vec? vec vref vset! vlen vcopy print-vec + list->vec build-vec vec->cpointer vref-cpointer new-vec) + +;;------------------------------------------------ +;; Memory management for flat-vectors +;;------------------------------------------------ + +(define flat-vector-manager + (make-will-executor)) + +(define manage-flat-vector! + (λ (m context) + (set-count! context (add1 (count context))) + (will-register flat-vector-manager m (flat-vector-collector context)))) + +(define flat-vector-collector + (λ (context) + (λ (v) + (cond + ((vector? v) + (set-count! context (sub1 (count context)))) + (else (fprintf (current-error-port) "?? ...")))))) + +(define start-vector-manager + (λ () + (when-debug-leaks + (void + (thread + (λ () + (let loop () + (will-execute flat-vector-manager) + (loop)))))))) + +(define counts (make-hash)) +(define count (λ (context) (dict-ref counts context 0))) +(define set-count! (λ (context v) (dict-set! counts context v))) +(define vector-manager-report + (λ () + (fprintf (current-error-port) "----------------------------------------------~%") + (fprintf (current-error-port) "context\t\t\tcount~%") + (for ([(context count) (in-hash counts)]) + (fprintf (current-error-port) "~a\t\t\t~a~%" context count)) + (fprintf (current-error-port) "----------------------------------------------~%"))) + +(provide start-vector-manager vector-manager-report) diff --git a/uniform-tensors/tensors/1-flats.rkt b/uniform-tensors/tensors/1-flats.rkt new file mode 100644 index 0000000..23aeb79 --- /dev/null +++ b/uniform-tensors/tensors/1-flats.rkt @@ -0,0 +1,73 @@ +#lang racket + +;-------------------------------------------------------- +; Representation of tensors +;-------------------------------------------------------- + +;; A flat tensor representation is for a contiguous slice in the backing store. +;; The fields we need: +;; shape : list +;; store : vector +;; offset : start of the contiguous slice. +;; size : number of elements in the contiguous slice +;; strides : Number of elements in each dimension of the tensor. +;; rank: Number of dimensions in the tensor + + + +(define flat + (λ (shape store offset) + (vector flat shape store offset + (size-of shape) + (strides shape) + (length shape)))) + +(define flat? + (λ (v) + (and (vector? v) + (eq? (vector-ref v 0) flat)))) + +(define flat-shape + (λ (f) + (vector-ref f 1))) + +(define flat-store + (λ (f) + (vector-ref f 2))) + +(define flat-offset + (λ (f) + (vector-ref f 3))) + +(define flat-size + (λ (f) + (vector-ref f 4))) + +(define flat-strides + (λ (f) + (vector-ref f 5))) + +(define flat-rank + (λ (f) + (vector-ref f 6))) + +(define size-of + (λ (shape) + (product shape 1))) + +(define product + (λ (lst a) + (cond + ((null? lst) a) + (else (product (cdr lst) (* (car lst) a)))))) + +(define strides + (λ (shape) + (cond + ((null? shape) '()) + (else (cons (size-of (cdr shape)) + (strides (cdr shape))))))) + +(provide flat flat? flat-shape flat-store + flat-offset flat-rank flat-strides flat-size + size-of strides) diff --git a/uniform-tensors/tensors/A-equality.rkt b/uniform-tensors/tensors/A-equality.rkt new file mode 100644 index 0000000..639533f --- /dev/null +++ b/uniform-tensors/tensors/A-equality.rkt @@ -0,0 +1,66 @@ +#lang racket + +;;—————————————————–—————————————————–—————————————————– +;; Equality checks for mostly for testing. +;;—————————————————–—————————————————–—————————————————– + +(require "0-vectors.ss") +(require "1-flats.ss") +(require rackunit) + +;;—————————————————–—————————————————–—————————————————– +;; These parameters can be overriden to account for +;; different type of numbers used inside tensors. +;;—————————————————–—————————————————–—————————————————– + +(define tolerance (make-parameter 0.0001)) + +(define equal-within-tolerance? + (make-parameter + (λ (actual expected) + (< (abs (- actual expected)) (tolerance))))) + +;;—————————————————–—————————————————–—————————————————– +;; These are representation specific, but part of the +;; exported interface of the module +;;—————————————————–—————————————————–—————————————————– + +(define tensor-equal? + (λ (actual expected) + (or (equal? actual expected) + (and (real? actual) + (real? expected) + ((equal-within-tolerance?) actual expected)) + (and (flat? actual) + (flat? expected) + (equal? (flat-shape actual) + (flat-shape expected)) + (equal-elements? actual expected))))) + +(define (equal-elements? actual expected) + (let ((actual-offset (flat-offset actual)) + (expected-offset (flat-offset expected)) + (actual-size (flat-size actual)) + (expected-size (flat-size expected)) + (actual-store (flat-store actual)) + (expected-store (flat-store expected))) + (and (equal? actual-size expected-size) + (call/cc (λ (return) + (for/fold ([check #t]) + ([i-actual (in-range actual-offset + (+ actual-offset + actual-size))] + [i-expected (in-range expected-offset + (+ expected-offset + expected-size))]) + (cond + (((equal-within-tolerance?) + (vref actual-store i-actual) + (vref expected-store i-expected)) check) + (else (return #f))))))))) + +(define-binary-check (check-tensor-equal? tensor-equal? actual expected)) + +(include "test/test-A-equality.rkt") + +(provide tolerance equal-within-tolerance? tensor-equal? check-tensor-equal? equal-elements?) diff --git a/uniform-tensors/tensors/B-tensor-basics.rkt b/uniform-tensors/tensors/B-tensor-basics.rkt new file mode 100644 index 0000000..63e9a71 --- /dev/null +++ b/uniform-tensors/tensors/B-tensor-basics.rkt @@ -0,0 +1,184 @@ +#lang racket + +;-------------------------------------------------------- +; Memory management tools for vectors +;-------------------------------------------------------- + +(require "0-vectors.ss") +(require "1-flats.ss") + +;-------------------------------------------------------- +; Lists +;-------------------------------------------------------- +(define ref list-ref) +(define refr drop) +(define len length) + +(provide ref refr len) + +;-------------------------------------------------------- +; Tensor basics +;-------------------------------------------------------- + +(define tref + (λ (t i) + (cond + ((= 1 (flat-rank t)) + (vref (flat-store t) (+ (flat-offset t) i))) + (else + (flat (cdr (flat-shape t)) + (flat-store t) + (+ (flat-offset t) (* i (car (flat-strides t))))))))) + +(define tlen + (λ (t) + (car (flat-shape t)))) + +(define flat-ref-idx + (λ (v indices) + (flat-ref-idx* (flat-offset v) (flat-strides v) indices))) + +(define flat-ref-idx* + (λ (current-idx strides indices) + (cond + ((null? indices) current-idx) + (else + (flat-ref-idx* + (+ current-idx + (* (car indices) (car strides))) + (cdr strides) + (cdr indices)))))) + +(define strides + (λ (shape) + (cond + ((null? shape) '()) + (else (cons (size-of (cdr shape)) + (strides (cdr shape))))))) + +(define size-of + (λ (shape) + (product shape 1))) + +(define product + (λ (lst a) + (cond + ((null? lst) a) + (else (product (cdr lst) (* (car lst) a)))))) + +(define list->tensor + (λ (lst) + (cond + ((null? lst) (error 'list->flat-tensor "No elements found")) + ((number? (car lst)) + (flat (list (length lst)) (list->vec lst) 0)) + (else + (flat-tensor-from-list lst))))) + +(define flat-tensor-from-list + (λ (lst) + (let* ([inner-shape (flat-shape (car lst))] + [inner-size (size-of inner-shape)] + [outer-shape (cons (length lst) inner-shape)] + [size (size-of outer-shape)] + [v (new-vec size 0.0 'from-list)]) + (for ([fl lst] + [i (in-naturals 0)]) + (vcopy v (* i inner-size) + (flat-store fl) (flat-offset fl) + inner-size)) + (flat outer-shape v 0)))) + +(define tensor? + (λ (t) + (or (number? t) + (flat? t)))) + +(define tensor + (λ args + (ensure-shape args) + (cond + ((number? (car args)) (flat (list (length args)) + (list->vec (map exact->inexact args)) + 0)) + (else (merge-flats args))))) + +(define merge-flats + (λ (args) + (let* ((inner-shape (flat-shape (car args))) + (outer (length args)) + + (new-shape (cons outer inner-shape)) + (stride (size-of inner-shape)) + + (new-size (size-of new-shape)) + + (v-out (new-vec new-size +nan.0 'tensor))) + (for ([i-out (in-range outer)] + [arg args]) + (vcopy v-out (* i-out stride) (flat-store arg) (flat-offset arg) stride)) + (flat new-shape v-out 0)))) + +(define ensure-shape + (λ (args) + (when (null? args) + (error 'tensor "Tensors cannot be empty")) + (let ((checked-shape + (λ (x) (if (flat? x) + (flat-shape x) + '())))) + (unless (and (not (null? args)) + (cond + ((number? (car args)) + (andmap number? (cdr args))) + ((flat? (car args)) + (let ((s (checked-shape (car args)))) + (andmap (λ (t) + (and (flat? t) + (equal? (checked-shape t) s))) + (cdr args)))) + (else #f))) + (error 'tensor + "Cannot construct a tensor out of these elements: ~a~%" + args))))) + +(define build-tensor + (λ (shape f) + (let* ((size (size-of shape)) + (v (new-vec size 0.0 'build-tensor)) + (strides (strides shape))) + (fill-flat-tensor v shape strides f 0 '()) + (flat shape v 0)))) + +(define fill-flat-tensor + (λ (dest shape strides f offset tidx) + (cond + ((null? (cdr shape)) + (for ([i (in-range 0 (car shape))]) + (vset! dest (+ offset i) + (exact->inexact (f (append tidx (list i))))))) + (else + (let ((stride (car strides))) + (for ([i (in-range 0 (car shape))]) + (fill-flat-tensor dest + (cdr shape) (cdr strides) f + (+ offset (* i stride)) (append tidx (list i))))))))) + +(define trefs + (λ (t b) + (let* ([st (flat-shape t)] + [est (cdr st)] + [estride (size-of est)] + [nshape (cons (length b) (cdr st))] + [size-out (size-of nshape)] + [v-out (new-vec size-out 0.0 'flat-refs)] + [vt (flat-store t)]) + (for ([ib b] + [i-out (in-range 0 size-out estride)]) + (vcopy v-out i-out vt (* ib estride) estride)) + (flat nshape v-out 0)))) + +(include "test/test-B-tensor-basics.rkt") + +(provide tref tlen list->tensor number? + tensor? tensor build-tensor trefs merge-flats) diff --git a/uniform-tensors/tensors/C-tensor-ops.rkt b/uniform-tensors/tensors/C-tensor-ops.rkt new file mode 100644 index 0000000..cf32d23 --- /dev/null +++ b/uniform-tensors/tensors/C-tensor-ops.rkt @@ -0,0 +1,35 @@ +#lang racket + +(require "1-flats.ss") +(require "B-tensor-basics.ss") + +;;—————————————————– +;; Shape, rank, size-of +;;—————————————————– + +(define shape + (λ (t) + (cond + ((number? t) '()) + (else (flat-shape t))))) + +(define rank + (λ (t) + (len (shape t)))) + +;;—————————————————– +;; Reshape a tensor +;;—————————————————– + +(define reshape + (λ (s t) + (cond + ((= (size-of s) (flat-size t)) + (flat s (flat-store t) (flat-offset t))) + (else (error 'tensor-reshape "Cannot reshape ~a to ~a~%" (flat-shape t) s))))) + + +(include "test/test-C-tensor-ops.rkt") + +(provide rank shape reshape) +(provide size-of strides) diff --git a/uniform-tensors/tensors/D-extend.rkt b/uniform-tensors/tensors/D-extend.rkt new file mode 100644 index 0000000..7ff2046 --- /dev/null +++ b/uniform-tensors/tensors/D-extend.rkt @@ -0,0 +1,394 @@ +#lang racket + +(require "0-vectors.ss") +(require "1-flats.ss") +(require "B-tensor-basics.ss") +(require "C-tensor-ops.ss") + +;;—————————————————–—————————————————–—————————————————– +;; Unary Pointwise extension +;;—————————————————–—————————————————–—————————————————– + +(define ext1-ρ + (λ (f m [shape-fn scalar-shape]) + (λ (t) + (cond + ((number? t) (f t)) + ((expects-preallocated? f) + (scalarize + (flat-ext1-ρ f m shape-fn t))) + (else + (let* ((in-shape (flat-shape t)) + (base-shape (min-shape m in-shape)) + (out-shape (shape-fn base-shape)) + (flat-f (functional->preallocated-1-ρ f base-shape out-shape))) + (scalarize + (flat-ext1-ρ flat-f m shape-fn t)))))))) + +(define ext1-∇ + (λ (f m [shape-fn scalar-shape]) + (λ (t z) + (cond + ((number? t) (f t z)) + ((expects-preallocated? f) + (scalarize (flat-ext1-∇ f m shape-fn t (ensure-flat z)))) + (else + (let* ((in-shape (flat-shape t)) + (base-shape (min-shape m in-shape)) + (out-shape (shape-fn base-shape)) + (flat-f (functional->preallocated-1-∇ f base-shape out-shape))) + (scalarize (flat-ext1-∇ flat-f m shape-fn t (ensure-flat z))))))))) + +(define functional->preallocated-1-ρ + (λ (f base-shape out-shape) + (λ (v0 i0 stride0 v-out i-out stride-out) + (set-prealloc-ρ! v-out i-out out-shape + (f (arg-value base-shape v0 i0)))))) + +(define functional->preallocated-1-∇ + (λ (f base-shape out-shape) + (λ (g0 v0 i0 stride0 vz iz stride-z) + (let ((z (arg-value out-shape vz iz)) + (a (arg-value base-shape v0 i0))) + (set-prealloc-∇! g0 i0 base-shape (f a z)))))) + +(define set-prealloc-ρ! + (λ (v-out i-out out-shape a) + (cond + ((null? out-shape) (vset! v-out i-out a)) + (else (v-copy-flat! v-out i-out a))))) + +(define set-prealloc-∇! + (λ (v-out i-out out-shape a) + (cond + ((null? out-shape) (vset! v-out i-out (+ (vref v-out i-out) a))) + (else (v-add-flat! v-out i-out a))))) + +(define arg-value + (λ (v-shape v i) + (cond + ((null? v-shape) (vref v i)) + (else (flat v-shape v i))))) + + +(define invoke-functional-∇ + (λ (f base-shape v0 i0) + (cond + ((null? base-shape) (f (vref v0 i0))) + (else (f (flat base-shape v0 i0)))))) + +;;—————————————————–—————————————————–—————————————————– +;; Binary Pointwise extension +;;—————————————————–—————————————————–—————————————————– + +(define ext2-ρ + (λ (f m n [shape-fn scalar-shape]) + (λ (t u) + (cond + ((and (number? t) (number? u)) (f t u)) + ((expects-preallocated? f) + (scalarize + (flat-ext2-ρ f m n shape-fn t u))) + ((number? t) + (let* ((t-shape '()) + (u-shape (min-shape n (flat-shape u))) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-ρ f t-shape u-shape out-shape))) + (scalarize + (flat-ext2-ρ flat-f m n shape-fn (ensure-flat t) u)))) + ((number? u) + (let* ((t-shape (min-shape m (flat-shape t))) + (u-shape '()) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-ρ f t-shape u-shape out-shape))) + (scalarize + (flat-ext2-ρ flat-f m n shape-fn t (ensure-flat u))))) + (else + (let* ((t-shape (min-shape m (flat-shape t))) + (u-shape (min-shape n (flat-shape u))) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-ρ f t-shape u-shape out-shape))) + (scalarize + (flat-ext2-ρ flat-f m n shape-fn t u)))))))) + +(define ext2-∇ + (λ (f m n [shape-fn scalar-shape]) + (λ (t u z) + (let ((invoke-flat-ext2-∇ + (λ (f m n shape-fn t u z) + (let-values (((da db) (flat-ext2-∇ f m n shape-fn t u z))) + (values (scalarize da) (scalarize db)))))) + (cond + ((and (number? t) (number? u)) (f t u z)) + ((expects-preallocated? f) + (invoke-flat-ext2-∇ f m n shape-fn t u z)) + ((number? t) + (let* ((t-shape '()) + (u-shape (min-shape n (flat-shape u))) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-∇ f t-shape u-shape out-shape))) + (invoke-flat-ext2-∇ flat-f m n shape-fn (ensure-flat t) u z))) + ((number? u) + (let* ((t-shape (min-shape m (flat-shape t))) + (u-shape '()) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-∇ f t-shape u-shape out-shape))) + (invoke-flat-ext2-∇ flat-f m n shape-fn t (ensure-flat u) z))) + (else + (let* ((t-shape (min-shape m (flat-shape t))) + (u-shape (min-shape n (flat-shape u))) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-∇ f t-shape u-shape out-shape))) + (invoke-flat-ext2-∇ flat-f m n shape-fn t u z)))))))) + +(define functional->preallocated-2-ρ + (λ (f t-shape u-shape out-shape) + (λ (v0 i0 stride0 v1 i1 stride1 v-out i-out stride-out) + (set-prealloc-ρ! v-out i-out out-shape + (f (arg-value t-shape v0 i0) + (arg-value u-shape v1 i1)))))) + +(define functional->preallocated-2-∇ + (λ (f t-shape u-shape out-shape) + (λ (g0 g1 v0 i0 stride0 v1 i1 stride1 vz iz stride-z) + (let ((z (arg-value out-shape vz iz)) + (a (arg-value t-shape v0 i0)) + (b (arg-value u-shape v1 i1))) + (let-values (((da db) (f a b z))) + (set-prealloc-∇! g0 i0 t-shape da) + (set-prealloc-∇! g1 i1 u-shape db)))))) + +(define idxs + (λ (strides out-i i0 i1) + (for/fold ([i0 i0] + [i1 i1] + [x out-i] #:result (values i0 i1)) + ([stride strides]) + (let ((idx (quotient x (vector-ref stride 0))) + (next-x (remainder x (vector-ref stride 0)))) + (values (+ i0 (* idx (vector-ref stride 1))) + (+ i1 (* idx (vector-ref stride 2))) + next-x))))) + +(define merge-shapes + (λ (in-shape min-rank out-f-shape) + (append (take in-shape (- (length in-shape) min-rank)) + out-f-shape))) + +(define flat-ext1-ρ + (λ (f min-rank shape-fn t0) + (let* ((s0 (flat-shape t0)) + (v0 (flat-store t0)) + (off0 (flat-offset t0)) + (sf0 (min-shape min-rank s0)) + (stride0 (size-of sf0)) + (size0 (size-of s0)) + + (sf-out (shape-fn sf0)) + (stride-out (size-of sf-out)) + (s-out (merge-shapes s0 min-rank sf-out)) + (size-out (size-of s-out)) + (v-out (new-vec size-out 0.0))) + (for ([i-out (in-range 0 size-out stride-out)] + #;[i0 (in-range off0 (+ off0 size0) stride0)]) + (define i0 (+ off0 (* (/ i-out stride-out) stride0))) + (f v0 i0 stride0 v-out i-out stride-out)) + (flat s-out v-out 0)))) + +(define flat-ext1-∇ + (λ (fᵈ min-rank shape-fn t0 z) + ;; z has the same shape as the output + (let* ((s0 (flat-shape t0)) + (v0 (flat-store t0)) + (off0 (flat-offset t0)) + (sf0 (min-shape min-rank s0)) + (stride0 (size-of sf0)) + (size0 (size-of s0)) + + (sz (flat-shape z)) + (size-z (size-of sz)) + (sf-z (shape-fn sf0)) + (stride-z (size-of sf-z)) + (vz (flat-store z)) + (offz (flat-offset z)) + + (g0 (new-vec size0 0.0))) + (for ([iz (in-range 0 size-z stride-z)] + #;[i0 (in-range off0 (+ off0 size0) stride0)]) + (define i0 (+ off0 (* (/ iz stride-z) stride0))) + (fᵈ g0 v0 i0 stride0 vz (+ offz iz) stride-z)) + (flat s0 g0 0)))) + +(define flat-ext2-ρ + (λ (f r0 r1 shape-fn t0 t1) + (let* ((s0 (flat-shape t0)) + (v0 (flat-store t0)) + (off0 (flat-offset t0)) + (sf0 (min-shape r0 s0)) + + (s1 (flat-shape t1)) + (v1 (flat-store t1)) + (off1 (flat-offset t1)) + (sf1 (min-shape r1 s1)) + + (sf-out (shape-fn sf0 sf1)) + (stride0 (size-of sf0)) + (stride1 (size-of sf1)) + (stride-out (size-of sf-out))) + (ext2-shapes s0 s1 r0 r1 sf-out + (λ (s-out size-out q0 q1 strides) + (let ((out-v (new-vec size-out 0.0))) + (for ([out-i (in-range 0 size-out stride-out)]) + (let-values (((i0 i1) + (idxs strides out-i off0 off1))) + (f v0 i0 stride0 v1 i1 stride1 out-v (+ 0 out-i) stride-out))) + (flat s-out out-v 0))))))) + +(define flat-ext2-∇ + (λ (fᵈ r0 r1 shape-fn t0 t1 z) + (let* ((s0 (flat-shape t0)) + (v0 (flat-store t0)) + (off0 (flat-offset t0)) + (sf0 (min-shape r0 s0)) + (stride0 (size-of sf0)) + + (s1 (flat-shape t1)) + (v1 (flat-store t1)) + (off1 (flat-offset t1)) + (sf1 (min-shape r1 s1)) + (stride1 (size-of sf1)) + + (sf-z (shape-fn sf0 sf1)) + (stride-z (size-of sf-z)) + (vz (flat-store z)) + (offz (flat-offset z))) + (ext2-shapes s0 s1 r0 r1 sf-z + (λ (sz size-z q0 q1 strides) + (let ((g0 (new-vec (size-of s0) 0.0)) + (g1 (new-vec (size-of s1) 0.0))) + (for ([iz (in-range 0 size-z stride-z)]) + (let-values (((i0 i1) + (idxs strides iz off0 off1))) + (fᵈ g0 g1 v0 i0 stride0 v1 i1 stride1 vz (+ offz iz) stride-z))) + (values (flat s0 g0 0) + (flat s1 g1 0)))))))) + +(define ext2-shapes + (λ (s0 s1 r0 r1 sf-out k) + (let ((l0 (length s0)) + (l1 (length s1))) + (cond + ((and (= r0 l0) (= r1 l1)) + (k sf-out + (size-of sf-out) + (size-of s0) + (size-of s1) + '())) + + ((= r0 l0) + (ext2-shapes s0 (cdr s1) r0 r1 sf-out + (desc-right (car s1) k))) + + ((= r1 l1) + (ext2-shapes (cdr s0) s1 r0 r1 sf-out + (desc-left (car s0) k))) + + ((and (not (null? s0)) + (not (null? s1)) + (= (car s0) (car s1))) + (ext2-shapes (cdr s0) (cdr s1) r0 r1 sf-out + (desc-both (car s0) k))) + + ((> l1 l0) + (ext2-shapes s0 (cdr s1) r0 r1 sf-out + (desc-right (car s1) k))) + + ((> l0 l1) + (ext2-shapes (cdr s0) s1 r0 r1 sf-out + (desc-left (car s0) k))) + + (else (error 'ext + "Shapes are incompatible for ext2: ~a, and ~a for min ranks ~a, and ~a~%" + s0 s1 r0 r1)))))) + +(define desc-both + (λ (d k) + (λ (s-out qout q0 q1 strides) + (k (cons d s-out) + (* qout d) + (* q0 d) + (* q1 d) + (cons (vector qout q0 q1) strides))))) + +(define desc-left + (λ (d k) + (λ (s-out qout q0 q1 strides) + (k (cons d s-out) + (* qout d) + (* q0 d) + q1 + (cons (vector qout q0 0) strides))))) + +(define desc-right + (λ (d k) + (λ (s-out qout q0 q1 strides) + (k (cons d s-out) + (* qout d) + q0 + (* q1 d) + (cons (vector qout 0 q1) strides))))) + +(define v-copy-flat! + (λ (vg ig a) + ;; copy elements from a to vg + (let ((va (flat-store a)) + (a-offset (flat-offset a)) + (a-stride (size-of (flat-shape a)))) + (for ([i (in-range 0 a-stride)]) + (vset! vg (+ ig i) + (vref va (+ a-offset i))))))) + +(define v-add-flat! + (λ (vg ig a) + ;; copy elements to a to vg while adding them to vg + (let ((va (flat-store a)) + (a-offset (flat-offset a)) + (a-stride (size-of (flat-shape a)))) + (for ([i (in-range 0 a-stride)]) + (vset! vg (+ ig i) + (+ (vref vg (+ ig i)) + (vref va (+ a-offset i)))))))) + +(define expects-preallocated? + (λ (f) + (let ((a (procedure-arity f))) + (and (integer? a) + (>= a 6))))) + +(define ensure-flat + (λ (z) + (cond + ((number? z) + (flat '() (new-vec 1 (exact->inexact z)) 0)) + (else z)))) + +(define scalarize + (λ (t) + (cond + ((null? (flat-shape t)) (vref (flat-store t) 0)) + (else t)))) + +(define min-shape + (λ (min-rank in-shape) + (drop in-shape (- (length in-shape) min-rank)))) + +(define scalar-shape + (λ (s0 [s1 '()]) '())) + +(include "test/test-D-extend.rkt") + +(provide ext1-ρ ext1-∇ ext2-ρ ext2-∇ expects-preallocated? + functional->preallocated-1-ρ functional->preallocated-1-∇ + functional->preallocated-2-ρ functional->preallocated-2-∇ + merge-shapes min-shape ext2-shapes idxs + flat-ext1-∇ flat-ext1-ρ flat-ext2-ρ scalarize ensure-flat) diff --git a/uniform-tensors/tensors/test/test-A-equality.rkt b/uniform-tensors/tensors/test/test-A-equality.rkt new file mode 100644 index 0000000..9488d8c --- /dev/null +++ b/uniform-tensors/tensors/test/test-A-equality.rkt @@ -0,0 +1,84 @@ +(module+ test + (require rackunit) + + (check-true ((equal-within-tolerance?) 1.00001 1.0001)) + (check-true ((equal-within-tolerance?) 1.0002 1.0001)) + (check-false ((equal-within-tolerance?) 1.0003 1.0001)) + + (define t0 + (flat '(2 3 4) + (build-vec 24 + (λ (i) + (* 2.0 i))) + 0)) + + + (define t1 + (flat '(2 3 4) + (build-vec 24 + (λ (i) + (* 2.000001 i))) + 0)) + + (define t2 + (flat '(1 2 3 4) + (build-vec 24 + (λ (i) + (* 2.000001 i))) + 0)) + + (define t3 + (flat '(2 2 3 4) + (build-vec 48 + (λ (i) + (* (quotient i 24) i))) + 0)) + + (define t4 + (flat '(2 2 3 4) + (build-vec 48 + (λ (i) + (- (* 2.000001 (* (quotient i 24) i)) 48.0))) + 0)) + + (check-true (equal-elements? t0 t1)) + + (check-true (equal-elements? t0 t2)) ;; elements are equal, but shapes are not + + (check-true (equal-elements? t0 (flat '(2 3 4) + (flat-store t2) + 0))) + + (check-false (equal-elements? t1 (flat '(2 3 4) + (flat-store t3) + 24))) + + (check-true (equal-elements? t1 (flat '(2 3 4) + (flat-store t4) + 24))) + + (check-true (tensor-equal? t0 t1)) + + (check-false (tensor-equal? t0 t2)) ;; elements are equal, but shapes are not + + (check-true (tensor-equal? t0 (flat '(2 3 4) + (flat-store t2) + 0))) + + (check-false (tensor-equal? t1 (flat '(2 3 4) + (flat-store t3) + 24))) + + (check-true (tensor-equal? t1 (flat '(2 3 4) + (flat-store t4) + 24))) + + (check-tensor-equal? t0 t1) + + (check-tensor-equal? t0 (flat '(2 3 4) + (flat-store t2) + 0)) + + (check-tensor-equal? t1 (flat '(2 3 4) + (flat-store t4) + 24))) diff --git a/uniform-tensors/tensors/test/test-B-tensor-basics.rkt b/uniform-tensors/tensors/test/test-B-tensor-basics.rkt new file mode 100644 index 0000000..903220b --- /dev/null +++ b/uniform-tensors/tensors/test/test-B-tensor-basics.rkt @@ -0,0 +1,33 @@ +(module+ test + (require rackunit) + (require "A-equality.ss") + + (define r0-td 3.0) + (define r1-td (tensor 3.0 4.0 5.0)) + (define r2-td (tensor (tensor 3.0 4.0 5.0) (tensor 7.0 8.0 9.0))) + (define r3-td + (tensor (tensor (tensor 0 1) (tensor 2 3) (tensor 4 5)) + (tensor (tensor 6 7) (tensor 8 9) (tensor 10 11)) + (tensor (tensor 12 13) (tensor 14 15) (tensor 16 17)) + (tensor (tensor 18 19) (tensor 20 21) (tensor 22 23)))) + + (check-tensor-equal? (tref r1-td 2) 5.0) + (check-equal? (tlen r1-td) 3) + (check-tensor-equal? (list->tensor (list 3.0 4.0 5.0)) r1-td) + + (check-true (and (tensor? r0-td) (tensor? r1-td))) + (check-false (tensor? '(a b c))) + + (check-tensor-equal? (build-tensor '(4 3 2) + (λ (idx) + (+ (* 6 (ref idx 0)) + (* 2 (ref idx 1)) + (ref idx 2)))) + r3-td) + + (check-tensor-equal? (build-tensor '(1 2 3) (λ (idx) (+ (list-ref idx 0) (list-ref idx 1) (list-ref idx 2)))) + (tensor (tensor (tensor 0 1 2) (tensor 1 2 3)))) + + (check-tensor-equal? (trefs r1-td '(0 2)) (tensor 3.0 5.0)) + + ) diff --git a/uniform-tensors/tensors/test/test-C-tensor-ops.rkt b/uniform-tensors/tensors/test/test-C-tensor-ops.rkt new file mode 100644 index 0000000..4509a31 --- /dev/null +++ b/uniform-tensors/tensors/test/test-C-tensor-ops.rkt @@ -0,0 +1,63 @@ +(module+ test + (require rackunit) + (require "A-equality.ss") + + (define r0-td 3.0) + (define r1-td (tensor 3.0 4.0 5.0)) + (define r2-td (tensor (tensor 3.0 4.0 5.0) (tensor 7.0 8.0 9.0))) + (define r3-td + (tensor (tensor (tensor 0 1) (tensor 2 3) (tensor 4 5)) + (tensor (tensor 6 7) (tensor 8 9) (tensor 10 11)) + (tensor (tensor 12 13) (tensor 14 15) (tensor 16 17)) + (tensor (tensor 18 19) (tensor 20 21) (tensor 22 23)))) + + (define test-shape (list 2 2 3)) + + (check-equal? (shape r0-td) (list)) + (check-equal? (shape r1-td) (list 3)) + (check-equal? (shape r2-td) (list 2 3)) + + (check-equal? (rank r0-td) 0) + (check-equal? (rank r1-td) 1) + (check-equal? (rank r2-td) 2) + + (check-equal? (size-of '()) 1) + (check-equal? (size-of test-shape) 12) + + + (check-equal? (size-of '(4 3 2)) 24) + + (check-tensor-equal? (reshape '(24) r3-td) + (tensor 0 1 2 3 4 5 + 6 7 8 9 10 11 + 12 13 14 15 16 17 + 18 19 20 21 22 23)) + + (check-tensor-equal? (reshape '(4 1) (tensor 0 1 2 3)) + (tensor (tensor 0) (tensor 1) (tensor 2) (tensor 3))) + + (check-tensor-equal? (reshape '(6) r2-td) + (tensor 3.0 4.0 5.0 7.0 8.0 9.0)) + + (check-tensor-equal? (reshape '(3 2) r2-td) + (tensor (tensor 3.0 4.0) + (tensor 5.0 7.0) + (tensor 8.0 9.0))) + + + (check-exn exn:fail? + (λ () + (tensor "1 2" 1 2))) + + (check-exn exn:fail? + (λ () + (tensor))) + + (check-exn exn:fail? + (λ () + (tensor 1 (tensor 2 3)))) + + (check-exn exn:fail? + (λ () + (tensor tensor (tensor 2 3)))) +) diff --git a/uniform-tensors/tensors/test/test-D-extend.rkt b/uniform-tensors/tensors/test/test-D-extend.rkt new file mode 100644 index 0000000..0eb7ec4 --- /dev/null +++ b/uniform-tensors/tensors/test/test-D-extend.rkt @@ -0,0 +1,236 @@ +(module+ test + (require rackunit) + (require "A-equality.rkt") + (require "B-tensor-basics.rkt") + + (define sum-f + (λ (in-v iᵢ sᵢ out-v iₒ sₒ) + (vset! out-v iₒ + (for/fold ([sum 0.0]) ([i (in-range iᵢ (+ iᵢ sᵢ))]) + (+ sum (vref in-v i)))))) + + (define sum-shape-f + (λ (in-f-shape) + '())) + + (define sum (ext1-ρ sum-f 1 sum-shape-f)) + + (check-equal? (min-shape 2 '(3 4 5 6)) '(5 6)) + + (check-equal? (min-shape 0 '(3 4 5 6)) '()) + + (check-equal? (merge-shapes '(3 4 5 6) 1 '()) + '(3 4 5)) + + (define t0 + (flat '(2 3 4) + (build-vec 24 + (λ (i) + (* 2 i))) + 0)) + + (check-true (equal-elements? (sum t0) + (tensor 12.0 44.0 76.0 108.0 140.0 172.0))) + + + (define dup-f + (λ (in-v iᵢ sᵢ out-v iₒ sₒ) + (for ([i (in-range 0 sₒ)]) + (vset! out-v (+ iₒ i) + (vref in-v (+ iᵢ (modulo i sᵢ))))))) + + (define dup-shape-f + (λ (in-f-shape) + (list (* 2 (car in-f-shape))))) + + (define dup (ext1-ρ dup-f 1 dup-shape-f)) + (check-true (equal-elements? (dup t0) + (tensor 0 2 4 6 0 2 4 6 + 8 10 12 14 8 10 12 14 + 16 18 20 22 16 18 20 22 + 24 26 28 30 24 26 28 30 + 32 34 36 38 32 34 36 38 + 40 42 44 46 40 42 44 46))) + + (define s0 '(3 4 5 6)) + (define s1 '(3 7 6)) + (define r0 2) + (define r1 1) + + (ext2-shapes s0 s1 r0 r1 '(5 6) + (λ (s-out size-out q0 q1 strides) + (check-equal? s-out '(3 4 7 5 6)) + (check-equal? size-out 2520) + (check-equal? strides '(#(840 120 42) #(210 30 0) #(30 0 6))) + (let-values (((i0 i1) (idxs strides 0 0 0))) + (check-equal? i0 0) + (check-equal? i1 0)) + + (let-values (((i0 i1) (idxs strides 30 0 0))) + (check-equal? i0 0) + (check-equal? i1 6)) + + (let-values (((i0 i1) (idxs strides 210 0 0))) + (check-equal? i0 30) + (check-equal? i1 0)) + + (let-values (((i0 i1) (idxs strides 240 0 0))) + (check-equal? i0 30) + (check-equal? i1 6)) + + (let-values (((i0 i1) (idxs strides 420 0 0))) + (check-equal? i0 60) + (check-equal? i1 0)) + + (let-values (((i0 i1) (idxs strides 840 0 0))) + (check-equal? i0 120) + (check-equal? i1 42)) + )) + + + (define *-ρ (ext2-ρ * 0 0)) + (define t0sqr (*-ρ t0 t0)) + + (check-true (equal-elements? + t0sqr + (tensor 0 4 16 36 + 64 100 144 196 + 256 324 400 484 + 576 676 784 900 + 1024 1156 1296 1444 + 1600 1764 1936 2116))) + + (define *-2-1-f + (λ (v0 i0 s0 v1 i1 s1 vout iout sout) + (for ([j0 (in-range 0 s0)]) + (vset! vout (+ iout j0) + (* (vref v0 (+ i0 j0)) + (vref v1 (+ i1 (modulo j0 s1)))))))) + + (define t1 + (flat '(5 6) + (build-vec 30 + (λ (i) (* 2.0 i))) + 0)) + + (define t2 + (flat '(6) + (build-vec 6 + (λ (i) (* 3.0 i))) + 0)) + + (define *-2-1 + (ext2-ρ *-2-1-f 2 1 (λ (s0 s1) s0))) + + (define r-1-2 + (*-2-1 t1 t2)) + + (check-tensor-equal? r-1-2 + (reshape + '(5 6) + (tensor 0 6.0 24.0 54.0 96.0 150.0 + 0 42.0 96.0 162.0 240.0 330.0 + 0 78.0 168.0 270.0 384.0 510.0 + 0 114.0 240.0 378.0 528.0 690.0 + 0 150.0 312.0 486.0 672.0 870.0))) + + (define t3 + (flat '(3 5 6) + (build-vec 90 + (λ (i) (* 2.0 i))) + 0)) + + (define t4 + (flat '(3 6) + (build-vec 18 + (λ (i) (* 3.0 i))) + 0)) + + (define r-3-4 + (*-2-1 t3 t4)) + + (check-tensor-equal? r-3-4 + (reshape + '(3 5 6) + (tensor + 0 6.0 24.0 54.0 96.0 150.0 + 0 42.0 96.0 162.0 240.0 330.0 + 0 78.0 168.0 270.0 384.0 510.0 + 0 114.0 240.0 378.0 528.0 690.0 + 0 150.0 312.0 486.0 672.0 870.0 + + 1080.0 1302.0 1536.0 1782.0 2040.0 2310.0 + 1296.0 1554.0 1824.0 2106.0 2400.0 2706.0 + 1512.0 1806.0 2112.0 2430.0 2760.0 3102.0 + 1728.0 2058.0 2400.0 2754.0 3120.0 3498.0 + 1944.0 2310.0 2688.0 3078.0 3480.0 3894.0 + + 4320.0 4758.0 5208.0 5670.0 6144.0 6630.0 + 4752.0 5226.0 5712.0 6210.0 6720.0 7242.0 + 5184.0 5694.0 6216.0 6750.0 7296.0 7854.0 + 5616.0 6162.0 6720.0 7290.0 7872.0 8466.0 + 6048.0 6630.0 7224.0 7830.0 8448.0 9078.0)))) + +(module+ test + (require rackunit) + + (define r0-td 3.0) + (define r1-td (flat '(3) (list->vec '(3.0 4.0 5.0)) 0)) + (define r2-td (flat '(2 3) (list->vec '(3.0 4.0 5.0 7.0 8.0 9.0)) 0)) + + (define +ᶠ +) + (define +ᵈ (λ (a b z) (values z z))) + + (define sqrᶠ (λ (a) (* a a))) + (define sqrᵈ + (λ (a z) (* z 2 a))) + + (define d-sqr (ext1-∇ sqrᵈ 0 scalar-shape)) + + (define one-like + (λ (t) + (let* ((st (flat-shape t)) + (size-t (size-of st))) + (flat st + (new-vec size-t 1.0) + 0)))) + + (check-true (equal-elements? (d-sqr r1-td (one-like r1-td)) (tensor 6.0 8.0 10.0))) + + (let ((gsqr (d-sqr r2-td (one-like r2-td)))) + (check-tensor-equal? gsqr (reshape '(2 3) (tensor 6.0 8.0 10.0 14.0 16.0 18.0)))) + + (define d+ (ext2-∇ +ᵈ 0 0 scalar-shape)) + + (let-values (((da db) (d+ r1-td r1-td (one-like r1-td)))) + (check-tensor-equal? da (tensor 1.0 1.0 1.0)) + (check-tensor-equal? db (tensor 1.0 1.0 1.0))) + + (let-values (((da db) (d+ r1-td r2-td (one-like r2-td)))) + (check-tensor-equal? da (tensor 2.0 2.0 2.0)) + (check-tensor-equal? db (reshape '(2 3) (tensor 1.0 1.0 1.0 1.0 1.0 1.0)))) + + (define *∇ (ext2-∇ (λ (a b z) (values (* z b) (* z a))) + 0 + 0)) + + (let-values (((gt gu) (*∇ (tensor 2.0 3.0 4.0) (tensor 1.0 2.0 3.0) (tensor 1.0 1.0 1.0)))) + (check-tensor-equal? gt (tensor 1.0 2.0 3.0)) + (check-tensor-equal? gu (tensor 2.0 3.0 4.0))) + + (define sum-1-∇ + (λ (g t it st vz iz sz) + (for* ([i (in-range it (+ it st))]) + (vset! g i (vref vz iz))))) + + (define sum-∇ (ext1-∇ sum-1-∇ 1 (λ (s) '()))) + + (let ((gt (sum-∇ (tensor 2.0 3.0 4.0) + 1.0))) + (check-tensor-equal? gt (tensor 1.0 1.0 1.0))) + + (let ((gt (sum-∇ (tensor (tensor 2.0 3.0 4.0) + (tensor 2.0 3.0 4.0)) + (tensor 2.0 1.0)))) + (check-tensor-equal? gt (tensor (tensor 2.0 2.0 2.0) + (tensor 1.0 1.0 1.0)))))