-
Notifications
You must be signed in to change notification settings - Fork 40
/
riemann_tools.py
332 lines (270 loc) · 10.9 KB
/
riemann_tools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
"""
This version of riemann_tools.py was adapted from the
version in clawpack/riemann/src from Clawpack V5.3.1 to
add some new features and improve the plots.
This may be modified further as the notebooks in this
directory are improved and expanded. Eventually a
stable version of this should be moved back to
clawpack/riemann/src in a future release.
"""
from __future__ import absolute_import
from __future__ import print_function
from matplotlib import animation
from clawpack.visclaw.JSAnimation import IPython_display
from IPython.display import display
import ipywidgets
import sympy
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from six.moves import range
import copy
sympy.init_printing(use_latex='mathjax')
def riemann_solution(solver,q_l,q_r,aux_l=None,aux_r=None,t=0.2,problem_data=None,verbose=False):
r"""
Compute the (approximate) solution of the Riemann problem with states (q_l, q_r)
based on the (approximate) Riemann solver `solver`. The solver should be
a pointwise solver provided as a Python function.
**Example**::
# Call solver
>>> from clawpack import riemann
>>> gamma = 1.4
>>> problem_data = { 'gamma' : gamma, 'gamma1' : gamma - 1 }
>>> solver = riemann.euler_1D_py.euler_hll_1D
>>> q_l = (3., -0.5, 2.); q_r = (1., 0., 1.)
>>> states, speeds, reval = riemann_solution(solver, q_l, q_r, problem_data=problem_data)
# Check output
>>> q_m = np.array([1.686068, 0.053321, 1.202282])
>>> assert np.allclose(q_m, states[:,1])
"""
if not isinstance(q_l, np.ndarray):
q_l = np.array(q_l) # in case q_l, q_r specified as scalars
q_r = np.array(q_r)
num_eqn = len(q_l)
if aux_l is None:
aux_l = np.zeros(num_eqn)
if aux_r is None:
aux_r = np.zeros(num_eqn)
wave, s, amdq, apdq = solver(q_l.reshape((num_eqn,1)),q_r.reshape((num_eqn,1)),
aux_l.reshape((num_eqn,1)),aux_r.reshape((num_eqn,1)),problem_data)
wave0 = wave[:,:,0]
num_waves = wave.shape[1]
qlwave = np.vstack((q_l,wave0.T)).T
# Sum to the waves to get the states:
states = np.cumsum(qlwave,1)
num_states = num_waves + 1
if verbose:
print('States in Riemann solution:')
states_sym = sympy.Matrix(states)
display([states_sym[:,k] for k in range(num_states)])
print('Waves (jumps between states):')
wave_sym = sympy.Matrix(wave[:,:,0])
display([wave_sym[:,k] for k in range(num_waves)])
print("Speeds: ")
s_sym = sympy.Matrix(s)
display(s_sym.T)
print("Fluctuations amdq, apdq: ")
amdq_sym = sympy.Matrix(amdq).T
apdq_sym = sympy.Matrix(apdq).T
display([amdq_sym, apdq_sym])
def riemann_eval(xi):
"Return Riemann solution as function of xi = x/t."
qout = np.zeros((num_eqn,len(xi)))
intervals = [(xi>=s[i])*(xi<=s[i+1]) for i in range(len(s)-1)]
intervals.insert(0, xi<s[0])
intervals.append(xi>=s[-1])
for m in range(num_eqn):
qout[m,:] = np.piecewise(xi, intervals, states[m,:])
return qout
return states, s, riemann_eval
def plot_phase(states, i_h=0, i_v=1, ax=None, label_h=None, label_v=None):
"""
Plot 2d phase space plot.
If num_eqns > 2, can specify which component of q to put on horizontal
and vertical axes via i_h and i_v.
"""
if label_h is None:
label_h = 'q[%s]' % i_h
if label_v is None:
label_v = 'q[%s]' % i_v
q0 = states[i_h,:]
q1 = states[i_v,:]
if ax is None:
fig, ax = plt.subplots()
ax.plot(q0,q1,'o-k')
#ax.set_title('phase space: %s -- %s' % (label_h,label_v))
ax.set_title('States in phase space')
ax.axis('equal')
dq0 = q0.max() - q0.min()
dq1 = q1.max() - q1.min()
ax.text(q0[0] + 0.05*dq0,q1[0] + 0.05*dq1,'q_left')
ax.text(q0[-1] + 0.05*dq0,q1[-1] + 0.05*dq1,'q_right')
ax.axis([q0.min()-0.1*dq0, q0.max()+0.1*dq0, q1.min()-0.1*dq1, q1.max()+0.1*dq1])
ax.set_xlabel(label_h)
ax.set_ylabel(label_v)
def plot_phase_3d(states):
"""
3d phase space plot
"""
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.plot(states[0,:],states[1,:],states[2,:],'ko-')
ax.set_xlabel('q[0]')
ax.set_ylabel('q[1]')
ax.set_zlabel('q[2]')
ax.set_title('phase space')
ax.text(states[0,0]+0.05,states[1,0],states[2,0],'q_left')
ax.text(states[0,-1]+0.05,states[1,-1],states[2,-1],'q_right')
def plot_riemann(states, s, riemann_eval, t, fig=None, color='b', layout='horizontal',conserved_variables=None):
"""
Take an array of states and speeds s and plot the solution at time t.
For rarefaction waves, the corresponding entry in s should be tuple of two values,
which are the wave speeds that bound the rarefaction fan.
Plots in the x-t plane and also produces a separate plot for each component of q.
"""
num_eqn,num_states = states.shape
if fig is None:
if layout == 'horizontal':
fig_width = 4*(num_eqn+1)
fig, ax = plt.subplots(1,num_eqn+1,figsize=(fig_width,4))
elif layout == 'vertical':
fig_width = 9
fig_height = 4*num_eqn
fig, ax = plt.subplots(num_eqn+1,1,figsize=(fig_width,fig_height),sharex=True)
plt.subplots_adjust(hspace=0)
ax[-1].set_xlabel('x')
ax[0].set_ylabel('t')
ax[0].set_title('t = %6.3f' % t)
else:
ax = fig.axes
xmin, xmax = ax[1].get_xlim()
for axis in ax:
for child in axis.get_children():
if isinstance(child, matplotlib.spines.Spine):
child.set_color('#dddddd')
tmax = 1.0
xmax = 0.
for i in range(len(s)):
if type(s[i]) not in (tuple, list): # this is a jump
x1 = tmax * s[i]
ax[0].plot([0,x1],[0,tmax],color=color)
xmax = max(xmax,abs(x1))
else: #plot rarefaction fan
speeds = np.linspace(s[i][0],s[i][1],5)
for ss in speeds:
x1 = tmax * ss
ax[0].plot([0,x1],[0,tmax],color=color,lw=0.3)
xmax = max(xmax,abs(x1))
x = np.linspace(-xmax,xmax,1000)
ax[0].set_xlim(-xmax,xmax)
ax[0].plot([-xmax,xmax],[t,t],'k',linewidth=2)
ax[0].text(-1.8*xmax,t,'t = %4.2f -->' % t)
ax[0].set_title('Waves in x-t plane')
if conserved_variables is None:
conserved_variables = ['q[%s]' % i for i in range(num_eqn)]
for i in range(num_eqn):
ax[i+1].set_xlim((-1,1))
qmax = states[i,:].max() #max([state[i] for state in states])
qmin = states[i,:].min() # min([state[i] for state in states])
qdiff = qmax - qmin
ax[i+1].set_xlim(-xmax,xmax)
ax[i+1].set_ylim((qmin-0.1*qdiff,qmax+0.1*qdiff))
if layout == 'horizontal':
ax[i+1].set_title(conserved_variables[i]+' at t = %6.3f' % t)
elif layout == 'vertical':
ax[i+1].set_ylabel(conserved_variables[i])
if t == 0:
q = riemann_eval(x/1e-10)
else:
q = riemann_eval(x/t)
for i in range(num_eqn):
ax[i+1].plot(x,q[i][:],color=color,lw=2)
return fig
def make_plot_function(states_list,speeds_list,riemann_eval_list,names=None,layout='horizontal',conserved_variables=None):
"""
Utility function to create a plot_function that takes a single argument t.
This function can then be used in a StaticInteract widget.
Version that takes an arbitrary list of sets of states and speeds in order to make a comparison.
"""
colors = "kbrg"
if type(states_list) is not list:
states_list = [states_list]
speeds_list = [speeds_list]
riemann_eval_list = [riemann_eval_list]
def plot_function(t):
fig = None
for i in range(len(states_list)):
states = states_list[i]
speeds = speeds_list[i]
riemann_eval = riemann_eval_list[i]
if fig is None:
fig = plot_riemann(states,speeds,riemann_eval,t,color=colors[i],layout=layout,conserved_variables=conserved_variables)
else:
fig = plot_riemann(states,speeds,riemann_eval,t,fig,colors[i],layout=layout,conserved_variables=conserved_variables)
if names is not None:
# We could use fig.legend here if we had the line plot handles
fig.axes[1].legend(names,loc='best')
return fig
return plot_function
def JSAnimate_plot_riemann(states,speeds,riemann_eval, times=None, **kwargs):
from clawpack.visclaw import animation_tools
figs = [] # to collect figures at multiple times
if times is None:
times = np.linspace(0,0.9,10)
for t in times:
fig = plot_riemann(states,speeds,riemann_eval,t, **kwargs)
figs.append(fig)
plt.close(fig)
images = animation_tools.make_images(figs)
anim = animation_tools.JSAnimate_images(images, figsize=(10,5))
return anim
def plot_riemann_trajectories(states, s, riemann_eval, i_vel=1,
fig=None, color='b', num_left=10, num_right=10):
"""
Take an array of states and speeds s and plot the solution in the x-t plane,
along with particle trajectories.
Only useful for systems where one component is velocity.
i_vel should be the index of this component.
For rarefaction waves, the corresponding entry in s should be tuple of two values,
which are the wave speeds that bound the rarefaction fan.
"""
num_eqn,num_states = states.shape
if fig is None:
fig, ax = plt.subplots()
tmax = 1.0
xmax = 0.
for i in range(len(s)):
if type(s[i]) not in (tuple, list): # this is a jump
x1 = tmax * s[i]
ax.plot([0,x1],[0,tmax],color=color, lw=2)
xmax = max(xmax,abs(x1))
else: #plot rarefaction fan
speeds = np.linspace(s[i][0],s[i][1],5)
for ss in speeds:
x1 = tmax * ss
ax.plot([0,x1],[0,tmax],color=color,lw=1)
xmax = max(xmax,abs(x1))
x = np.linspace(-xmax,xmax,1000)
ax.set_xlim(-xmax,xmax)
xx_left = np.linspace(-xmax,0,num_left)
xx_right = np.linspace(0,xmax,num_right)
xx = np.hstack((xx_left, xx_right))
xtraj = [xx]
nsteps = 200.
dt = 1./nsteps
tt = np.linspace(0,1,nsteps+1)
q_old = riemann_eval(xx/1e-15)
for n in range(1,len(tt)):
q_new = riemann_eval(xx/tt[n])
v_mid = 0.5*(q_old[i_vel,:] + q_new[i_vel,:])
xx = xx + dt*v_mid
xtraj.append(xx)
q_old = copy.copy(q_new)
xtraj = np.array(xtraj)
for j in range(xtraj.shape[1]):
plt.plot(xtraj[:,j],tt,'k')
ax.set_title('Waves and particle trajectories in x-t plane')
if __name__ == '__main__':
import doctest
doctest.testmod()