Skip to content

Commit

Permalink
reinitialize pc_model for splits
Browse files Browse the repository at this point in the history
  • Loading branch information
Caitlin Curry committed Jun 8, 2023
1 parent 8b61940 commit 9ca2fc6
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 58 deletions.
22 changes: 11 additions & 11 deletions PyUQTk/PyPCE/pce_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def UQTkRegression(pc_model,f_evaluations, samplepts):
# Return numpy array of PC coefficients
return c_k
################################################################################
def UQTkBCS(pc_model, xdata, ydata, niter, eta, ntry=1, eta_folds=5,\
def UQTkBCS(pc_begin, xdata, ydata, niter, eta, ntry=1, eta_folds=5,\
mindex_growth='nonconservative', regparams=None, sigma2=1e-8, trval_frac=None,\
npccut=None, pcf_thr=None, verbose=0, eta_plot=False):
"""
Expand All @@ -381,7 +381,7 @@ def UQTkBCS(pc_model, xdata, ydata, niter, eta, ntry=1, eta_folds=5,\
ToDo: add documentation in UQTk manual on what BCS is and the basis growth schemes
Input:
pc_model: PC object with information about the starting basis
pc_begin: PC object with information about the starting basis
xdata: N-dimensional NumPy array with sample points [#samples,
#dimensions]
ydata: 1D numpy array (vector) with function, evaluated at the
Expand Down Expand Up @@ -428,7 +428,7 @@ def UQTkBCS(pc_model, xdata, ydata, niter, eta, ntry=1, eta_folds=5,\
eta_opt = eta
elif (type(eta)==np.ndarray or type(eta)==list):
# the eta with the lowest RMSE is selected from etas
eta_opt = UQTkOptimizeEta(pc_model, ydata, xdata, eta, niter, eta_folds, mindex_growth, verbose, eta_plot)
eta_opt = UQTkOptimizeEta(pc_begin, ydata, xdata, eta, niter, eta_folds, mindex_growth, verbose, eta_plot)
if verbose:
print("Optimal eta is", eta_opt)
else:
Expand All @@ -448,12 +448,12 @@ def UQTkBCS(pc_model, xdata, ydata, niter, eta, ntry=1, eta_folds=5,\
if regparams is None:
regparams = np.array([])
elif type(regparams)==int or type(regparams)==float:
regparams = regparams*np.ones((pc_model.GetNumberPCTerms(),))
regparams = regparams*np.ones((pc_begin.GetNumberPCTerms(),))

if mindex_growth == None:
full_basis_size = pc_model.GetNumberPCTerms()
full_basis_size = pc_begin.GetNumberPCTerms()
else:
full_basis_size = uqtkpce.PCSet("NISPnoq", pc_model.GetOrder() + niter -1, pc_model.GetNDim(), pc_model.GetPCType(), pc_model.GetAlpha(), pc_model.GetBeta()).GetNumberPCTerms()
full_basis_size = uqtkpce.PCSet("NISPnoq", pc_begin.GetOrder() + niter -1, pc_begin.GetNDim(), pc_begin.GetPCType(), pc_begin.GetAlpha(), pc_begin.GetBeta()).GetNumberPCTerms()

# loop through iterations with different splits of the data
for i in range(ntry):
Expand All @@ -468,16 +468,16 @@ def UQTkBCS(pc_model, xdata, ydata, niter, eta, ntry=1, eta_folds=5,\
# Iterations of multiindex growth
for j in range(niter):
# Retrieve multiindex
mi_uqtk = uqtkarray.intArray2D(pc_model.GetNumberPCTerms(), nsam)
pc_model.GetMultiIndex(mi_uqtk)
mi_uqtk = uqtkarray.intArray2D(pc_begin.GetNumberPCTerms(), nsam)
pc_begin.GetMultiIndex(mi_uqtk)
mindex=uqtkarray.uqtk2numpy(mi_uqtk)
if verbose>0:
print("==== BCS with multiindex of size %d ====" % (mindex.shape[0],))
if verbose>1:
print(mindex)

# One run of BCS to obtain an array of coefficients and a new multiindex
c_k, used_mi_np = UQTkEvalBCS(pc_model, y_split, x_split, sigma2, eta_opt, regparams, verbose)
c_k, used_mi_np = UQTkEvalBCS(pc_begin, y_split, x_split, sigma2, eta_opt, regparams, verbose)

# Custom 'cuts' by number of PC terms or by value of PC coefficients
npcall = c_k.shape[0] # number of PC terms
Expand Down Expand Up @@ -524,8 +524,8 @@ def UQTkBCS(pc_model, xdata, ydata, niter, eta, ntry=1, eta_folds=5,\
mindex_uq.assign(i2,j2, mindex[i2][j2])

# create a pc object with the new multiindex
pc_model=uqtkpce.PCSet("NISPnoq", mindex_uq, pc_model.GetPCType(),\
pc_model.GetAlpha(), pc_model.GetBeta())
pc_model=uqtkpce.PCSet("NISPnoq", mindex_uq, pc_begin.GetPCType(),\
pc_begin.GetAlpha(), pc_begin.GetBeta())

# Save for this trial
mi_selected.append(mindex)
Expand Down
Loading

0 comments on commit 9ca2fc6

Please sign in to comment.