Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
antemons committed May 27, 2024
1 parent 0710b06 commit 4d7fec7
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 5 deletions.
32 changes: 32 additions & 0 deletions geometricalgebra/cga3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,35 @@ def get_origin_pose(weight: float = 1.0):
return Vector.concatenate(
[e_0[None], weight * TETRAHEDRON / 2, weight * Vector.from_scalar(np.sqrt(3), pseudo=True)[None]]
)


def classify(vector: Vector, eps=1e-8):
"""Classify the vector
Args:
vector: the vector to classify
Returns:
the classification of the vector
"""
if len(vector.grades) == 1:
if vector.grade == 0:
return "scalar"
if vector.grade == 1:
pass

if vector.grade == 5:
return "pseudo-scalar"


if vector.grades == {1}:
if vector.square_norm() < eps:
return "flat point"
else:
raise NotImplementedError()

if vector.is_flat():
pass

raise NotImplementedError("multi-grade vector")

7 changes: 6 additions & 1 deletion geometricalgebra/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
"""Collection of utilities"""

from geometricalgebra.vector import FRAMEWORK

from typing import Tuple

def solve_gauss_normal_equations(a, b):
xnp = FRAMEWORK.numpy
at_a_inv = xnp.linalg.pinv(xnp.einsum("...jk,...ji->...ki", a, a))
return xnp.einsum("...ij,...kj,...k->...i", at_a_inv, a, b)


def reshape_last_dimensions(array, num_of_dim: int, new_shape: Tuple[int, ...]):
"""Reshape a tensor with shape (..., 4, 3) to shape (..., 12)"""
return array.reshape([*array.shape[:-num_of_dim], *new_shape])
5 changes: 1 addition & 4 deletions geometricalgebra/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class Framework(NamedTuple):
softmax: Any


ga_numpy = os.environ.get("GEOMETRICALGEBRA_NUMPY", "numpy")
ga_numpy = os.environ.get("GEOMETRICALGEBRA_NUMPY", "tensorflow")
if ga_numpy == "jax":
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -117,9 +117,6 @@ def algebra(cls):
def _algebra(self):
return self.algebra

# def framework(self):
# return jnp

def numpy(self):
return self

Expand Down

0 comments on commit 4d7fec7

Please sign in to comment.