diff --git a/sympy/intro/README.md b/sympy/intro/README.md index bc40cce..3364a12 100644 --- a/sympy/intro/README.md +++ b/sympy/intro/README.md @@ -170,70 +170,6 @@ result = sympy.dsolve(diffeq, f(x)) print(result) # -> Eq(f(x), (C1 + C2*x)*exp(x) + cos(x)/2) ``` -Sometimes it is necessary to simplyify the problem for the solver: - -```python -import sympy - -u = sympy.symbols("u", cls=sympy.Function, real=True) -t, tau, a0, urest, uc, r, i, uthr = sympy.symbols( - "t tau a0 urest uc r i uthr", real=True -) -c0, c1, c2 = sympy.symbols("c0 c1 c2", real=True) - -diffeq_rhs = (a0 * (u(t) - urest) * (u(t) - uc) + r * i) / tau - -print("Original") -print(diffeq_rhs) -print() - -diffeq_rhs = sympy.expand(diffeq_rhs) -diffeq_rhs = sympy.collect(diffeq_rhs, u(t)) - -c0_coeffs = diffeq_rhs.coeff(u(t), 0) -c1_coeffs = diffeq_rhs.coeff(u(t), 1) -c2_coeffs = diffeq_rhs.coeff(u(t), 2) - -diffeq_rhs = diffeq_rhs.subs(c0_coeffs, c0) -diffeq_rhs = diffeq_rhs.subs(c1_coeffs, c1) -diffeq_rhs = diffeq_rhs.subs(c2_coeffs, c2) - -print("Simplified") -print(diffeq_rhs) -print() - -diffeq = sympy.Eq(u(t).diff(t), diffeq_rhs) - -print("Full equation") -print(diffeq) -print() - -solved_diffeq_rhs = sympy.dsolve(diffeq, u(t), ics={u(0): uc}, simplify=False).rhs - -solved_diffeq_rhs = solved_diffeq_rhs.subs(c0, c0_coeffs) -solved_diffeq_rhs = solved_diffeq_rhs.subs(c1, c1_coeffs) -solved_diffeq_rhs = solved_diffeq_rhs.subs(c2, c2_coeffs) - -print("Result") -print(solved_diffeq_rhs) -print() -``` - -Output: -```python -Original -(a0*(-uc + u(t))*(-urest + u(t)) + i*r)/tau - -Simplified -c0 + c1*u(t) + c2*u(t)**2 - -Full equation -Eq(Derivative(u(t), t), c0 + c1*u(t) + c2*u(t)**2) - -Result --t + sqrt(-1/(4*a0*(a0*uc*urest/tau + i*r/tau)/tau - (-a0*uc/tau - a0*urest/tau)**2))*log(uc - 2*sqrt(-1/(4*a0*(a0*uc*urest/tau + i*r/tau)/tau - (-a0*uc/tau - a0*urest/tau)**2))*(a0*uc*urest/tau + i*r/tau) + tau*sqrt(-1/(4*a0*(a0*uc*urest/tau + i*r/tau)/tau - (-a0*uc/tau - a0*urest/tau)**2))*(-a0*uc/tau - a0*urest/tau)**2/(2*a0) + tau*(-a0*uc/tau - a0*urest/tau)/(2*a0)) - sqrt(-1/(4*a0*(a0*uc*urest/tau + i*r/tau)/tau - (-a0*uc/tau - a0*urest/tau)**2))*log(uc + 2*sqrt(-1/(4*a0*(a0*uc*urest/tau + i*r/tau)/tau - (-a0*uc/tau - a0*urest/tau)**2))*(a0*uc*urest/tau + i*r/tau) - tau*sqrt(-1/(4*a0*(a0*uc*urest/tau + i*r/tau)/tau - (-a0*uc/tau - a0*urest/tau)**2))*(-a0*uc/tau - a0*urest/tau)**2/(2*a0) + tau*(-a0*uc/tau - a0*urest/tau)/(2*a0)) -``` - ## [Numerical Evaluation](https://docs.sympy.org/latest/modules/evalf.html) ```python @@ -307,3 +243,78 @@ import sympy x = sympy.symbols("x") sympy.solve([x**2 - 1, x >= 0.5, x <= 3], x) ``` + + +## Example: Gain function of the quadratic integrate and fire neuron + +```python +import sympy +import numpy as np +import matplotlib.pyplot as plt + +tau_value: float = 10e-3 # s +a0_value: float = 0.1e3 # V^-1 +uc_value: float = -55.0e-3 # V +urest_value: float = -70.0e-3 # V +r_value: float = 10e6 # Ohm +uthr_value: float = -40e-3 # V + +u = sympy.symbols("u", cls=sympy.Function, real=True) +urest, uc, uthr = sympy.symbols("urest uc uthr", real=True) +t, tau, a0, r, i = sympy.symbols("t tau a0 r i", real=True, positive=True) + +c0, c1, c2 = sympy.symbols("c0 c1 c2", real=True) + +diffeq_rhs = (a0 * (u(t) - urest) * (u(t) - uc) + r * i) / tau +diffeq_rhs = sympy.expand(diffeq_rhs) +diffeq_rhs = sympy.collect(diffeq_rhs, u(t)) + +c0_coeffs = diffeq_rhs.coeff(u(t), 0) +c1_coeffs = diffeq_rhs.coeff(u(t), 1) +c2_coeffs = diffeq_rhs.coeff(u(t), 2) + +diffeq_rhs = diffeq_rhs.subs(c0_coeffs, c0) +diffeq_rhs = diffeq_rhs.subs(c1_coeffs, c1) +diffeq_rhs = diffeq_rhs.subs(c2_coeffs, c2) + +diffeq = sympy.Eq(u(t).diff(t), diffeq_rhs) + +solved_diffeq = sympy.dsolve(diffeq, u(t), ics={u(0): urest}, simplify=False) +solved_diffeq = solved_diffeq.subs(u(t), uthr) +solved_diffeq = sympy.simplify(solved_diffeq) + +solved2_diffeq = sympy.solve(solved_diffeq, t)[0] + +solved2_diffeq = solved2_diffeq.subs(c0, c0_coeffs) +solved2_diffeq = solved2_diffeq.subs(c1, c1_coeffs) +solved2_diffeq = solved2_diffeq.subs(c2, c2_coeffs) +solved2_diffeq = sympy.simplify(solved2_diffeq) + +solved2_diffeq = solved2_diffeq.subs(tau, tau_value) +solved2_diffeq = solved2_diffeq.subs(a0, a0_value) +solved2_diffeq = solved2_diffeq.subs(uc, uc_value) +solved2_diffeq = solved2_diffeq.subs(urest, urest_value) +solved2_diffeq = solved2_diffeq.subs(r, r_value) +solved2_diffeq = solved2_diffeq.subs(uthr, uthr_value) +solved2_diffeq = sympy.simplify(solved2_diffeq) + + +print("The final function") +print(solved2_diffeq) +print() + + +thr_func = sympy.lambdify(i, solved2_diffeq, "numpy") + +i = 4 * 1e-9 * np.arange(0, 1000, dtype=np.complex128) / 1000 +z = thr_func(i) +z = 1.0 / z +z[z < 0] = 0 +z = z.astype(dtype=np.float32) + +plt.plot(i * 1e9, z) +plt.xlabel("Input [nA]") +plt.ylabel("Rate [Hz]") +plt.show() +``` +