Skip to content

Commit 3e85d6d

Browse files
authored
Add check_eq for StateDump in Python (#1372)
This adds a utility to the `StateDump` object in Python to help with writing tests that verify quantum state. The check ignores global phase, so allows for passing in any dictionary where the states differ from the dump by a constant factor, including unnormalized states.
1 parent 4d891c1 commit 3e85d6d

2 files changed

Lines changed: 62 additions & 1 deletion

File tree

pip/qsharp/_qsharp.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
Circuit,
1111
)
1212
from warnings import warn
13-
from typing import Any, Callable, Dict, Optional, TypedDict, Union, List
13+
from typing import Any, Callable, Dict, Optional, Tuple, TypedDict, Union, List
1414
from .estimator._estimator import EstimatorResult, EstimatorParams
1515
import json
1616

@@ -349,6 +349,39 @@ def __str__(self) -> str:
349349
def _repr_html_(self) -> str:
350350
return self.__data._repr_html_()
351351

352+
def check_eq(
353+
self, state: Union[Dict[int, complex], List[complex]], tolerance: float = 1e-10
354+
) -> bool:
355+
"""
356+
Checks if the state dump is equal to the given state. This is not mathematical equality,
357+
as the check ignores global phase.
358+
359+
:param state: The state to check against, provided either as a dictionary of state indices to complex amplitudes,
360+
or as a list of real amplitudes.
361+
:param tolerance: The tolerance for the check. Defaults to 1e-10.
362+
"""
363+
phase = None
364+
# Convert a dense list of real amplitudes to a dictionary of state indices to complex amplitudes
365+
if isinstance(state, list):
366+
state = {i: state[i] for i in range(len(state))}
367+
# Filter out zero states from the state dump and the given state based on tolerance
368+
state = {k: v for k, v in state.items() if abs(v) > tolerance}
369+
inner_state = {k: v for k, v in self.__inner.items() if abs(v) > tolerance}
370+
if len(state) != len(inner_state):
371+
return False
372+
for key in state:
373+
if key not in inner_state:
374+
return False
375+
if phase is None:
376+
# Calculate the phase based on the first state pair encountered.
377+
# Every pair of states after this must have the same phase for the states to be equivalent.
378+
phase = inner_state[key] / state[key]
379+
elif abs(phase - inner_state[key] / state[key]) > tolerance:
380+
# This pair of states does not have the same phase,
381+
# within tolerance, so the equivalence check fails.
382+
return False
383+
return True
384+
352385

353386
def dump_machine() -> StateDump:
354387
"""

pip/tests/test_qsharp.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,34 @@ def test_dump_machine() -> None:
9898
# Check that the state dump correctly supports iteration and membership checks
9999
for idx in state_dump:
100100
assert idx in state_dump
101+
# Check that the state dump is correct and equivalence check ignores global phase, allowing passing
102+
# in of different, potentially unnormalized states. The state should be
103+
# |01⟩: 0.7071+0.0000𝑖, |11⟩: −0.7071+0.0000𝑖
104+
assert state_dump.check_eq({1: complex(0.7071, 0.0), 3: complex(-0.7071, 0.0)})
105+
assert state_dump.check_eq({1: complex(0.0, 0.7071), 3: complex(0.0, -0.7071)})
106+
assert state_dump.check_eq({1: complex(0.5, 0.0), 3: complex(-0.5, 0.0)})
107+
assert state_dump.check_eq(
108+
{1: complex(0.7071, 0.0), 3: complex(-0.7071, 0.0), 0: complex(0.0, 0.0)}
109+
)
110+
assert state_dump.check_eq([0.0, 0.5, 0.0, -0.5])
111+
assert state_dump.check_eq([0.0, 0.5001, 0.00001, -0.5], tolerance=1e-3)
112+
assert state_dump.check_eq(
113+
[complex(0.0, 0.0), complex(0.0, -0.5), complex(0.0, 0.0), complex(0.0, 0.5)]
114+
)
115+
assert not state_dump.check_eq({1: complex(0.7071, 0.0), 3: complex(0.7071, 0.0)})
116+
assert not state_dump.check_eq({1: complex(0.5, 0.0), 3: complex(0.0, 0.5)})
117+
assert not state_dump.check_eq({2: complex(0.5, 0.0), 3: complex(-0.5, 0.0)})
118+
assert not state_dump.check_eq([0.0, 0.5001, 0.0, -0.5], tolerance=1e-6)
119+
# Reset the qubits and apply a small rotation to q1, to confirm that tolerance applies to the dump
120+
# itself and not just the state.
121+
qsharp.eval("ResetAll([q1, q2]);")
122+
qsharp.eval("Ry(0.0001, q1);")
123+
state_dump = qsharp.dump_machine()
124+
assert state_dump.qubit_count == 2
125+
assert len(state_dump) == 2
126+
assert not state_dump.check_eq([1.0])
127+
assert state_dump.check_eq([0.99999999875, 0.0, 4.999999997916667e-05])
128+
assert state_dump.check_eq([1.0], tolerance=1e-4)
101129

102130

103131
def test_dump_operation() -> None:

0 commit comments

Comments
 (0)