Skip to content

Commit

Permalink
Merge pull request #177 from simpeg/feat/tree_functions
Browse files Browse the repository at this point in the history
Feat/tree functions
  • Loading branch information
lheagy authored Jul 8, 2019
2 parents 9ffbca3 + 8608d36 commit 67588be
Show file tree
Hide file tree
Showing 11 changed files with 909 additions and 167 deletions.
14 changes: 7 additions & 7 deletions discretize/TensorMesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,18 +91,18 @@ def _repr_html_(self):
fmt = "<table>\n"
fmt += " <tr>\n"
fmt += " <td style='font-weight: bold; font-size: 1.2em; text-align"
fmt += ": center;' colspan='3'>{}</td\n>".format(type(self).__name__)
fmt += ": center;' colspan='3'>{}</td>\n".format(type(self).__name__)
fmt += " <td style='font-size: 1.2em; text-align: center;'"
fmt += "colspan='4'>{:,} cells</td>\n".format(self.nC)
fmt += " </tr>\n"

fmt += " <tr>\n"
fmt += " <th></th\n>"
fmt += " <th></th\n>"
fmt += " <th colspan='2'"+style+">MESH EXTENT</th\n>"
fmt += " <th colspan='2'"+style+">CELL WIDTH</th\n>"
fmt += " <th"+style+">FACTOR</th\n>"
fmt += " </tr\n>"
fmt += " <th></th>\n"
fmt += " <th></th>\n"
fmt += " <th colspan='2'"+style+">MESH EXTENT</th>\n"
fmt += " <th colspan='2'"+style+">CELL WIDTH</th>\n"
fmt += " <th"+style+">FACTOR</th>\n"
fmt += " </tr>\n"

fmt += " <tr>\n"
fmt += " <th"+style+">dir</th>\n"
Expand Down
253 changes: 205 additions & 48 deletions discretize/TreeMesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
from .InnerProducts import InnerProducts
from .MeshIO import TreeMeshIO
from . import utils
from .tree_ext import _TreeMesh
from .tree_ext import _TreeMesh, TreeCell
import numpy as np
from scipy.spatial import Delaunay
import scipy.sparse as sp
Expand All @@ -113,55 +113,132 @@ def is_pow2(num): return ((num & (num - 1)) == 0) and num != 0
# Now can initialize cpp tree parent
_TreeMesh.__init__(self, self.h, self.x0)

def __str__(self):
outStr = ' ---- {0!s}TreeMesh ---- '.format(
('Oc' if self.dim == 3 else 'Quad')
)

def printH(hx, outStr=''):
i = -1
while True:
i = i + 1
if i > hx.size:
break
elif i == hx.size:
break
h = hx[i]
n = 1
for j in range(i+1, hx.size):
if hx[j] == h:
n = n + 1
i = i + 1
else:
break
if n == 1:
outStr += ' {0:.2f}, '.format(h)
else:
outStr += ' {0:d}*{1:.2f}, '.format(n, h)
return outStr[:-1]

if self.dim == 2:
outStr += '\n x0: {0:.2f}'.format(self.x0[0])
outStr += '\n y0: {0:.2f}'.format(self.x0[1])
outStr += printH(self.hx, outStr='\n hx:')
outStr += printH(self.hy, outStr='\n hy:')
elif self.dim == 3:
outStr += '\n x0: {0:.2f}'.format(self.x0[0])
outStr += '\n y0: {0:.2f}'.format(self.x0[1])
outStr += '\n z0: {0:.2f}'.format(self.x0[2])
outStr += printH(self.hx, outStr='\n hx:')
outStr += printH(self.hy, outStr='\n hy:')
outStr += printH(self.hz, outStr='\n hz:')
outStr += '\n nC: {0:d}'.format(self.nC)
outStr += '\n Fill: {0:2.2f}%'.format((self.fill*100))
return outStr
def __repr__(self):
"""Plain text representation."""
mesh_name = '{0!s}TreeMesh'.format(('Oc' if self.dim==3 else 'Quad'))

top = "\n"+mesh_name+": {0:2.2f}% filled\n\n".format(self.fill*100)

# Number of cells per level
level_count = self._count_cells_per_index()
non_zero_levels = np.nonzero(level_count)[0]
cell_display = ["Level : Number of cells"]
cell_display.append("-----------------------")
for level in non_zero_levels:
cell_display.append("{:^5} : {:^15}".format(level, level_count[level]))
cell_display.append("-----------------------")
cell_display.append("Total : {:^15}".format(self.nC))

extent_display = [" Mesh Extent "]
extent_display.append(" min , max ")
extent_display.append(" ---------------------------")
dim_label = {0:'x',1:'y',2:'z'}
for dim in range(self.dim):
n_vector = getattr(self, 'vectorN'+dim_label[dim])
extent_display.append("{}: {:^13},{:^13}".format(dim_label[dim], n_vector[0], n_vector[-1]))

for i, line in enumerate(extent_display):
if i==len(cell_display):
cell_display.append(" "*(len(cell_display[0])-3-len(line)))
cell_display[i] += 3*" " + line

h_display = [' Cell Widths ']
h_display.append(" min , max ")
h_display.append("-"*(len(h_display[0])))
h_gridded = self.h_gridded
mins = np.min(h_gridded,axis=0)
maxs = np.max(h_gridded,axis=0)
for dim in range(self.dim):
h_display.append("{:^10}, {:^10}".format(mins[dim], maxs[dim]))

for i, line in enumerate(h_display):
if i==len(cell_display):
cell_display.append(" "*len(cell_display[0]))
cell_display[i] += 3*" " + line

return top+"\n".join(cell_display)

def _repr_html_(self):
"""html representation"""
mesh_name = '{0!s}TreeMesh'.format(('Oc' if self.dim==3 else 'Quad'))
level_count = self._count_cells_per_index()
non_zero_levels = np.nonzero(level_count)[0]
dim_label = {0:'x',1:'y',2:'z'}
h_gridded = self.h_gridded
mins = np.min(h_gridded,axis=0)
maxs = np.max(h_gridded,axis=0)

style = " style='padding: 5px 20px 5px 20px;'"
#Cell level table:
cel_tbl = "<table>\n"
cel_tbl += "<tr>\n"
cel_tbl += "<th"+style+">Level</th>\n"
cel_tbl += "<th"+style+">Number of cells</th>\n"
cel_tbl += "</tr>\n"
for level in non_zero_levels:
cel_tbl += "<tr>\n"
cel_tbl += "<td"+style+">{}</td>\n".format(level)
cel_tbl += "<td"+style+">{}</td>\n".format(level_count[level])
cel_tbl += "</tr>\n"
cel_tbl += "<tr>\n"
cel_tbl += "<td style='font-weight: bold; padding: 5px 20px 5px 20px;'> Total </td>\n"
cel_tbl += "<td"+style+"> {} </td>\n".format(self.nC)
cel_tbl += "</tr>\n"
cel_tbl += "</table>\n"

det_tbl = "<table>\n"
det_tbl += "<tr>\n"
det_tbl += "<th></th>\n"
det_tbl += "<th"+style+" colspan='2'>Mesh extent</th>\n"
det_tbl += "<th"+style+" colspan='2'>Cell widths</th>\n"
det_tbl += "</tr>\n"

det_tbl += "<tr>\n"
det_tbl += "<th></th>\n"
det_tbl += "<th"+style+">min</th>\n"
det_tbl += "<th"+style+">max</th>\n"
det_tbl += "<th"+style+">min</th>\n"
det_tbl += "<th"+style+">max</th>\n"
det_tbl += "</tr>\n"
for dim in range(self.dim):
n_vector = getattr(self, 'vectorN'+dim_label[dim])
det_tbl += "<tr>\n"
det_tbl += "<td"+style+">{}</td>\n".format(dim_label[dim])
det_tbl += "<td"+style+">{}</td>\n".format(n_vector[0])
det_tbl += "<td"+style+">{}</td>\n".format(n_vector[-1])
det_tbl += "<td"+style+">{}</td>\n".format(mins[dim])
det_tbl += "<td"+style+">{}</td>\n".format(maxs[dim])
det_tbl += "</tr>\n"
det_tbl += "</table>\n"

full_tbl = "<table>\n"
full_tbl += "<tr>\n"
full_tbl += "<td style='font-weight: bold; font-size: 1.2em; text-align: center;'>{}</td>\n".format(mesh_name)
full_tbl += "<td style='font-size: 1.2em; text-align: center;' colspan='2'>{0:2.2f}% filled</td>\n".format(100*self.fill)
full_tbl += "</tr>\n"
full_tbl += "<tr>\n"

full_tbl += "<td>\n"
full_tbl += cel_tbl
full_tbl += "</td>\n"

full_tbl += "<td>\n"
full_tbl += det_tbl
full_tbl += "</td>\n"

full_tbl += "</tr>\n"
full_tbl += "</table>\n"

return full_tbl

@property
def vntF(self):
"""Total number of hanging and non-hanging faces in a [nx,ny,nz] form"""
return [self.ntFx, self.ntFy] + ([] if self.dim == 2 else [self.ntFz])

@property
def vntE(self):
"""Total number of hanging and non-hanging edges in a [nx,ny,nz] form"""
return [self.ntEx, self.ntEy] + ([] if self.dim == 2 else [self.ntEz])

@property
Expand Down Expand Up @@ -242,7 +319,7 @@ def cellGradx(self):
@property
def cellGrady(self):
"""
Cell centered Gradient operator in y-direction (Gradx)
Cell centered Gradient operator in y-direction (Grady)
Grad = sp.vstack((Gradx, Grady, Gradz))
"""
if getattr(self, '_cellGrady', None) is None:
Expand Down Expand Up @@ -271,6 +348,8 @@ def cellGradz(self):
Cell centered Gradient operator in z-direction (Gradz)
Grad = sp.vstack((Gradx, Grady, Gradz))
"""
if self.dim == 2:
raise TypeError("z derivative not defined in 2D")
if getattr(self, '_cellGradz', None) is None:

nFx = self.nFx
Expand Down Expand Up @@ -310,21 +389,98 @@ def faceDivz(self):
return self._faceDivz

def point2index(self, locs):
"""Finds cells that contain the given points.
Returns an array of index values of the cells that contain the given
points
Parameters
----------
locs: array_like of shape (N, dim)
points to search for the location of
Returns
-------
numpy.array of integers of length(N)
Cell indices that contain the points
"""
locs = utils.asArray_N_x_Dim(locs, self.dim)

inds = np.empty(locs.shape[0], dtype=np.int64)
for ind, loc in enumerate(locs):
inds[ind] = self._get_containing_cell_index(loc)
inds = self._get_containing_cell_indexes(locs)
return inds

def cell_levels_by_index(self, indices):
"""Fast function to return a list of levels for the given cell indices
Parameters
----------
index: array_like of length (N)
Cell indexes to query
Returns
-------
numpy.array of length (N)
Levels for the cells.
"""

return self._cell_levels_by_indexes(indices)


def getInterpolationMat(self, locs, locType, zerosOutside=False):
""" Produces interpolation matrix
Parameters
----------
loc : numpy.ndarray
Location of points to interpolate to
locType: str
What to interpolate
locType can be::
'Ex' -> x-component of field defined on edges
'Ey' -> y-component of field defined on edges
'Ez' -> z-component of field defined on edges
'Fx' -> x-component of field defined on faces
'Fy' -> y-component of field defined on faces
'Fz' -> z-component of field defined on faces
'N' -> scalar field defined on nodes
'CC' -> scalar field defined on cell centers
Returns
-------
scipy.sparse.csr_matrix
M, the interpolation matrix
"""
locs = utils.asArray_N_x_Dim(locs, self.dim)
if locType not in ['N', 'CC', "Ex", "Ey", "Ez", "Fx", "Fy", "Fz"]:
raise Exception('locType must be one of N, CC, Ex, Ey, Ez, Fx, Fy, or Fz')

if self.dim == 2 and locType in ['Ez', 'Fz']:
raise Exception('Unable to interpolate from Z edges/face in 2D')

locs = np.require(np.atleast_2d(locs), dtype=np.float64, requirements='C')

if locType == 'N':
Av = self._getNodeIntMat(locs, zerosOutside)
elif locType in ['Ex', 'Ey', 'Ez']:
Av = self._getEdgeIntMat(locs, zerosOutside, locType[1])
elif locType in ['Fx', 'Fy', 'Fz']:
Av = self._getFaceIntMat(locs, zerosOutside, locType[1])
elif locType in ['CC']:
Av = self._getCellIntMat(locs, zerosOutside)
return Av

@property
def permuteCC(self):
"""Permutation matrix re-ordering of cells sorted by x, then y, then z"""
# TODO: cache these?
P = np.lexsort(self.gridCC.T) # sort by x, then y, then z
return sp.identity(self.nC).tocsr()[P]

@property
def permuteF(self):
"""Permutation matrix re-ordering of faces sorted by x, then y, then z"""
# TODO: cache these?
Px = np.lexsort(self.gridFx.T)
Py = np.lexsort(self.gridFy.T)+self.nFx
Expand All @@ -337,6 +493,7 @@ def permuteF(self):

@property
def permuteE(self):
"""Permutation matrix re-ordering of edges sorted by x, then y, then z"""
# TODO: cache these?
Px = np.lexsort(self.gridEx.T)
Py = np.lexsort(self.gridEy.T) + self.nEx
Expand Down
8 changes: 5 additions & 3 deletions discretize/View.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,9 @@ def plotGrid(
else:
if not isinstance(ax, matplotlib.axes.Axes):
raise AssertionError("ax must be an matplotlib.axes.Axes")
if lines:
color = kwargs.get('color', 'C0')
linewidth = kwargs.get('linewidth', 1.)

if self.dim == 1:
if nodes:
Expand Down Expand Up @@ -660,8 +663,7 @@ def plotGrid(
marker="^", linestyle=""
)

color = kwargs.get('color', 'C0')
linewidth = kwargs.get('linewidth', 1.)

# Plot the grid lines
if lines:
NN = self.r(self.gridN, 'N', 'N', 'M')
Expand Down Expand Up @@ -728,7 +730,7 @@ def plotGrid(
X = np.r_[X1, X2, X3]
Y = np.r_[Y1, Y2, Y3]
Z = np.r_[Z1, Z2, Z3]
ax.plot(X, Y, color="C0", linestyle="-", zs=Z)
ax.plot(X, Y, color=color, linestyle="-", lw=linewidth, zs=Z)
ax.set_xlabel('x1')
ax.set_ylabel('x2')
ax.set_zlabel('x3')
Expand Down
2 changes: 2 additions & 0 deletions discretize/tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ Cell::Cell(Node *pts[8], int_t ndim, int_t maxlevel, function func){
int_t n_points = 1<<n_dim;
for(int_t i = 0; i < n_points; ++i)
points[i] = pts[i];
index = -1;
level = 0;
max_level = maxlevel;
parent = NULL;
Expand Down Expand Up @@ -212,6 +213,7 @@ Cell::Cell(Node *pts[8], Cell *parent){
int_t n_points = 1<<n_dim;
for(int_t i = 0; i < n_points; ++i)
points[i] = pts[i];
index = -1;
level = parent->level + 1;
max_level = parent->max_level;
test_func = parent->test_func;
Expand Down
3 changes: 2 additions & 1 deletion discretize/tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ class Cell{
Edge *edges[12];
Face *faces[6];

int_t location_ind[3], index, key, level, max_level;
int_t location_ind[3], key, level, max_level;
long long int index; // non root parents will have a -1 value
double location[3];
double volume;
function test_func;
Expand Down
Loading

0 comments on commit 67588be

Please sign in to comment.