Skip to content

Commit 67588be

Browse files
authored
Merge pull request #177 from simpeg/feat/tree_functions
Feat/tree functions
2 parents 9ffbca3 + 8608d36 commit 67588be

11 files changed

+909
-167
lines changed

discretize/TensorMesh.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -91,18 +91,18 @@ def _repr_html_(self):
9191
fmt = "<table>\n"
9292
fmt += " <tr>\n"
9393
fmt += " <td style='font-weight: bold; font-size: 1.2em; text-align"
94-
fmt += ": center;' colspan='3'>{}</td\n>".format(type(self).__name__)
94+
fmt += ": center;' colspan='3'>{}</td>\n".format(type(self).__name__)
9595
fmt += " <td style='font-size: 1.2em; text-align: center;'"
9696
fmt += "colspan='4'>{:,} cells</td>\n".format(self.nC)
9797
fmt += " </tr>\n"
9898

9999
fmt += " <tr>\n"
100-
fmt += " <th></th\n>"
101-
fmt += " <th></th\n>"
102-
fmt += " <th colspan='2'"+style+">MESH EXTENT</th\n>"
103-
fmt += " <th colspan='2'"+style+">CELL WIDTH</th\n>"
104-
fmt += " <th"+style+">FACTOR</th\n>"
105-
fmt += " </tr\n>"
100+
fmt += " <th></th>\n"
101+
fmt += " <th></th>\n"
102+
fmt += " <th colspan='2'"+style+">MESH EXTENT</th>\n"
103+
fmt += " <th colspan='2'"+style+">CELL WIDTH</th>\n"
104+
fmt += " <th"+style+">FACTOR</th>\n"
105+
fmt += " </tr>\n"
106106

107107
fmt += " <tr>\n"
108108
fmt += " <th"+style+">dir</th>\n"

discretize/TreeMesh.py

+205-48
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@
8888
from .InnerProducts import InnerProducts
8989
from .MeshIO import TreeMeshIO
9090
from . import utils
91-
from .tree_ext import _TreeMesh
91+
from .tree_ext import _TreeMesh, TreeCell
9292
import numpy as np
9393
from scipy.spatial import Delaunay
9494
import scipy.sparse as sp
@@ -113,55 +113,132 @@ def is_pow2(num): return ((num & (num - 1)) == 0) and num != 0
113113
# Now can initialize cpp tree parent
114114
_TreeMesh.__init__(self, self.h, self.x0)
115115

116-
def __str__(self):
117-
outStr = ' ---- {0!s}TreeMesh ---- '.format(
118-
('Oc' if self.dim == 3 else 'Quad')
119-
)
120-
121-
def printH(hx, outStr=''):
122-
i = -1
123-
while True:
124-
i = i + 1
125-
if i > hx.size:
126-
break
127-
elif i == hx.size:
128-
break
129-
h = hx[i]
130-
n = 1
131-
for j in range(i+1, hx.size):
132-
if hx[j] == h:
133-
n = n + 1
134-
i = i + 1
135-
else:
136-
break
137-
if n == 1:
138-
outStr += ' {0:.2f}, '.format(h)
139-
else:
140-
outStr += ' {0:d}*{1:.2f}, '.format(n, h)
141-
return outStr[:-1]
142-
143-
if self.dim == 2:
144-
outStr += '\n x0: {0:.2f}'.format(self.x0[0])
145-
outStr += '\n y0: {0:.2f}'.format(self.x0[1])
146-
outStr += printH(self.hx, outStr='\n hx:')
147-
outStr += printH(self.hy, outStr='\n hy:')
148-
elif self.dim == 3:
149-
outStr += '\n x0: {0:.2f}'.format(self.x0[0])
150-
outStr += '\n y0: {0:.2f}'.format(self.x0[1])
151-
outStr += '\n z0: {0:.2f}'.format(self.x0[2])
152-
outStr += printH(self.hx, outStr='\n hx:')
153-
outStr += printH(self.hy, outStr='\n hy:')
154-
outStr += printH(self.hz, outStr='\n hz:')
155-
outStr += '\n nC: {0:d}'.format(self.nC)
156-
outStr += '\n Fill: {0:2.2f}%'.format((self.fill*100))
157-
return outStr
116+
def __repr__(self):
117+
"""Plain text representation."""
118+
mesh_name = '{0!s}TreeMesh'.format(('Oc' if self.dim==3 else 'Quad'))
119+
120+
top = "\n"+mesh_name+": {0:2.2f}% filled\n\n".format(self.fill*100)
121+
122+
# Number of cells per level
123+
level_count = self._count_cells_per_index()
124+
non_zero_levels = np.nonzero(level_count)[0]
125+
cell_display = ["Level : Number of cells"]
126+
cell_display.append("-----------------------")
127+
for level in non_zero_levels:
128+
cell_display.append("{:^5} : {:^15}".format(level, level_count[level]))
129+
cell_display.append("-----------------------")
130+
cell_display.append("Total : {:^15}".format(self.nC))
131+
132+
extent_display = [" Mesh Extent "]
133+
extent_display.append(" min , max ")
134+
extent_display.append(" ---------------------------")
135+
dim_label = {0:'x',1:'y',2:'z'}
136+
for dim in range(self.dim):
137+
n_vector = getattr(self, 'vectorN'+dim_label[dim])
138+
extent_display.append("{}: {:^13},{:^13}".format(dim_label[dim], n_vector[0], n_vector[-1]))
139+
140+
for i, line in enumerate(extent_display):
141+
if i==len(cell_display):
142+
cell_display.append(" "*(len(cell_display[0])-3-len(line)))
143+
cell_display[i] += 3*" " + line
144+
145+
h_display = [' Cell Widths ']
146+
h_display.append(" min , max ")
147+
h_display.append("-"*(len(h_display[0])))
148+
h_gridded = self.h_gridded
149+
mins = np.min(h_gridded,axis=0)
150+
maxs = np.max(h_gridded,axis=0)
151+
for dim in range(self.dim):
152+
h_display.append("{:^10}, {:^10}".format(mins[dim], maxs[dim]))
153+
154+
for i, line in enumerate(h_display):
155+
if i==len(cell_display):
156+
cell_display.append(" "*len(cell_display[0]))
157+
cell_display[i] += 3*" " + line
158+
159+
return top+"\n".join(cell_display)
160+
161+
def _repr_html_(self):
162+
"""html representation"""
163+
mesh_name = '{0!s}TreeMesh'.format(('Oc' if self.dim==3 else 'Quad'))
164+
level_count = self._count_cells_per_index()
165+
non_zero_levels = np.nonzero(level_count)[0]
166+
dim_label = {0:'x',1:'y',2:'z'}
167+
h_gridded = self.h_gridded
168+
mins = np.min(h_gridded,axis=0)
169+
maxs = np.max(h_gridded,axis=0)
170+
171+
style = " style='padding: 5px 20px 5px 20px;'"
172+
#Cell level table:
173+
cel_tbl = "<table>\n"
174+
cel_tbl += "<tr>\n"
175+
cel_tbl += "<th"+style+">Level</th>\n"
176+
cel_tbl += "<th"+style+">Number of cells</th>\n"
177+
cel_tbl += "</tr>\n"
178+
for level in non_zero_levels:
179+
cel_tbl += "<tr>\n"
180+
cel_tbl += "<td"+style+">{}</td>\n".format(level)
181+
cel_tbl += "<td"+style+">{}</td>\n".format(level_count[level])
182+
cel_tbl += "</tr>\n"
183+
cel_tbl += "<tr>\n"
184+
cel_tbl += "<td style='font-weight: bold; padding: 5px 20px 5px 20px;'> Total </td>\n"
185+
cel_tbl += "<td"+style+"> {} </td>\n".format(self.nC)
186+
cel_tbl += "</tr>\n"
187+
cel_tbl += "</table>\n"
188+
189+
det_tbl = "<table>\n"
190+
det_tbl += "<tr>\n"
191+
det_tbl += "<th></th>\n"
192+
det_tbl += "<th"+style+" colspan='2'>Mesh extent</th>\n"
193+
det_tbl += "<th"+style+" colspan='2'>Cell widths</th>\n"
194+
det_tbl += "</tr>\n"
195+
196+
det_tbl += "<tr>\n"
197+
det_tbl += "<th></th>\n"
198+
det_tbl += "<th"+style+">min</th>\n"
199+
det_tbl += "<th"+style+">max</th>\n"
200+
det_tbl += "<th"+style+">min</th>\n"
201+
det_tbl += "<th"+style+">max</th>\n"
202+
det_tbl += "</tr>\n"
203+
for dim in range(self.dim):
204+
n_vector = getattr(self, 'vectorN'+dim_label[dim])
205+
det_tbl += "<tr>\n"
206+
det_tbl += "<td"+style+">{}</td>\n".format(dim_label[dim])
207+
det_tbl += "<td"+style+">{}</td>\n".format(n_vector[0])
208+
det_tbl += "<td"+style+">{}</td>\n".format(n_vector[-1])
209+
det_tbl += "<td"+style+">{}</td>\n".format(mins[dim])
210+
det_tbl += "<td"+style+">{}</td>\n".format(maxs[dim])
211+
det_tbl += "</tr>\n"
212+
det_tbl += "</table>\n"
213+
214+
full_tbl = "<table>\n"
215+
full_tbl += "<tr>\n"
216+
full_tbl += "<td style='font-weight: bold; font-size: 1.2em; text-align: center;'>{}</td>\n".format(mesh_name)
217+
full_tbl += "<td style='font-size: 1.2em; text-align: center;' colspan='2'>{0:2.2f}% filled</td>\n".format(100*self.fill)
218+
full_tbl += "</tr>\n"
219+
full_tbl += "<tr>\n"
220+
221+
full_tbl += "<td>\n"
222+
full_tbl += cel_tbl
223+
full_tbl += "</td>\n"
224+
225+
full_tbl += "<td>\n"
226+
full_tbl += det_tbl
227+
full_tbl += "</td>\n"
228+
229+
full_tbl += "</tr>\n"
230+
full_tbl += "</table>\n"
231+
232+
return full_tbl
158233

159234
@property
160235
def vntF(self):
236+
"""Total number of hanging and non-hanging faces in a [nx,ny,nz] form"""
161237
return [self.ntFx, self.ntFy] + ([] if self.dim == 2 else [self.ntFz])
162238

163239
@property
164240
def vntE(self):
241+
"""Total number of hanging and non-hanging edges in a [nx,ny,nz] form"""
165242
return [self.ntEx, self.ntEy] + ([] if self.dim == 2 else [self.ntEz])
166243

167244
@property
@@ -242,7 +319,7 @@ def cellGradx(self):
242319
@property
243320
def cellGrady(self):
244321
"""
245-
Cell centered Gradient operator in y-direction (Gradx)
322+
Cell centered Gradient operator in y-direction (Grady)
246323
Grad = sp.vstack((Gradx, Grady, Gradz))
247324
"""
248325
if getattr(self, '_cellGrady', None) is None:
@@ -271,6 +348,8 @@ def cellGradz(self):
271348
Cell centered Gradient operator in z-direction (Gradz)
272349
Grad = sp.vstack((Gradx, Grady, Gradz))
273350
"""
351+
if self.dim == 2:
352+
raise TypeError("z derivative not defined in 2D")
274353
if getattr(self, '_cellGradz', None) is None:
275354

276355
nFx = self.nFx
@@ -310,21 +389,98 @@ def faceDivz(self):
310389
return self._faceDivz
311390

312391
def point2index(self, locs):
392+
"""Finds cells that contain the given points.
393+
Returns an array of index values of the cells that contain the given
394+
points
395+
396+
Parameters
397+
----------
398+
locs: array_like of shape (N, dim)
399+
points to search for the location of
400+
401+
Returns
402+
-------
403+
numpy.array of integers of length(N)
404+
Cell indices that contain the points
405+
"""
313406
locs = utils.asArray_N_x_Dim(locs, self.dim)
314-
315-
inds = np.empty(locs.shape[0], dtype=np.int64)
316-
for ind, loc in enumerate(locs):
317-
inds[ind] = self._get_containing_cell_index(loc)
407+
inds = self._get_containing_cell_indexes(locs)
318408
return inds
319409

410+
def cell_levels_by_index(self, indices):
411+
"""Fast function to return a list of levels for the given cell indices
412+
413+
Parameters
414+
----------
415+
index: array_like of length (N)
416+
Cell indexes to query
417+
418+
Returns
419+
-------
420+
numpy.array of length (N)
421+
Levels for the cells.
422+
"""
423+
424+
return self._cell_levels_by_indexes(indices)
425+
426+
427+
def getInterpolationMat(self, locs, locType, zerosOutside=False):
428+
""" Produces interpolation matrix
429+
430+
Parameters
431+
----------
432+
loc : numpy.ndarray
433+
Location of points to interpolate to
434+
435+
locType: str
436+
What to interpolate
437+
438+
locType can be::
439+
440+
'Ex' -> x-component of field defined on edges
441+
'Ey' -> y-component of field defined on edges
442+
'Ez' -> z-component of field defined on edges
443+
'Fx' -> x-component of field defined on faces
444+
'Fy' -> y-component of field defined on faces
445+
'Fz' -> z-component of field defined on faces
446+
'N' -> scalar field defined on nodes
447+
'CC' -> scalar field defined on cell centers
448+
449+
Returns
450+
-------
451+
scipy.sparse.csr_matrix
452+
M, the interpolation matrix
453+
454+
"""
455+
locs = utils.asArray_N_x_Dim(locs, self.dim)
456+
if locType not in ['N', 'CC', "Ex", "Ey", "Ez", "Fx", "Fy", "Fz"]:
457+
raise Exception('locType must be one of N, CC, Ex, Ey, Ez, Fx, Fy, or Fz')
458+
459+
if self.dim == 2 and locType in ['Ez', 'Fz']:
460+
raise Exception('Unable to interpolate from Z edges/face in 2D')
461+
462+
locs = np.require(np.atleast_2d(locs), dtype=np.float64, requirements='C')
463+
464+
if locType == 'N':
465+
Av = self._getNodeIntMat(locs, zerosOutside)
466+
elif locType in ['Ex', 'Ey', 'Ez']:
467+
Av = self._getEdgeIntMat(locs, zerosOutside, locType[1])
468+
elif locType in ['Fx', 'Fy', 'Fz']:
469+
Av = self._getFaceIntMat(locs, zerosOutside, locType[1])
470+
elif locType in ['CC']:
471+
Av = self._getCellIntMat(locs, zerosOutside)
472+
return Av
473+
320474
@property
321475
def permuteCC(self):
476+
"""Permutation matrix re-ordering of cells sorted by x, then y, then z"""
322477
# TODO: cache these?
323478
P = np.lexsort(self.gridCC.T) # sort by x, then y, then z
324479
return sp.identity(self.nC).tocsr()[P]
325480

326481
@property
327482
def permuteF(self):
483+
"""Permutation matrix re-ordering of faces sorted by x, then y, then z"""
328484
# TODO: cache these?
329485
Px = np.lexsort(self.gridFx.T)
330486
Py = np.lexsort(self.gridFy.T)+self.nFx
@@ -337,6 +493,7 @@ def permuteF(self):
337493

338494
@property
339495
def permuteE(self):
496+
"""Permutation matrix re-ordering of edges sorted by x, then y, then z"""
340497
# TODO: cache these?
341498
Px = np.lexsort(self.gridEx.T)
342499
Py = np.lexsort(self.gridEy.T) + self.nEx

discretize/View.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,9 @@ def plotGrid(
613613
else:
614614
if not isinstance(ax, matplotlib.axes.Axes):
615615
raise AssertionError("ax must be an matplotlib.axes.Axes")
616+
if lines:
617+
color = kwargs.get('color', 'C0')
618+
linewidth = kwargs.get('linewidth', 1.)
616619

617620
if self.dim == 1:
618621
if nodes:
@@ -660,8 +663,7 @@ def plotGrid(
660663
marker="^", linestyle=""
661664
)
662665

663-
color = kwargs.get('color', 'C0')
664-
linewidth = kwargs.get('linewidth', 1.)
666+
665667
# Plot the grid lines
666668
if lines:
667669
NN = self.r(self.gridN, 'N', 'N', 'M')
@@ -728,7 +730,7 @@ def plotGrid(
728730
X = np.r_[X1, X2, X3]
729731
Y = np.r_[Y1, Y2, Y3]
730732
Z = np.r_[Z1, Z2, Z3]
731-
ax.plot(X, Y, color="C0", linestyle="-", zs=Z)
733+
ax.plot(X, Y, color=color, linestyle="-", lw=linewidth, zs=Z)
732734
ax.set_xlabel('x1')
733735
ax.set_ylabel('x2')
734736
ax.set_zlabel('x3')

discretize/tree.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ Cell::Cell(Node *pts[8], int_t ndim, int_t maxlevel, function func){
185185
int_t n_points = 1<<n_dim;
186186
for(int_t i = 0; i < n_points; ++i)
187187
points[i] = pts[i];
188+
index = -1;
188189
level = 0;
189190
max_level = maxlevel;
190191
parent = NULL;
@@ -212,6 +213,7 @@ Cell::Cell(Node *pts[8], Cell *parent){
212213
int_t n_points = 1<<n_dim;
213214
for(int_t i = 0; i < n_points; ++i)
214215
points[i] = pts[i];
216+
index = -1;
215217
level = parent->level + 1;
216218
max_level = parent->max_level;
217219
test_func = parent->test_func;

discretize/tree.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ class Cell{
106106
Edge *edges[12];
107107
Face *faces[6];
108108

109-
int_t location_ind[3], index, key, level, max_level;
109+
int_t location_ind[3], key, level, max_level;
110+
long long int index; // non root parents will have a -1 value
110111
double location[3];
111112
double volume;
112113
function test_func;

0 commit comments

Comments
 (0)