Skip to content

Commit

Permalink
Merge pull request #49 from daducci/fix_segfault
Browse files Browse the repository at this point in the history
Fix segfault when using new regularization scheme
  • Loading branch information
daducci authored Jun 1, 2018
2 parents b058584 + 1b69c3f commit c7bdc81
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 22 deletions.
35 changes: 16 additions & 19 deletions commit/proximals.pyx
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!python
#cython: boundscheck=False, wraparound=False
"""
Author: Matteo Frigo - lts5 @ EPFL and Dep. of CS @ Univ. of Verona
Expand All @@ -10,9 +12,6 @@ cimport numpy as np
from math import sqrt
import sys

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.profile(False)

cpdef non_negativity(np.ndarray[np.float64_t] x, int compartment_start, int compartment_size):
"""
Expand All @@ -22,11 +21,12 @@ cpdef non_negativity(np.ndarray[np.float64_t] x, int compartment_start, int comp
np.ndarray[np.float64_t] v
size_t i
v = x.copy()
for i in range(compartment_start, compartment_size):
for i in range(compartment_start, compartment_start+compartment_size):
if v[i] < 0.0:
v[i] = 0.0
return v


cpdef soft_thresholding(np.ndarray[np.float64_t] x, double lam, int compartment_start, int compartment_size) :
"""
Proximal of L1 norm
Expand All @@ -36,13 +36,14 @@ cpdef soft_thresholding(np.ndarray[np.float64_t] x, double lam, int compartment_
np.ndarray[np.float64_t] v
size_t i
v = x.copy()
for i in range(compartment_start, compartment_size):
for i in range(compartment_start, compartment_start+compartment_size):
if v[i] <= lam:
v[i] = 0.0
else:
v[i] -= lam
return v


cpdef projection_onto_l2_ball(np.ndarray[np.float64_t] x, double lam, int compartment_start, int compartment_size) :
"""
Proximal of L2 norm
Expand All @@ -53,46 +54,45 @@ cpdef projection_onto_l2_ball(np.ndarray[np.float64_t] x, double lam, int compar
np.ndarray[np.float64_t] v
size_t i
v = x.copy()
xn = sqrt(sum(v[compartment_start:compartment_size]**2))
xn = sqrt(sum(v[compartment_start:compartment_start+compartment_size]**2))
if xn > lam:
for i in range(compartment_start, compartment_size):
for i in range(compartment_start, compartment_start+compartment_size):
v[i] = v[i]/xn*lam
return v


cpdef omega_group_sparsity(np.ndarray[np.float64_t] v, np.ndarray[object] subtree, np.ndarray[np.float64_t] weight, double lam, double n) :
"""
References:
[1] Jenatton et al. - `Proximal Methods for Hierarchical Sparse Coding`
"""
cdef:
int nG = weight.size
size_t k, i
double xn, tmp = 0.0
size_t k
double tmp = 0.0

if lam != 0:
if n == 2:
for k in range(nG):
idx = subtree[k]
xn = 0.0
for i in idx:
xn += v[i]*v[i]
tmp += weight[k] * sqrt( xn )
tmp += weight[k] * sqrt( sum(v[idx]**2) )
elif n == np.Inf:
for k in range(nG):
idx = subtree[k]
tmp += weight[k] * max( v[idx] )
return lam*tmp


cpdef prox_group_sparsity( np.ndarray[np.float64_t] x, np.ndarray[object] subtree, np.ndarray[np.float64_t] weight, double lam, double n ) :
"""
References:
[1] Jenatton et al. - `Proximal Methods for Hierarchical Sparse Coding`
"""
cdef:
np.ndarray[np.float64_t] v
int nG = weight.size, N, rho
int nG = weight.size
size_t k, i
double r, xn, theta
double r, xn

v = x.copy()
v[v<0] = 0.0
Expand All @@ -111,10 +111,7 @@ cpdef prox_group_sparsity( np.ndarray[np.float64_t] x, np.ndarray[object] subtre
if n == 2:
for k in range(nG):
idx = subtree[k]
xn = 0.0
for i in idx:
xn += v[i]*v[i]
xn = sqrt(xn)
xn = sqrt( sum(v[idx]**2) )
r = weight[k] * lam
if xn > r:
r = (xn-r)/xn
Expand Down
6 changes: 3 additions & 3 deletions commit/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def regularisation2omegaprox(regularisation):
sizeIC = regularisation.get('sizeIC')
if lambdaIC == 0.0:
omegaIC = lambda x: 0.0
proxIC = lambda x: non_negativity(x, startIC, sizeIC)
proxIC = lambda x: x
elif normIC == norm2:
omegaIC = lambda x: lambdaIC * np.linalg.norm(x[startIC:sizeIC])
proxIC = lambda x: projection_onto_l2_ball(x, lambdaIC, startIC, sizeIC)
Expand Down Expand Up @@ -205,7 +205,7 @@ def regularisation2omegaprox(regularisation):
sizeEC = regularisation.get('sizeEC')
if lambdaEC == 0.0:
omegaEC = lambda x: 0.0
proxEC = lambda x: non_negativity(x, startEC, sizeEC)
proxEC = lambda x: x
elif normEC == norm2:
omegaEC = lambda x: lambdaEC * np.linalg.norm(x[startEC:sizeEC])
proxEC = lambda x: projection_onto_l2_ball(x, lambdaEC, startEC, sizeEC)
Expand All @@ -223,7 +223,7 @@ def regularisation2omegaprox(regularisation):
sizeISO = regularisation.get('sizeISO')
if lambdaISO == 0.0:
omegaISO = lambda x: 0.0
proxISO = lambda x: non_negativity(x, startISO, sizeISO)
proxISO = lambda x: x
elif normISO == norm2:
omegaISO = lambda x: lambdaISO * np.linalg.norm(x[startISO:sizeISO])
proxISO = lambda x: projection_onto_l2_ball(x, lambdaISO, startISO, sizeISO)
Expand Down

0 comments on commit c7bdc81

Please sign in to comment.