diff --git a/gtfsdb/helpers.go b/gtfsdb/helpers.go index 618cac07..195514eb 100644 --- a/gtfsdb/helpers.go +++ b/gtfsdb/helpers.go @@ -1415,13 +1415,30 @@ func (c *Client) buildBlockLayoverIndex(ctx context.Context, staticData *gtfs.St continue } + layoverStart := int64(lastStopCurrent.ArrivalTime) + layoverEnd := int64(firstStopNext.DepartureTime) + + // ArrivalTime/DepartureTime are nanoseconds since service-day midnight + // (they come from the go-gtfs library as time.Duration values). + // If the layover appears negative, check if it's a valid midnight wraparound + // (e.g. within a reasonable layover threshold, like 4 hours). + const dayNs = int64(24 * time.Hour) // 86_400_000_000_000 ns + const maxLayoverNs = int64(4 * time.Hour) // 14_400_000_000_000 ns + if layoverStart > layoverEnd { + if (layoverEnd+dayNs)-layoverStart < maxLayoverNs { + layoverEnd += dayNs // It crosses midnight, shift it by 24h + } else { + continue // It's invalid data + } + } + err := qtx.CreateBlockLayover(ctx, CreateBlockLayoverParams{ BlockID: key.blockID, ServiceID: key.serviceID, RouteID: nextTrip.Route.Id, LayoverStopID: lastStopCurrent.Stop.Id, - LayoverStart: int64(lastStopCurrent.DepartureTime), - LayoverEnd: int64(firstStopNext.ArrivalTime), + LayoverStart: layoverStart, + LayoverEnd: layoverEnd, NextTripID: nextTrip.ID, }) if err != nil { diff --git a/gtfsdb/query.sql b/gtfsdb/query.sql index b10da1dd..2a78561c 100644 --- a/gtfsdb/query.sql +++ b/gtfsdb/query.sql @@ -673,7 +673,8 @@ FROM JOIN routes ON trips.route_id = routes.id JOIN agencies a ON routes.agency_id = a.id WHERE - stop_times.stop_id IN (sqlc.slice('stop_ids')); + stop_times.stop_id IN (sqlc.slice('stop_ids')) +ORDER BY a.id ASC, stop_times.stop_id ASC; -- name: GetStopTimesForTrip :many SELECT diff --git a/gtfsdb/query.sql.go b/gtfsdb/query.sql.go index 788eab65..cd4c0b96 100644 --- a/gtfsdb/query.sql.go +++ b/gtfsdb/query.sql.go @@ -1304,6 +1304,7 @@ FROM JOIN agencies a ON routes.agency_id = a.id WHERE stop_times.stop_id IN (/*SLICE:stop_ids*/?) +ORDER BY a.id, stop_times.stop_id ` type GetAgenciesForStopsRow struct { diff --git a/internal/models/arrival_and_departure.go b/internal/models/arrival_and_departure.go index 1e87ffdd..27f92980 100644 --- a/internal/models/arrival_and_departure.go +++ b/internal/models/arrival_and_departure.go @@ -12,7 +12,7 @@ type ArrivalAndDeparture struct { DistanceFromStop float64 `json:"distanceFromStop"` Frequency *Frequency `json:"frequency"` HistoricalOccupancy string `json:"historicalOccupancy"` - LastUpdateTime ModelTime `json:"lastUpdateTime,omitzero"` + LastUpdateTime ModelTime `json:"lastUpdateTime"` NumberOfStopsAway int `json:"numberOfStopsAway"` OccupancyStatus string `json:"occupancyStatus"` Predicted bool `json:"predicted"` diff --git a/internal/models/response.go b/internal/models/response.go index 0f3efb20..9c96c585 100644 --- a/internal/models/response.go +++ b/internal/models/response.go @@ -59,6 +59,38 @@ func NewArrivalsAndDepartureResponse(arrivalsAndDepartures any, references Refer return NewOKResponse(data, c) } +func NewArrivalsAndDeparturesForLocationResponse( + arrivalsAndDepartures []ArrivalAndDeparture, + references ReferencesModel, + nearbyStopIds []StopWithDistance, + situationIds []string, + stopIds []string, + limitExceeded bool, + c clock.Clock, +) ResponseModel { + if nearbyStopIds == nil { + nearbyStopIds = []StopWithDistance{} + } + if situationIds == nil { + situationIds = []string{} + } + if stopIds == nil { + stopIds = []string{} + } + entryData := map[string]any{ + "arrivalsAndDepartures": arrivalsAndDepartures, + "limitExceeded": limitExceeded, + "nearbyStopIds": nearbyStopIds, + "situationIds": situationIds, + "stopIds": stopIds, + } + data := map[string]any{ + "entry": entryData, + "references": references, + } + return NewOKResponse(data, c) +} + // NewResponse creates a standard response using the provided clock. func NewResponse(code int, data any, text string, c clock.Clock) ResponseModel { return ResponseModel{ diff --git a/internal/models/situation.go b/internal/models/situation.go index 7dcf108e..b4e17acf 100644 --- a/internal/models/situation.go +++ b/internal/models/situation.go @@ -6,7 +6,7 @@ type Situation struct { ActiveWindows []ActiveWindow `json:"activeWindows"` AllAffects []AffectedEntity `json:"allAffects"` ConsequenceMessage string `json:"consequenceMessage"` - Consequences []any `json:"consequences"` + Consequences []Consequence `json:"consequences"` PublicationWindows []any `json:"publicationWindows"` Reason string `json:"reason"` Severity string `json:"severity"` @@ -15,6 +15,22 @@ type Situation struct { URL *TranslatedString `json:"url,omitempty"` } +type Consequence struct { + Condition string `json:"condition"` + ConditionDetails ConditionDetails `json:"conditionDetails"` +} + +type ConditionDetails struct { + DiversionPath DiversionPath `json:"diversionPath"` + DiversionStopIDs []string `json:"diversionStopIds"` +} + +type DiversionPath struct { + Length int `json:"length"` + Levels string `json:"levels"` + Points string `json:"points"` +} + type ActiveWindow struct { From int64 `json:"from"` To int64 `json:"to"` diff --git a/internal/models/stops.go b/internal/models/stops.go index 9a1e9073..342fe2a6 100644 --- a/internal/models/stops.go +++ b/internal/models/stops.go @@ -34,3 +34,11 @@ type StopsResponse struct { List []Stop `json:"list"` OutOfRange bool `json:"outOfRange"` } + +// StopWithDistance represents a nearby stop together with its distance from the +// centre of the query bounds. It matches the Java StopWithDistanceV2Bean and is +// used by the arrivals-and-departures-for-location endpoint. +type StopWithDistance struct { + StopID string `json:"stopId"` + DistanceFromQuery float64 `json:"distanceFromQuery"` +} diff --git a/internal/restapi/arrival_and_departure_for_stop_handler.go b/internal/restapi/arrival_and_departure_for_stop_handler.go index 90cfa268..94c2d0a9 100644 --- a/internal/restapi/arrival_and_departure_for_stop_handler.go +++ b/internal/restapi/arrival_and_departure_for_stop_handler.go @@ -713,8 +713,24 @@ func (api *RestAPI) getPredictedTimes( func (api *RestAPI) getNumberOfStopsAway(ctx context.Context, targetTripID string, targetStopSequence int, vehicle *gtfs.Vehicle, serviceDate time.Time) *int { currentVehicleStopSequence := getCurrentVehicleStopSequence(vehicle) + if currentVehicleStopSequence == nil { - return nil + // Fallback: infer the vehicle's current stop from its lat/lon position. + // This handles agencies (e.g. Sound Transit Link light rail) that don't + // publish current_stop_sequence in GTFS-RT vehicle positions. + if vehicle == nil || vehicle.Position == nil || + vehicle.Position.Latitude == nil || vehicle.Position.Longitude == nil { + return nil + } + inferred := api.inferStopSequenceFromPosition( + ctx, targetTripID, + float64(*vehicle.Position.Latitude), + float64(*vehicle.Position.Longitude), + ) + if inferred == nil { + return nil + } + currentVehicleStopSequence = inferred } activeTripID := GetVehicleActiveTripID(vehicle) @@ -728,3 +744,55 @@ func (api *RestAPI) getNumberOfStopsAway(ctx context.Context, targetTripID strin numberOfStopsAway := targetGlobalSeq - vehicleGlobalSeq - 1 return &numberOfStopsAway } + +// inferStopSequenceFromPosition returns the stop_sequence of the stop the vehicle +// is currently at or has most recently passed, determined by projecting the vehicle's +// lat/lon onto the ordered list of stop positions for the trip. +// +// It fetches stop times (ordered by sequence) and stop coordinates in a single batch, +// then finds the last stop that is "behind" the vehicle along the route direction. +// Returns nil when no stop times exist or coordinates cannot be resolved. +func (api *RestAPI) inferStopSequenceFromPosition(ctx context.Context, tripID string, vehLat, vehLon float64) *int { + stopTimes, err := api.GtfsManager.GtfsDB.Queries.GetStopTimesForTrip(ctx, tripID) + if err != nil || len(stopTimes) == 0 { + return nil + } + + stopIDs := make([]string, len(stopTimes)) + for i, st := range stopTimes { + stopIDs[i] = st.StopID + } + + stops, err := api.GtfsManager.GtfsDB.Queries.GetStopsByIDs(ctx, stopIDs) + if err != nil { + return nil + } + + coordMap := make(map[string][2]float64, len(stops)) + for _, s := range stops { + coordMap[s.ID] = [2]float64{s.Lat, s.Lon} + } + + // Find the stop that is geometrically closest to the vehicle's current position. + // OBA Java uses a similar nearest-stop heuristic when stop-sequence is absent. + bestIdx := -1 + bestDist := -1.0 + for i, st := range stopTimes { + coords, ok := coordMap[st.StopID] + if !ok { + continue + } + d := utils.Distance(vehLat, vehLon, coords[0], coords[1]) + if bestDist < 0 || d < bestDist { + bestDist = d + bestIdx = i + } + } + + if bestIdx < 0 { + return nil + } + + seq := int(stopTimes[bestIdx].StopSequence) + return &seq +} diff --git a/internal/restapi/arrivals_and_departures_for_location_handler.go b/internal/restapi/arrivals_and_departures_for_location_handler.go new file mode 100644 index 00000000..d1947c87 --- /dev/null +++ b/internal/restapi/arrivals_and_departures_for_location_handler.go @@ -0,0 +1,1120 @@ +package restapi + +import ( + "context" + "log/slog" + "net/http" + "net/url" + "sort" + "strconv" + "strings" + "time" + + "github.com/OneBusAway/go-gtfs" + "maglev.onebusaway.org/gtfsdb" + internalgtfs "maglev.onebusaway.org/internal/gtfs" + "maglev.onebusaway.org/internal/models" + "maglev.onebusaway.org/internal/nulls" + "maglev.onebusaway.org/internal/utils" +) + +// ArrivalsAndDeparturesForLocationParams holds all parsed and validated query +// parameters for the arrivals-and-departures-for-location endpoint. +type ArrivalsAndDeparturesForLocationParams struct { + Lat float64 + Lon float64 + Radius float64 + LatSpan float64 + LonSpan float64 + + Time time.Time + MinutesBefore int + MinutesAfter int + FrequencyMinutesBefore int + FrequencyMinutesAfter int + + MaxCount int + EmptyReturnsNotFound bool + RouteTypes []int +} + +// activeStopTime pairs a GTFS stop time with the service date it occurs on. +type activeStopTime struct { + gtfsdb.GetStopTimesForStopInWindowRow + ServiceDate time.Time +} + +// stopProcessingContext holds parameters for processing a single stop's arrivals. +type stopProcessingContext struct { + StopCode string + AgencyID string + CombinedStopID string + QueryTime time.Time + Loc *time.Location +} + +// fetchWindow groups parameters for fetching stop times to reduce function arguments. +type fetchWindow struct { + StopCode string + Loc *time.Location + QueryTime time.Time + Start time.Time + End time.Time +} + +// locationArrivalsState holds the shared accumulation state across all stops +// while processing arrivals and departures for a location. +type locationArrivalsState struct { + arrivals []models.ArrivalAndDeparture + tripIDSet map[string]*gtfsdb.Trip + routeIDSet map[string]*gtfsdb.Route + stopIDSet map[string]bool + stopAgencyOverride map[string]string + stopsWithArrivals map[string]bool + collectedAlerts map[string]gtfs.Alert + limitExceeded bool + + stopAgencyMap map[string]string + fallbackAgencyID string + agencyLoc *time.Location +} + +func newLocationArrivalsState() *locationArrivalsState { + return &locationArrivalsState{ + arrivals: make([]models.ArrivalAndDeparture, 0), + tripIDSet: make(map[string]*gtfsdb.Trip), + routeIDSet: make(map[string]*gtfsdb.Route), + stopIDSet: make(map[string]bool), + stopAgencyOverride: make(map[string]string), + stopsWithArrivals: make(map[string]bool), + collectedAlerts: make(map[string]gtfs.Alert), + } +} + +type arrivalContext struct { + st gtfsdb.GetStopTimesForStopInWindowRow + serviceMidnight time.Time + scheduledArrivalTime time.Time + scheduledDepartureTime time.Time + predictedArrivalTime time.Time + predictedDepartureTime time.Time + predicted bool + vehicleID string + tripStatus *models.TripStatus + distanceFromStop float64 + numberOfStopsAway int + lastUpdateTime time.Time + arrivalStatus string + totalStopsInTrip int + blockTripSequence int + situationIDs []string +} + +// Error message constants shared by the parameter-parsing helpers below. +const ( + errMustBeValidInteger = "must be a valid integer" + errMustBeNonNegativeInteger = "must be a non-negative integer" +) + +// parseArrivalsAndDeparturesForLocationParams parses and validates all query +// parameters for this endpoint in one place. +func (api *RestAPI) parseArrivalsAndDeparturesForLocationParams(r *http.Request) (ArrivalsAndDeparturesForLocationParams, map[string][]string) { + const ( + defaultMinutesBefore = 5 + defaultMinutesAfter = 35 + maxMinutesBefore = 60 + maxMinutesAfter = 240 + defaultMaxCount = 250 + ) + + params := ArrivalsAndDeparturesForLocationParams{ + Time: api.Clock.Now(), + MinutesBefore: defaultMinutesBefore, + MinutesAfter: defaultMinutesAfter, + MaxCount: defaultMaxCount, + } + + var fieldErrors map[string][]string + addError := func(field, msg string) { + if fieldErrors == nil { + fieldErrors = make(map[string][]string) + } + fieldErrors[field] = append(fieldErrors[field], msg) + } + + // Spatial params (required) — reuse the shared location parser. + loc, locErrors := api.parseLocationParams(r, nil) + if len(locErrors) > 0 { + mergeFieldErrors(&fieldErrors, locErrors) + } else { + params.Lat = loc.Lat + params.Lon = loc.Lon + params.Radius = loc.Radius + params.LatSpan = loc.LatSpan + params.LonSpan = loc.LonSpan + } + + q := r.URL.Query() + params.Time = parseTimeParam(q, params.Time, addError) + parseMinutesCappedParam(q, "minutesBefore", maxMinutesBefore, ¶ms.MinutesBefore, addError) + parseMinutesCappedParam(q, "minutesAfter", maxMinutesAfter, ¶ms.MinutesAfter, addError) + parseMinutesUncappedParam(q, "frequencyMinutesBefore", ¶ms.FrequencyMinutesBefore, addError) + parseMinutesUncappedParam(q, "frequencyMinutesAfter", ¶ms.FrequencyMinutesAfter, addError) + params.EmptyReturnsNotFound = parseEmptyReturnsNotFoundParam(q, addError) + params.RouteTypes = parseRouteTypesParam(q, addError) + + var maxCountErrors map[string][]string + params.MaxCount, maxCountErrors = utils.ParseMaxCount(q, defaultMaxCount, nil) + mergeFieldErrors(&fieldErrors, maxCountErrors) + + return params, fieldErrors +} + +// parseTimeParam parses the "time" query parameter as a Unix timestamp in +// milliseconds. Returns defaultTime unchanged when the parameter is absent. +func parseTimeParam(q url.Values, defaultTime time.Time, addError func(string, string)) time.Time { + val := q.Get("time") + if val == "" { + return defaultTime + } + ms, err := strconv.ParseInt(val, 10, 64) + if err != nil { + addError("time", "must be a valid Unix timestamp in milliseconds") + return defaultTime + } + return time.Unix(ms/1000, (ms%1000)*1_000_000) +} + +// parseMinutesCappedParam parses an integer minutes query parameter and writes +// the result into dest. Values above maxVal are silently capped; negative +// values and non-integer values are rejected via addError. +func parseMinutesCappedParam(q url.Values, key string, maxVal int, dest *int, addError func(string, string)) { + val := q.Get(key) + if val == "" { + return + } + n, err := strconv.Atoi(val) + if err != nil { + addError(key, errMustBeValidInteger) + return + } + if n < 0 { + addError(key, errMustBeNonNegativeInteger) + return + } + if n > maxVal { + *dest = maxVal + return + } + *dest = n +} + +// parseMinutesUncappedParam parses an integer minutes query parameter with no +// upper bound and writes the result into dest. +func parseMinutesUncappedParam(q url.Values, key string, dest *int, addError func(string, string)) { + val := q.Get(key) + if val == "" { + return + } + n, err := strconv.Atoi(val) + if err != nil { + addError(key, errMustBeValidInteger) + return + } + if n < 0 { + addError(key, errMustBeNonNegativeInteger) + return + } + *dest = n +} + +// parseEmptyReturnsNotFoundParam parses the "emptyReturnsNotFound" boolean +// query parameter. Returns false when absent or invalid. +func parseEmptyReturnsNotFoundParam(q url.Values, addError func(string, string)) bool { + val := q.Get("emptyReturnsNotFound") + if val == "" { + return false + } + b, err := strconv.ParseBool(val) + if err != nil { + addError("emptyReturnsNotFound", "must be true or false") + return false + } + return b +} + +// parseRouteTypesParam parses the "routeType" comma-delimited integer list +// query parameter. Returns nil when absent; stops and errors at the first +// invalid token. +func parseRouteTypesParam(q url.Values, addError func(string, string)) []int { + val := q.Get("routeType") + if val == "" { + return nil + } + var routeTypes []int + for _, p := range strings.Split(val, ",") { + p = strings.TrimSpace(p) + if p == "" { + continue + } + rt, err := strconv.Atoi(p) + if err != nil { + addError("routeType", "must be a comma-delimited list of integers") + return nil + } + routeTypes = append(routeTypes, rt) + } + return routeTypes +} + +// mergeFieldErrors merges src into *dst, initialising *dst lazily if nil. +func mergeFieldErrors(dst *map[string][]string, src map[string][]string) { + if len(src) == 0 { + return + } + if *dst == nil { + *dst = make(map[string][]string) + } + for k, v := range src { + (*dst)[k] = append((*dst)[k], v...) + } +} + +// arrivalStatusFromDeviation derives a human-readable status string from a +// schedule deviation, matching Java's ArrivalAndDepartureBeanServiceImpl logic. +// +// - deviation > 300s (5+ min late) → "LATE" +// - deviation < -180s (3+ min early) → "EARLY" +// - otherwise → "ON_TIME" +// +// When there is no real-time data the caller should pass "default" directly. +func arrivalStatusFromDeviation(deviationSeconds int) string { + switch { + case deviationSeconds > 300: + return "LATE" + case deviationSeconds < -180: + return "EARLY" + default: + return "ON_TIME" + } +} + +// arrivalsAndDeparturesForLocationHandler returns arrivals and departures for all +// stops within a geographic bounding box (lat/lon + latSpan/lonSpan or radius). +func (api *RestAPI) arrivalsAndDeparturesForLocationHandler(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + params, fieldErrors := api.parseArrivalsAndDeparturesForLocationParams(r) + if len(fieldErrors) > 0 { + api.validationErrorResponse(w, r, fieldErrors) + return + } + + stops, limitExceeded := api.GtfsManager.GetStopsForLocation( + ctx, + &internalgtfs.LocationParams{ + Lat: params.Lat, + Lon: params.Lon, + Radius: params.Radius, + LatSpan: params.LatSpan, + LonSpan: params.LonSpan, + }, + "", + params.MaxCount, + params.RouteTypes, + ) + + if len(stops) == 0 { + api.handleEmptyStopsResponseForLocation(w, r, params) + return + } + + state := newLocationArrivalsState() + if limitExceeded { + state.limitExceeded = true + } + + if err := api.resolveAgenciesForStopsLocation(ctx, stops, state); err != nil { + api.serverErrorResponse(w, r, err) + return + } + + // Fan out: collect arrivals across every stop in the bbox. + for _, dbStop := range stops { + if state.limitExceeded || len(state.arrivals) >= params.MaxCount { + state.limitExceeded = true + break + } + if err := api.collectArrivalsForLocationStop(ctx, w, r, dbStop, params, state); err != nil { + return // Context cancellation/error response already handled. + } + } + + api.sortLocationArrivalsByTime(state.arrivals) + + api.collectStopLevelAlerts(stops, state) + + references, topLevelSituationIDs := api.buildLocationReferencesBlock(ctx, state) + queriedStopIDs := api.buildLocationQueriedStopIDs(stops, state) + nearbyStops := getLocationNearbyStops(api, ctx, params.Lat, params.Lon) + + api.sendResponse(w, r, models.NewArrivalsAndDeparturesForLocationResponse( + state.arrivals, + *references, + nearbyStops, + topLevelSituationIDs, + queriedStopIDs, + state.limitExceeded, + api.Clock, + )) +} + +func (api *RestAPI) handleEmptyStopsResponseForLocation(w http.ResponseWriter, r *http.Request, params ArrivalsAndDeparturesForLocationParams) { + if params.EmptyReturnsNotFound { + api.sendNotFound(w, r) + return + } + api.sendResponse(w, r, models.NewArrivalsAndDeparturesForLocationResponse( + []models.ArrivalAndDeparture{}, + *models.NewEmptyReferences(), + []models.StopWithDistance{}, + []string{}, + []string{}, + false, + api.Clock, + )) +} + +func (api *RestAPI) collectStopLevelAlerts(stops []gtfsdb.Stop, state *locationArrivalsState) { + rawStopCodes := make([]string, 0, len(stops)) + for _, s := range stops { + rawStopCodes = append(rawStopCodes, s.ID) + } + for _, sc := range rawStopCodes { + for _, alert := range api.GtfsManager.GetAlertsForStop(sc) { + if alert.ID != "" { + if _, seen := state.collectedAlerts[alert.ID]; !seen { + state.collectedAlerts[alert.ID] = alert + } + } + } + } +} + +func (api *RestAPI) resolveAgenciesForStopsLocation(ctx context.Context, stops []gtfsdb.Stop, state *locationArrivalsState) error { + rawStopCodes := make([]string, 0, len(stops)) + for _, s := range stops { + rawStopCodes = append(rawStopCodes, s.ID) + } + + agencyRows, err := api.GtfsManager.GtfsDB.Queries.GetAgenciesForStops(ctx, rawStopCodes) + if err != nil { + return err + } + + state.stopAgencyMap = make(map[string]string, len(agencyRows)) + for _, row := range agencyRows { + if _, exists := state.stopAgencyMap[row.StopID]; !exists { + state.stopAgencyMap[row.StopID] = row.ID + } + } + + state.fallbackAgencyID = pickPrimaryAgency(state.stopAgencyMap) + state.agencyLoc = time.UTC + if state.fallbackAgencyID != "" { + if ag, tzErr := api.GtfsManager.GtfsDB.Queries.GetAgency(ctx, state.fallbackAgencyID); tzErr == nil { + if parsed, parseErr := loadAgencyLocation(ag.ID, ag.Timezone); parseErr == nil { + state.agencyLoc = parsed + } + } + } + return nil +} + +func (api *RestAPI) collectArrivalsForLocationStop(ctx context.Context, w http.ResponseWriter, r *http.Request, dbStop gtfsdb.Stop, params ArrivalsAndDeparturesForLocationParams, state *locationArrivalsState) error { + stopCode := dbStop.ID + agencyID := state.stopAgencyMap[stopCode] + if agencyID == "" { + agencyID = state.fallbackAgencyID + } + state.stopIDSet[stopCode] = true + + stopLoc := state.agencyLoc + if agencyID != state.fallbackAgencyID { + if ag, agErr := api.GtfsManager.GtfsDB.Queries.GetAgency(ctx, agencyID); agErr == nil { + if parsed, parseErr := loadAgencyLocation(ag.ID, ag.Timezone); parseErr == nil { + stopLoc = parsed + } + } + } + + stopQueryTime := params.Time.In(stopLoc) + + spc := stopProcessingContext{ + StopCode: stopCode, + AgencyID: agencyID, + CombinedStopID: utils.FormCombinedID(agencyID, stopCode), + QueryTime: stopQueryTime, + Loc: stopLoc, + } + + allActiveStopTimes, err := api.fetchActiveStopTimesForLocationWindow(ctx, spc, params) + if err != nil { + api.clientCanceledResponse(w, r, err) + return err + } + if len(allActiveStopTimes) == 0 { + return nil + } + + stopProducedArrival, err := api.buildArrivalsFromLocationStopTimes(w, r, spc, allActiveStopTimes, params, state) + if err != nil { + return err + } + if stopProducedArrival { + state.stopsWithArrivals[stopCode] = true + } + return nil +} + +func (api *RestAPI) fetchActiveStopTimesForLocationWindow( + ctx context.Context, spc stopProcessingContext, params ArrivalsAndDeparturesForLocationParams, +) ([]activeStopTime, error) { + maxBefore := params.MinutesBefore + if params.FrequencyMinutesBefore > maxBefore { + maxBefore = params.FrequencyMinutesBefore + } + + maxAfter := params.MinutesAfter + if params.FrequencyMinutesAfter > maxAfter { + maxAfter = params.FrequencyMinutesAfter + } + + stopWindowStart := spc.QueryTime.Add(-time.Duration(maxBefore) * time.Minute) + stopWindowEnd := spc.QueryTime.Add(time.Duration(maxAfter) * time.Minute) + + fw := fetchWindow{ + StopCode: spc.StopCode, + Loc: spc.Loc, + QueryTime: spc.QueryTime, + Start: stopWindowStart, + End: stopWindowEnd, + } + + var allActiveStopTimes []activeStopTime + for dayOffset := -1; dayOffset <= 1; dayOffset++ { + err := api.fetchStopTimesForDayOffset(ctx, fw, dayOffset, &allActiveStopTimes) + if err != nil { + return nil, err + } + } + return allActiveStopTimes, nil +} + +func (api *RestAPI) batchFetchLocationRoutesAndTrips( + ctx context.Context, stopCode string, allActiveStopTimes []activeStopTime, +) (map[string]gtfsdb.Route, map[string]gtfsdb.Trip, map[string]int, error) { + batchRouteIDs := make(map[string]bool) + batchTripIDs := make(map[string]bool) + for _, ast := range allActiveStopTimes { + if ast.RouteID != "" { + batchRouteIDs[ast.RouteID] = true + } + if ast.TripID != "" { + batchTripIDs[ast.TripID] = true + } + } + + uniqueRouteIDs := stringMapKeys(batchRouteIDs) + uniqueTripIDs := stringMapKeys(batchTripIDs) + + fetchedRoutes, rErr := api.GtfsManager.GtfsDB.Queries.GetRoutesByIDs(ctx, uniqueRouteIDs) + if rErr != nil { + api.Logger.Warn("failed to batch fetch routes", slog.String("stopID", stopCode), slog.Any("error", rErr)) + return nil, nil, nil, rErr + } + fetchedTrips, tErr := api.GtfsManager.GtfsDB.Queries.GetTripsByIDs(ctx, uniqueTripIDs) + if tErr != nil { + api.Logger.Warn("failed to batch fetch trips", slog.String("stopID", stopCode), slog.Any("error", tErr)) + return nil, nil, nil, tErr + } + + routesLookup := make(map[string]gtfsdb.Route, len(fetchedRoutes)) + for _, rt := range fetchedRoutes { + routesLookup[rt.ID] = rt + } + tripsLookup := make(map[string]gtfsdb.Trip, len(fetchedTrips)) + for _, tr := range fetchedTrips { + tripsLookup[tr.ID] = tr + } + + tripStopCountMap := api.buildTripStopCountMap(ctx, uniqueTripIDs) + return routesLookup, tripsLookup, tripStopCountMap, nil +} + +func (api *RestAPI) buildTripStopCountMap(ctx context.Context, uniqueTripIDs []string) map[string]int { + tripStopCountMap := make(map[string]int, len(uniqueTripIDs)) + if len(uniqueTripIDs) > 0 { + allST, countErr := api.GtfsManager.GtfsDB.Queries.GetStopTimesForTripIDs(ctx, uniqueTripIDs) + if countErr != nil { + api.Logger.Warn("failed to batch fetch stop times for trips", slog.Any("error", countErr)) + } else { + for _, st := range allST { + tripStopCountMap[st.TripID]++ + } + } + } + return tripStopCountMap +} + +func (api *RestAPI) fetchStopTimesForDayOffset( + ctx context.Context, fw fetchWindow, dayOffset int, allActiveStopTimes *[]activeStopTime, +) error { + if ctx.Err() != nil { + return ctx.Err() + } + targetDate := fw.QueryTime.AddDate(0, 0, dayOffset) + serviceMidnight := time.Date(targetDate.Year(), targetDate.Month(), targetDate.Day(), 0, 0, 0, 0, fw.Loc) + serviceDateStr := targetDate.Format("20060102") + + activeServiceIDs, svcErr := api.GtfsManager.GtfsDB.Queries.GetActiveServiceIDsForDate(ctx, serviceDateStr) + if svcErr != nil { + api.Logger.Warn("failed to query active service IDs", slog.String("date", serviceDateStr), slog.Any("error", svcErr)) + return nil + } + if len(activeServiceIDs) == 0 { + return nil + } + + activeServiceIDSet := make(map[string]bool, len(activeServiceIDs)) + for _, sid := range activeServiceIDs { + activeServiceIDSet[sid] = true + } + + startNanos := fw.Start.Sub(serviceMidnight).Nanoseconds() + endNanos := fw.End.Sub(serviceMidnight).Nanoseconds() + if endNanos < 0 { + return nil + } + + stopTimes, stErr := api.GtfsManager.GtfsDB.Queries.GetStopTimesForStopInWindow(ctx, gtfsdb.GetStopTimesForStopInWindowParams{ + StopID: fw.StopCode, + WindowStartNanos: startNanos, + WindowEndNanos: endNanos, + }) + if stErr != nil { + api.Logger.Warn("failed to query stop times in window", slog.String("stopID", fw.StopCode), slog.Any("error", stErr)) + return nil + } + + for _, st := range stopTimes { + if activeServiceIDSet[st.ServiceID] { + *allActiveStopTimes = append(*allActiveStopTimes, activeStopTime{ + GetStopTimesForStopInWindowRow: st, + ServiceDate: serviceMidnight, + }) + } + } + return nil +} + +// isRouteTypeAllowed checks if a route's type matches any in the requested filter list. +func isRouteTypeAllowed(routeType int64, allowedTypes []int) bool { + if len(allowedTypes) == 0 { + return true + } + for _, rt := range allowedTypes { + if int(routeType) == rt { + return true + } + } + return false +} + +func (api *RestAPI) buildArrivalsFromLocationStopTimes( + w http.ResponseWriter, + r *http.Request, + spc stopProcessingContext, + allActiveStopTimes []activeStopTime, + params ArrivalsAndDeparturesForLocationParams, + state *locationArrivalsState, +) (bool, error) { + ctx := r.Context() + routesLookup, tripsLookup, tripStopCountMap, bErr := api.batchFetchLocationRoutesAndTrips(ctx, spc.StopCode, allActiveStopTimes) + if bErr != nil { + return false, nil + } + + stopProducedArrival := false + + for _, ast := range allActiveStopTimes { + if len(state.arrivals) >= params.MaxCount { + state.limitExceeded = true + break + } + if ctx.Err() != nil { + api.clientCanceledResponse(w, r, ctx.Err()) + return stopProducedArrival, ctx.Err() + } + + st := ast.GetStopTimesForStopInWindowRow + + route, routeOK := routesLookup[st.RouteID] + if !routeOK || !isRouteTypeAllowed(route.Type, params.RouteTypes) { + continue + } + + trip, tripOK := tripsLookup[st.TripID] + if !tripOK { + continue + } + + rCopy := route + state.routeIDSet[route.ID] = &rCopy + tCopy := trip + state.tripIDSet[trip.ID] = &tCopy + + api.buildSingleArrival(ctx, spc, ast, state, route, tripStopCountMap[st.TripID]) + stopProducedArrival = true + } + + return stopProducedArrival, nil +} + +func (api *RestAPI) buildSingleArrival( + ctx context.Context, + spc stopProcessingContext, + ast activeStopTime, + state *locationArrivalsState, + route gtfsdb.Route, + totalStopsInTrip int, +) { + st := ast.GetStopTimesForStopInWindowRow + ac := &arrivalContext{ + st: st, + serviceMidnight: ast.ServiceDate, + totalStopsInTrip: totalStopsInTrip, + arrivalStatus: "default", + } + + ac.scheduledArrivalTime = ac.serviceMidnight.Add(time.Duration(ac.st.ArrivalTime)) + ac.scheduledDepartureTime = ac.serviceMidnight.Add(time.Duration(ac.st.DepartureTime)) + + vehicle := api.GtfsManager.GetVehicleForTrip(ctx, ac.st.TripID) + if vehicle != nil && vehicle.Trip != nil && vehicle.ID != nil { + ac.vehicleID = vehicle.ID.ID + } + + api.applyPredictedTimes(ac, spc.StopCode) + + if vehicle != nil { + api.applyTripStatus(ctx, ac, route, vehicle, spc.QueryTime, spc.StopCode, state) + } + + ac.blockTripSequence = api.calculateBlockTripSequence(ctx, ac.st.TripID, ac.serviceMidnight) + api.applyAlerts(ctx, ac, state) + + formattedVehicleID := "" + if ac.vehicleID != "" { + formattedVehicleID = utils.FormCombinedID(route.AgencyID, ac.vehicleID) + } + + rawStopSequence := int(ac.st.StopSequence) - 1 + + state.arrivals = append(state.arrivals, *models.NewArrivalAndDeparture( + utils.FormCombinedID(route.AgencyID, route.ID), + route.ShortName.String, + route.LongName.String, + utils.FormCombinedID(route.AgencyID, ac.st.TripID), + ac.st.TripHeadsign.String, + spc.CombinedStopID, + formattedVehicleID, + ac.serviceMidnight, + ac.scheduledArrivalTime, + ac.scheduledDepartureTime, + ac.predictedArrivalTime, + ac.predictedDepartureTime, + ac.lastUpdateTime, + ac.predicted, + true, + true, + rawStopSequence, + ac.totalStopsInTrip, + ac.numberOfStopsAway, + ac.blockTripSequence, + ac.distanceFromStop, + ac.arrivalStatus, + "", "", "", + ac.tripStatus, + ac.situationIDs, + )) +} + +func (api *RestAPI) applyPredictedTimes(ac *arrivalContext, stopCode string) { + predArr, predDep, isPredicted := api.getPredictedTimes( + ac.st.TripID, stopCode, int64(ac.st.StopSequence), + ac.scheduledArrivalTime, ac.scheduledDepartureTime, + ) + if isPredicted { + ac.predicted = true + ac.predictedArrivalTime = predArr + ac.predictedDepartureTime = predDep + } +} + +func (api *RestAPI) applyTripStatus(ctx context.Context, ac *arrivalContext, route gtfsdb.Route, vehicle *gtfs.Vehicle, stopQueryTime time.Time, stopCode string, state *locationArrivalsState) { + status, statusErr := api.BuildTripStatus(ctx, route.AgencyID, ac.st.TripID, vehicle, ac.serviceMidnight, stopQueryTime) + if statusErr != nil { + api.Logger.Warn("BuildTripStatus failed", "tripID", ac.st.TripID, "error", statusErr) + } + if status != nil { + ac.tripStatus = status + + if !ac.predicted && status.Predicted { + dev := time.Duration(status.ScheduleDeviation) * time.Second + ac.predictedArrivalTime = ac.scheduledArrivalTime.Add(dev) + ac.predictedDepartureTime = ac.scheduledDepartureTime.Add(dev) + ac.predicted = true + } + + if ac.predicted { + ac.arrivalStatus = arrivalStatusFromDeviation(status.ScheduleDeviation) + } + + api.applyTripStatusStops(ac, state) + + if vehicle.Position != nil { + ac.distanceFromStop = api.getBlockDistanceToStop(ctx, ac.st.TripID, stopCode, vehicle, stopQueryTime) + nsa := api.getNumberOfStopsAway(ctx, ac.st.TripID, int(ac.st.StopSequence), vehicle, stopQueryTime) + if nsa != nil { + ac.numberOfStopsAway = *nsa + } else { + ac.numberOfStopsAway = -1 + } + } + + api.applyActiveTrip(ctx, ac, state) + } + ac.lastUpdateTime = api.GtfsManager.GetVehicleLastUpdateTime(vehicle) +} + +func (api *RestAPI) applyTripStatusStops(ac *arrivalContext, state *locationArrivalsState) { + if ac.tripStatus.NextStop != "" { + if nsAgency, nsID, nsErr := utils.ExtractAgencyIDAndCodeID(ac.tripStatus.NextStop); nsErr == nil { + state.stopIDSet[nsID] = true + if nsAgency != "" { + state.stopAgencyOverride[nsID] = nsAgency + } + } + } + if ac.tripStatus.ClosestStop != "" { + if csAgency, csID, csErr := utils.ExtractAgencyIDAndCodeID(ac.tripStatus.ClosestStop); csErr == nil { + state.stopIDSet[csID] = true + if csAgency != "" { + state.stopAgencyOverride[csID] = csAgency + } + } + } +} + +func (api *RestAPI) applyActiveTrip(ctx context.Context, ac *arrivalContext, state *locationArrivalsState) { + if ac.tripStatus.ActiveTripID == "" { + return + } + _, atID, atErr := utils.ExtractAgencyIDAndCodeID(ac.tripStatus.ActiveTripID) + if atErr != nil { + return + } + if activeSeq := api.calculateBlockTripSequence(ctx, atID, ac.serviceMidnight); activeSeq > 0 { + ac.tripStatus.BlockTripSequence = activeSeq + } + + if atID != ac.st.TripID { + if _, exists := state.tripIDSet[atID]; !exists { + if at, atFetchErr := api.GtfsManager.GtfsDB.Queries.GetTrip(ctx, atID); atFetchErr == nil { + atCopy := at + state.tripIDSet[at.ID] = &atCopy + if ar, arFetchErr := api.GtfsManager.GtfsDB.Queries.GetRoute(ctx, at.RouteID); arFetchErr == nil { + arCopy := ar + state.routeIDSet[ar.ID] = &arCopy + } + } + } + } +} + +func (api *RestAPI) applyAlerts(ctx context.Context, ac *arrivalContext, state *locationArrivalsState) { + tripAlerts := api.GtfsManager.GetAlertsForTrip(ctx, ac.st.TripID) + ac.situationIDs = make([]string, 0, len(tripAlerts)) + for _, alert := range tripAlerts { + if alert.ID == "" { + continue + } + ac.situationIDs = append(ac.situationIDs, alert.ID) + if _, seen := state.collectedAlerts[alert.ID]; !seen { + state.collectedAlerts[alert.ID] = alert + } + } +} + +func (api *RestAPI) sortLocationArrivalsByTime(arrivals []models.ArrivalAndDeparture) { + sort.Slice(arrivals, func(i, j int) bool { + ai := arrivals[i] + aj := arrivals[j] + var ti, tj time.Time + if !ai.PredictedArrivalTime.IsZero() { + ti = ai.PredictedArrivalTime.Time + } else { + ti = ai.ScheduledArrivalTime.Time + } + if !aj.PredictedArrivalTime.IsZero() { + tj = aj.PredictedArrivalTime.Time + } else { + tj = aj.ScheduledArrivalTime.Time + } + return ti.Before(tj) + }) +} + +func (api *RestAPI) buildLocationReferencesBlock(ctx context.Context, state *locationArrivalsState) (*models.ReferencesModel, []string) { + references := models.NewEmptyReferences() + addedAgencyIDs := make(map[string]bool) + + api.addTripReferences(ctx, state, references) + api.addRouteAndAgencyReferences(ctx, state, references, addedAgencyIDs) + api.addStopReferences(ctx, state, references) + + topLevelSituationIDs := make([]string, 0, len(state.collectedAlerts)) + if len(state.collectedAlerts) > 0 { + alertSlice := make([]gtfs.Alert, 0, len(state.collectedAlerts)) + for alertID, a := range state.collectedAlerts { + alertSlice = append(alertSlice, a) + topLevelSituationIDs = append(topLevelSituationIDs, alertID) + } + references.Situations = append(references.Situations, api.BuildSituationReferences(alertSlice)...) + } + + return references, topLevelSituationIDs +} + +func (api *RestAPI) addTripReferences(ctx context.Context, state *locationArrivalsState, references *models.ReferencesModel) { + for _, trip := range state.tripIDSet { + routeForTrip, ok := state.routeIDSet[trip.RouteID] + if !ok { + if fetched, fErr := api.GtfsManager.GtfsDB.Queries.GetRoute(ctx, trip.RouteID); fErr == nil { + fCopy := fetched + state.routeIDSet[fetched.ID] = &fCopy + routeForTrip = &fCopy + } else { + api.Logger.Warn("failed to fetch route for trip reference", "tripID", trip.ID, "routeID", trip.RouteID) + continue + } + } + references.Trips = append(references.Trips, *models.NewTripReference( + utils.FormCombinedID(routeForTrip.AgencyID, trip.ID), + utils.FormCombinedID(routeForTrip.AgencyID, trip.RouteID), + utils.FormCombinedID(routeForTrip.AgencyID, trip.ServiceID), + trip.TripHeadsign.String, + "", + strconv.FormatInt(trip.DirectionID.Int64, 10), + utils.FormCombinedID(routeForTrip.AgencyID, trip.BlockID.String), + utils.FormCombinedID(routeForTrip.AgencyID, trip.ShapeID.String), + )) + } +} + +func (api *RestAPI) addRouteAndAgencyReferences(ctx context.Context, state *locationArrivalsState, references *models.ReferencesModel, addedAgencyIDs map[string]bool) { + for _, route := range state.routeIDSet { + references.Routes = append(references.Routes, models.NewRoute( + utils.FormCombinedID(route.AgencyID, route.ID), + route.AgencyID, + route.ShortName.String, + route.LongName.String, + route.Desc.String, + models.RouteType(route.Type), + route.Url.String, + route.Color.String, + route.TextColor.String, + )) + if !addedAgencyIDs[route.AgencyID] { + if ag, agErr := api.GtfsManager.GtfsDB.Queries.GetAgency(ctx, route.AgencyID); agErr == nil { + references.Agencies = append(references.Agencies, models.NewAgencyReference( + ag.ID, ag.Name, ag.Url, ag.Timezone, ag.Lang.String, + ag.Phone.String, ag.Email.String, ag.FareUrl.String, "", false, + )) + addedAgencyIDs[ag.ID] = true + } + } + } +} + +func (api *RestAPI) addStopReferences(ctx context.Context, state *locationArrivalsState, references *models.ReferencesModel) { + stopIDsSlice := stringMapKeys(state.stopIDSet) + batchStops, _ := api.GtfsManager.GtfsDB.Queries.GetStopsByIDs(ctx, stopIDsSlice) + batchRoutesForStops, _ := api.GtfsManager.GtfsDB.Queries.GetRoutesForStops(ctx, stopIDsSlice) + + stopsMap := make(map[string]gtfsdb.Stop, len(batchStops)) + for _, s := range batchStops { + stopsMap[s.ID] = s + } + routesByStop := make(map[string][]gtfsdb.GetRoutesForStopsRow) + for _, row := range batchRoutesForStops { + routesByStop[row.StopID] = append(routesByStop[row.StopID], row) + } + + for _, sid := range stopIDsSlice { + stopData, ok := stopsMap[sid] + if !ok { + continue + } + ag := state.stopAgencyMap[sid] + if ag == "" { + ag = state.stopAgencyOverride[sid] + } + if ag == "" { + ag = state.fallbackAgencyID + } + routesForStop := routesByStop[sid] + combinedRouteIDs := make([]string, len(routesForStop)) + for i, rr := range routesForStop { + combinedRouteIDs[i] = utils.FormCombinedID(rr.AgencyID, rr.ID) + if _, exists := state.routeIDSet[rr.ID]; !exists { + rc := gtfsdb.Route{ + ID: rr.ID, + AgencyID: rr.AgencyID, + ShortName: rr.ShortName, + LongName: rr.LongName, + Desc: rr.Desc, + Type: rr.Type, + Url: rr.Url, + Color: rr.Color, + TextColor: rr.TextColor, + } + state.routeIDSet[rr.ID] = &rc + } + } + references.Stops = append(references.Stops, models.Stop{ + ID: utils.FormCombinedID(ag, stopData.ID), + Name: stopData.Name.String, + Lat: stopData.Lat, + Lon: stopData.Lon, + Code: stopData.Code.String, + Direction: api.DirectionCalculator.CalculateStopDirection(ctx, stopData.ID, stopData.Direction), + LocationType: int(stopData.LocationType.Int64), + WheelchairBoarding: utils.MapWheelchairBoarding(nulls.WheelchairBoardingOrUnknown(stopData.WheelchairBoarding)), + RouteIDs: combinedRouteIDs, + StaticRouteIDs: combinedRouteIDs, + }) + } +} + +func (api *RestAPI) buildLocationQueriedStopIDs(stops []gtfsdb.Stop, state *locationArrivalsState) []string { + queriedStopIDs := make([]string, 0, len(state.stopsWithArrivals)) + for _, dbStop := range stops { + if state.stopsWithArrivals[dbStop.ID] { + ag := state.stopAgencyMap[dbStop.ID] + if ag == "" { + ag = state.fallbackAgencyID + } + queriedStopIDs = append(queriedStopIDs, utils.FormCombinedID(ag, dbStop.ID)) + } + } + return queriedStopIDs +} + +// getLocationNearbyStops returns stops near the query centre together with their +// distance from the centre, sorted ascending by distance. +func getLocationNearbyStops( + api *RestAPI, + ctx context.Context, + centerLat, centerLon float64, +) []models.StopWithDistance { + + nearby, _ := api.GtfsManager.GetStopsForLocation( + ctx, + &internalgtfs.LocationParams{ + Lat: centerLat, + Lon: centerLon, + Radius: models.DefaultSearchRadiusInMeters, + }, + "", + 250, + []int{}, + ) + + if len(nearby) == 0 { + return nil + } + + candidateIDs := make([]string, len(nearby)) + for i, s := range nearby { + candidateIDs[i] = s.ID + } + + nearbyAgencyMap := make(map[string]string, len(candidateIDs)) + agencyRows, err := api.GtfsManager.GtfsDB.Queries.GetAgenciesForStops(ctx, candidateIDs) + if err != nil { + api.Logger.Warn("failed to resolve agencies for nearby stops", "error", err) + } else { + for _, row := range agencyRows { + if _, exists := nearbyAgencyMap[row.StopID]; !exists { + nearbyAgencyMap[row.StopID] = row.ID + } + } + } + + nearbyFallback := pickPrimaryAgency(nearbyAgencyMap) + + result := make([]models.StopWithDistance, 0, len(nearby)) + for _, s := range nearby { + ag := nearbyFallback + if resolved, ok := nearbyAgencyMap[s.ID]; ok { + ag = resolved + } + combinedID := utils.FormCombinedID(ag, s.ID) + + dist := utils.Distance(centerLat, centerLon, s.Lat, s.Lon) + result = append(result, models.StopWithDistance{ + StopID: combinedID, + DistanceFromQuery: dist, + }) + } + + if len(result) == 0 { + return nil + } + + sort.Slice(result, func(i, j int) bool { + return result[i].DistanceFromQuery < result[j].DistanceFromQuery + }) + + return result +} + +// stringMapKeys returns the keys of a map[string]bool as a string slice. +func stringMapKeys(m map[string]bool) []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + return keys +} + +// pickPrimaryAgency returns the agency ID that appears most frequently in the +// stopCode→agencyID map. Used only as a fallback when a stop has no resolved +// agency — never used to prefix alert IDs directly. +func pickPrimaryAgency(stopAgencyMap map[string]string) string { + counts := make(map[string]int, 4) + for _, ag := range stopAgencyMap { + counts[ag]++ + } + best := "" + bestCount := 0 + for ag, cnt := range counts { + if cnt > bestCount || (cnt == bestCount && ag < best) { + best = ag + bestCount = cnt + } + } + return best +} diff --git a/internal/restapi/arrivals_and_departures_for_location_handler_test.go b/internal/restapi/arrivals_and_departures_for_location_handler_test.go new file mode 100644 index 00000000..1f8033fa --- /dev/null +++ b/internal/restapi/arrivals_and_departures_for_location_handler_test.go @@ -0,0 +1,494 @@ +package restapi + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "maglev.onebusaway.org/internal/clock" +) + +// --- Param parsing unit tests --- + +func TestParseArrivalsAndDeparturesForLocationParams_Defaults(t *testing.T) { + api := createTestApi(t) + defer api.Shutdown() + + req := httptest.NewRequest("GET", "/test?lat=47.653&lon=-122.307&latSpan=0.008&lonSpan=0.008", nil) + params, errs := api.parseArrivalsAndDeparturesForLocationParams(req) + + assert.Nil(t, errs) + assert.Equal(t, 47.653, params.Lat) + assert.Equal(t, -122.307, params.Lon) + assert.Equal(t, 0.008, params.LatSpan) + assert.Equal(t, 0.008, params.LonSpan) + assert.Equal(t, 5, params.MinutesBefore) + assert.Equal(t, 35, params.MinutesAfter) + assert.Equal(t, 250, params.MaxCount) + assert.WithinDuration(t, api.Clock.Now(), params.Time, time.Second) +} + +func TestParseArrivalsAndDeparturesForLocationParams_CustomValues(t *testing.T) { + api := createTestApi(t) + defer api.Shutdown() + + req := httptest.NewRequest("GET", + "/test?lat=47.653&lon=-122.307&radius=500&minutesBefore=10&minutesAfter=60&maxCount=50&time=1609459200000", nil) + params, errs := api.parseArrivalsAndDeparturesForLocationParams(req) + + assert.Nil(t, errs) + assert.Equal(t, 47.653, params.Lat) + assert.Equal(t, -122.307, params.Lon) + assert.Equal(t, 500.0, params.Radius) + assert.Equal(t, 10, params.MinutesBefore) + assert.Equal(t, 60, params.MinutesAfter) + assert.Equal(t, 50, params.MaxCount) + assert.False(t, params.Time.IsZero()) +} + +func TestParseArrivalsAndDeparturesForLocationParams_MissingLatLon(t *testing.T) { + api := createTestApi(t) + defer api.Shutdown() + + req := httptest.NewRequest("GET", "/test", nil) + _, errs := api.parseArrivalsAndDeparturesForLocationParams(req) + + assert.NotNil(t, errs) + assert.Contains(t, errs, "lat") + assert.Contains(t, errs, "lon") +} + +func TestParseArrivalsAndDeparturesForLocationParams_InvalidTime(t *testing.T) { + api := createTestApi(t) + defer api.Shutdown() + + req := httptest.NewRequest("GET", + "/test?lat=47.653&lon=-122.307&latSpan=0.008&lonSpan=0.008&time=notanumber", nil) + _, errs := api.parseArrivalsAndDeparturesForLocationParams(req) + + assert.NotNil(t, errs) + assert.Contains(t, errs, "time") + assert.Equal(t, "must be a valid Unix timestamp in milliseconds", errs["time"][0]) +} + +func TestParseArrivalsAndDeparturesForLocationParams_InvalidMinutesAfter(t *testing.T) { + api := createTestApi(t) + defer api.Shutdown() + + req := httptest.NewRequest("GET", + "/test?lat=47.653&lon=-122.307&latSpan=0.008&lonSpan=0.008&minutesAfter=notanumber", nil) + _, errs := api.parseArrivalsAndDeparturesForLocationParams(req) + + assert.NotNil(t, errs) + assert.Contains(t, errs, "minutesAfter") + assert.Equal(t, "must be a valid integer", errs["minutesAfter"][0]) +} + +func TestParseArrivalsAndDeparturesForLocationParams_InvalidMinutesBefore(t *testing.T) { + api := createTestApi(t) + defer api.Shutdown() + + req := httptest.NewRequest("GET", + "/test?lat=47.653&lon=-122.307&latSpan=0.008&lonSpan=0.008&minutesBefore=notanumber", nil) + _, errs := api.parseArrivalsAndDeparturesForLocationParams(req) + + assert.NotNil(t, errs) + assert.Contains(t, errs, "minutesBefore") + assert.Equal(t, "must be a valid integer", errs["minutesBefore"][0]) +} + +func TestParseArrivalsAndDeparturesForLocationParams_NegativeMinutes(t *testing.T) { + api := createTestApi(t) + defer api.Shutdown() + + req := httptest.NewRequest("GET", + "/test?lat=47.653&lon=-122.307&latSpan=0.008&lonSpan=0.008&minutesBefore=-1&minutesAfter=-5", nil) + _, errs := api.parseArrivalsAndDeparturesForLocationParams(req) + + assert.NotNil(t, errs) + assert.Contains(t, errs, "minutesBefore") + assert.Contains(t, errs, "minutesAfter") +} + +func TestParseArrivalsAndDeparturesForLocationParams_MinutesCappedAtMax(t *testing.T) { + api := createTestApi(t) + defer api.Shutdown() + + req := httptest.NewRequest("GET", + "/test?lat=47.653&lon=-122.307&latSpan=0.008&lonSpan=0.008&minutesBefore=9999&minutesAfter=9999", nil) + params, errs := api.parseArrivalsAndDeparturesForLocationParams(req) + + assert.Nil(t, errs) + assert.Equal(t, 60, params.MinutesBefore) + assert.Equal(t, 240, params.MinutesAfter) +} + +func TestParseArrivalsAndDeparturesForLocationParams_InvalidMaxCount(t *testing.T) { + api := createTestApi(t) + defer api.Shutdown() + + req := httptest.NewRequest("GET", + "/test?lat=47.653&lon=-122.307&latSpan=0.008&lonSpan=0.008&maxCount=0", nil) + _, errs := api.parseArrivalsAndDeparturesForLocationParams(req) + + assert.NotNil(t, errs) + assert.Contains(t, errs, "maxCount") +} + +// --- HTTP handler integration tests --- + +func TestParseArrivalsAndDeparturesForLocationParams_FrequencyAndRouteType(t *testing.T) { + api := createTestApi(t) + defer api.Shutdown() + + req := httptest.NewRequest("GET", + "/test?lat=47.653&lon=-122.307&radius=500&frequencyMinutesBefore=15&frequencyMinutesAfter=45&emptyReturnsNotFound=true&routeType=1,3", nil) + params, errs := api.parseArrivalsAndDeparturesForLocationParams(req) + + assert.Nil(t, errs) + assert.Equal(t, 15, params.FrequencyMinutesBefore) + assert.Equal(t, 45, params.FrequencyMinutesAfter) + assert.True(t, params.EmptyReturnsNotFound) + assert.Equal(t, []int{1, 3}, params.RouteTypes) +} + +func TestParseArrivalsAndDeparturesForLocationParams_InvalidRouteType(t *testing.T) { + api := createTestApi(t) + defer api.Shutdown() + + req := httptest.NewRequest("GET", + "/test?lat=47.6&lon=-122.3&routeType=1,abc", nil) + _, errs := api.parseArrivalsAndDeparturesForLocationParams(req) + + assert.NotNil(t, errs) + assert.Contains(t, errs, "routeType") +} + +func TestArrivalsAndDeparturesForLocationRequiresValidAPIKey(t *testing.T) { + _, resp, model := serveAndRetrieveEndpoint(t, + "/api/where/arrivals-and-departures-for-location.json?key=invalid&lat=40.583321&lon=-122.426966&latSpan=0.01&lonSpan=0.01") + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.Equal(t, http.StatusUnauthorized, model.Code) + assert.Equal(t, "permission denied", model.Text) +} + +func TestArrivalsAndDeparturesForLocationMissingLatLon(t *testing.T) { + _, resp, _ := serveAndRetrieveEndpoint(t, + "/api/where/arrivals-and-departures-for-location.json?key=TEST") + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) +} + +func TestArrivalsAndDeparturesForLocationInvalidTime(t *testing.T) { + _, resp, _ := serveAndRetrieveEndpoint(t, + "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=40.583321&lon=-122.426966&latSpan=0.01&lonSpan=0.01&time=notanumber") + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) +} + +func TestArrivalsAndDeparturesForLocationEmptyAreaReturnsOK(t *testing.T) { + // Coordinates far from any test GTFS data so no stops are found. + mockClock := clock.NewMockClock(time.Date(2025, 12, 26, 14, 0, 0, 0, time.UTC)) + api := createTestApiWithClock(t, mockClock) + + resp, model := serveApiAndRetrieveEndpoint(t, api, + "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=0.0&lon=0.0&latSpan=0.001&lonSpan=0.001") + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, 200, model.Code) + assert.Equal(t, "OK", model.Text) + assert.Equal(t, 2, model.Version) + assert.NotZero(t, model.CurrentTime) + + data, ok := model.Data.(map[string]interface{}) + require.True(t, ok) + + entry, ok := data["entry"].(map[string]interface{}) + require.True(t, ok) + + // Entry must contain all expected keys even when empty. + assert.Contains(t, entry, "arrivalsAndDepartures") + assert.Contains(t, entry, "stopIds") + assert.Contains(t, entry, "nearbyStopIds") + assert.Contains(t, entry, "situationIds") + assert.Contains(t, entry, "limitExceeded") + + ads, ok := entry["arrivalsAndDepartures"].([]interface{}) + require.True(t, ok) + assert.Empty(t, ads) + + stopIDs, ok := entry["stopIds"].([]interface{}) + require.True(t, ok) + assert.Empty(t, stopIDs) + + assert.False(t, entry["limitExceeded"].(bool)) +} + +func TestArrivalsAndDeparturesForLocationEmptyReturnsNotFound(t *testing.T) { + mockClock := clock.NewMockClock(time.Date(2025, 12, 26, 14, 0, 0, 0, time.UTC)) + api := createTestApiWithClock(t, mockClock) + + resp, model := serveApiAndRetrieveEndpoint(t, api, + "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=0.0&lon=0.0&radius=100&emptyReturnsNotFound=true") + + assert.Equal(t, http.StatusNotFound, resp.StatusCode) + assert.Equal(t, 404, model.Code) +} + +func TestArrivalsAndDeparturesForLocationEndToEnd(t *testing.T) { + mockClock := clock.NewMockClock(time.Date(2025, 12, 26, 14, 0, 0, 0, time.UTC)) + api := createTestApiWithClock(t, mockClock) + + resp, model := serveApiAndRetrieveEndpoint(t, api, + "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=40.583321&lon=-122.426966&radius=2500") + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, 200, model.Code) + assert.Equal(t, "OK", model.Text) + assert.Equal(t, 2, model.Version) + assert.NotZero(t, model.CurrentTime) + + data, ok := model.Data.(map[string]interface{}) + require.True(t, ok, "data should be a map") + + entry, ok := data["entry"].(map[string]interface{}) + require.True(t, ok, "entry should be a map") + + // Required entry keys. + assert.Contains(t, entry, "arrivalsAndDepartures") + assert.Contains(t, entry, "stopIds") + assert.Contains(t, entry, "nearbyStopIds") + assert.Contains(t, entry, "situationIds") + assert.Contains(t, entry, "limitExceeded") + + // nearbyStopIds must be a list of objects with stopId + distanceFromQuery. + nearbyRaw, ok := entry["nearbyStopIds"].([]interface{}) + require.True(t, ok, "nearbyStopIds should be a list") + for _, item := range nearbyRaw { + nearby, ok := item.(map[string]interface{}) + require.True(t, ok, "each nearbyStopIds entry should be an object") + assert.Contains(t, nearby, "stopId") + assert.Contains(t, nearby, "distanceFromQuery") + } + + // stopIds must be a list. + stopIDs, ok := entry["stopIds"].([]interface{}) + require.True(t, ok, "stopIds should be a list") + assert.NotEmpty(t, stopIDs, "should have found stops in this area") + + // References block. + refs, ok := data["references"].(map[string]interface{}) + require.True(t, ok, "references should be a map") + assert.Contains(t, refs, "agencies") + assert.Contains(t, refs, "routes") + assert.Contains(t, refs, "stops") + assert.Contains(t, refs, "trips") + assert.Contains(t, refs, "situations") + + // Validate arrival shape if any were returned. + ads, ok := entry["arrivalsAndDepartures"].([]interface{}) + require.True(t, ok, "arrivalsAndDepartures should be a list") + + if len(ads) == 0 { + t.Skip("no arrivals in test data for this time/location") + } + + ad, ok := ads[0].(map[string]interface{}) + require.True(t, ok, "first arrival should be a map") + + // Required arrival fields. + for _, field := range []string{ + "routeId", "tripId", "stopId", "serviceDate", + "scheduledArrivalTime", "scheduledDepartureTime", + "predictedArrivalTime", "predictedDepartureTime", + "predicted", "status", "situationIds", + "routeShortName", "tripHeadsign", + "arrivalEnabled", "departureEnabled", + "numberOfStopsAway", "distanceFromStop", + "blockTripSequence", "totalStopsInTrip", + "frequency", + } { + assert.Contains(t, ad, field, "arrival must contain field %q", field) + } + + assert.Equal(t, "default", ad["status"]) + + // Every arrival's stopId must be one of the queried stopIds. + stopIDInAD, _ := ad["stopId"].(string) + assert.NotEmpty(t, stopIDInAD) + assert.Contains(t, stopIDs, stopIDInAD, + "arrival stopId should be one of the queried stopIds") +} + +func TestArrivalsAndDeparturesForLocationStopIdsOnlyContainsStopsWithArrivals(t *testing.T) { + // Java only includes a stop in stopIds when it has at least one arrival. + mockClock := clock.NewMockClock(time.Date(2025, 12, 26, 14, 0, 0, 0, time.UTC)) + api := createTestApiWithClock(t, mockClock) + + resp, model := serveApiAndRetrieveEndpoint(t, api, + "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=40.583321&lon=-122.426966&radius=2500") + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + data, ok := model.Data.(map[string]interface{}) + require.True(t, ok) + entry, ok := data["entry"].(map[string]interface{}) + require.True(t, ok) + + ads, _ := entry["arrivalsAndDepartures"].([]interface{}) + stopIDs, _ := entry["stopIds"].([]interface{}) + + if len(ads) == 0 { + // No arrivals → stopIds must also be empty. + assert.Empty(t, stopIDs, "stopIds must be empty when there are no arrivals") + return + } + + // Every stopId in the entry must appear in at least one arrival's stopId field. + arrivalStopIDs := make(map[interface{}]bool) + for _, adRaw := range ads { + if ad, ok := adRaw.(map[string]interface{}); ok { + arrivalStopIDs[ad["stopId"]] = true + } + } + for _, sid := range stopIDs { + assert.True(t, arrivalStopIDs[sid], + "stopId %v in entry.stopIds has no matching arrival", sid) + } +} + +func TestArrivalsAndDeparturesForLocationWithLatSpanLonSpan(t *testing.T) { + mockClock := clock.NewMockClock(time.Date(2025, 12, 26, 14, 0, 0, 0, time.UTC)) + api := createTestApiWithClock(t, mockClock) + + resp, model := serveApiAndRetrieveEndpoint(t, api, + "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=40.583321&lon=-122.426966&latSpan=0.045&lonSpan=0.059") + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, 200, model.Code) + + data, ok := model.Data.(map[string]interface{}) + require.True(t, ok) + entry, ok := data["entry"].(map[string]interface{}) + require.True(t, ok) + + assert.Contains(t, entry, "stopIds") + assert.Contains(t, entry, "arrivalsAndDepartures") + assert.Contains(t, entry, "limitExceeded") +} + +func TestArrivalsAndDeparturesForLocationReferencesConsistency(t *testing.T) { + mockClock := clock.NewMockClock(time.Date(2025, 12, 26, 14, 0, 0, 0, time.UTC)) + api := createTestApiWithClock(t, mockClock) + + resp, model := serveApiAndRetrieveEndpoint(t, api, + "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=40.583321&lon=-122.426966&radius=2500") + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + data, ok := model.Data.(map[string]interface{}) + require.True(t, ok) + entry, ok := data["entry"].(map[string]interface{}) + require.True(t, ok) + refs, ok := data["references"].(map[string]interface{}) + require.True(t, ok) + + ads, _ := entry["arrivalsAndDepartures"].([]interface{}) + if len(ads) == 0 { + t.Skip("no arrivals in test data for this location") + } + + routeRefs, _ := refs["routes"].([]interface{}) + tripRefs, _ := refs["trips"].([]interface{}) + agencies, _ := refs["agencies"].([]interface{}) + + routeRefIDs := collectAllIdsFromObjects(t, routeRefs, "id") + tripRefIDs := collectAllIdsFromObjects(t, tripRefs, "id") + agencyRefIDs := collectAllIdsFromObjects(t, agencies, "id") + + // Every arrival's routeId and tripId must appear in references. + for _, adRaw := range ads { + ad, ok := adRaw.(map[string]interface{}) + require.True(t, ok) + + routeID, _ := ad["routeId"].(string) + assert.Contains(t, routeRefIDs, routeID, + "every arrival routeId must appear in references.routes") + + tripID, _ := ad["tripId"].(string) + assert.Contains(t, tripRefIDs, tripID, + "every arrival tripId must appear in references.trips") + } + + // Every route's agencyId must appear in references.agencies. + agencyIDsFromRoutes := collectAllIdsFromObjects(t, routeRefs, "agencyId") + for _, aid := range agencyIDsFromRoutes { + assert.Contains(t, agencyRefIDs, aid, + "every route agencyId must appear in references.agencies") + } +} + +func TestArrivalsAndDeparturesForLocationArrivalsAreSortedByTime(t *testing.T) { + mockClock := clock.NewMockClock(time.Date(2025, 12, 26, 14, 0, 0, 0, time.UTC)) + api := createTestApiWithClock(t, mockClock) + + resp, model := serveApiAndRetrieveEndpoint(t, api, + "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=40.583321&lon=-122.426966&radius=2500") + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + data, ok := model.Data.(map[string]interface{}) + require.True(t, ok) + entry, ok := data["entry"].(map[string]interface{}) + require.True(t, ok) + + ads, _ := entry["arrivalsAndDepartures"].([]interface{}) + if len(ads) < 2 { + t.Skip("need at least 2 arrivals to test sort order") + } + + var prevTime float64 + for i, adRaw := range ads { + ad, ok := adRaw.(map[string]interface{}) + require.True(t, ok) + + predicted, _ := ad["predicted"].(bool) + var arrTime float64 + if predicted { + arrTime, _ = ad["predictedArrivalTime"].(float64) + } + if arrTime == 0 { + arrTime, _ = ad["scheduledArrivalTime"].(float64) + } + + if i > 0 { + assert.GreaterOrEqual(t, arrTime, prevTime, + "arrivals must be sorted ascending by arrival time (index %d)", i) + } + prevTime = arrTime + } +} + +func TestArrivalsAndDeparturesForLocationLimitExceeded(t *testing.T) { + mockClock := clock.NewMockClock(time.Date(2025, 12, 26, 14, 0, 0, 0, time.UTC)) + api := createTestApiWithClock(t, mockClock) + + // maxCount=1 forces limitExceeded=true if there is more than 1 arrival. + resp, model := serveApiAndRetrieveEndpoint(t, api, + "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=40.583321&lon=-122.426966&radius=2500&maxCount=1") + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + data, ok := model.Data.(map[string]interface{}) + require.True(t, ok) + entry, ok := data["entry"].(map[string]interface{}) + require.True(t, ok) + + ads, _ := entry["arrivalsAndDepartures"].([]interface{}) + assert.LessOrEqual(t, len(ads), 1) + assert.Equal(t, true, entry["limitExceeded"]) +} diff --git a/internal/restapi/arrivals_and_departures_for_stop_handler.go b/internal/restapi/arrivals_and_departures_for_stop_handler.go index 19edda08..afee2785 100644 --- a/internal/restapi/arrivals_and_departures_for_stop_handler.go +++ b/internal/restapi/arrivals_and_departures_for_stop_handler.go @@ -340,7 +340,7 @@ func (api *RestAPI) arrivalsAndDeparturesForStopHandler(w http.ResponseWriter, r if vehicle != nil { // Use route.AgencyID instead of stopAgencyID for BuildTripStatus - status, statusErr := api.BuildTripStatus(ctx, route.AgencyID, st.TripID, nil, serviceMidnight, params.Time) + status, statusErr := api.BuildTripStatus(ctx, route.AgencyID, st.TripID, vehicle, serviceMidnight, params.Time) if statusErr != nil { api.Logger.Warn("BuildTripStatus failed for arrival", "tripID", st.TripID, "error", statusErr) @@ -418,7 +418,7 @@ func (api *RestAPI) arrivalsAndDeparturesForStopHandler(w http.ResponseWriter, r continue } - situationIDs = append(situationIDs, utils.FormCombinedID(route.AgencyID, alert.ID)) + situationIDs = append(situationIDs, alert.ID) if _, seen := collectedAlerts[alert.ID]; !seen { collectedAlerts[alert.ID] = alert } @@ -619,7 +619,7 @@ func (api *RestAPI) arrivalsAndDeparturesForStopHandler(w http.ResponseWriter, r topLevelSituationIDSet := make(map[string]struct{}, len(collectedAlerts)) for alertID := range collectedAlerts { - topLevelSituationIDSet[utils.FormCombinedID(alertAgencyID, alertID)] = struct{}{} + topLevelSituationIDSet[alertID] = struct{}{} } topLevelSituationIDs := make([]string, 0, len(topLevelSituationIDSet)) for id := range topLevelSituationIDSet { diff --git a/internal/restapi/context_cancellation_test.go b/internal/restapi/context_cancellation_test.go index 0ce8124b..22ac6587 100644 --- a/internal/restapi/context_cancellation_test.go +++ b/internal/restapi/context_cancellation_test.go @@ -23,27 +23,32 @@ func TestContextCancellationHandling(t *testing.T) { }{ { name: "agencies with coverage should handle context cancellation", - endpoint: "/api/where/agencies-with-coverage.json?key=test", + endpoint: "/api/where/agencies-with-coverage.json?key=org.onebusaway.iphone", timeout: 1 * time.Nanosecond, // Very short timeout to trigger cancellation }, { name: "stop IDs for agency should handle context cancellation", - endpoint: "/api/where/stop-ids-for-agency/1?key=test", + endpoint: "/api/where/stop-ids-for-agency/1?key=org.onebusaway.iphone", timeout: 1 * time.Nanosecond, }, { name: "routes for location should handle context cancellation", - endpoint: "/api/where/routes-for-location.json?lat=38.9&lon=-77.0&key=test", + endpoint: "/api/where/routes-for-location.json?lat=38.9&lon=-77.0&key=org.onebusaway.iphone", timeout: 1 * time.Nanosecond, }, { name: "stops for location should handle context cancellation", - endpoint: "/api/where/stops-for-location.json?lat=38.9&lon=-77.0&key=test", + endpoint: "/api/where/stops-for-location.json?lat=38.9&lon=-77.0&key=org.onebusaway.iphone", + timeout: 1 * time.Nanosecond, + }, + { + name: "arrivals and departures for location should handle context cancellation", + endpoint: "/api/where/arrivals-and-departures-for-location.json?lat=38.9&lon=-77.0&latSpan=0.01&lonSpan=0.01&key=org.onebusaway.iphone", timeout: 1 * time.Nanosecond, }, { name: "stops for route should handle context cancellation", - endpoint: "/api/where/stops-for-route/1?key=test", + endpoint: "/api/where/stops-for-route/1?key=org.onebusaway.iphone", timeout: 1 * time.Nanosecond, }, } @@ -76,7 +81,8 @@ func TestContextCancellationHandling(t *testing.T) { // If cancelled, we expect a timeout or cancellation error response statusCode := w.Code - // Valid responses: 200 (completed), 401 (API validation), 500 (error), or timeout-related + // Valid responses: 200 (completed), 401 (API validation), 500 (error), timeout-related, + // or 404 (not found). Rate limit 429 is prevented by using an exempt key. assert.True(t, statusCode == http.StatusOK || statusCode == http.StatusUnauthorized || // API key validation happens first statusCode == http.StatusBadRequest || @@ -84,7 +90,7 @@ func TestContextCancellationHandling(t *testing.T) { statusCode == http.StatusRequestTimeout || statusCode == http.StatusGatewayTimeout || statusCode == http.StatusNotFound, - "Expected status 200, 401, 404, 500, 408, or 504, got %d", statusCode) + "Expected status 200, 401, 400, 404, 500, 408, or 504, got %d", statusCode) }) } } @@ -95,7 +101,7 @@ func TestLongerTimeoutContextHandling(t *testing.T) { // Test with a reasonable timeout that should allow completion t.Run("reasonable timeout should complete successfully", func(t *testing.T) { - req, err := http.NewRequest("GET", "/api/where/agencies-with-coverage.json?key=test", nil) + req, err := http.NewRequest("GET", "/api/where/agencies-with-coverage.json?key=org.onebusaway.iphone", nil) require.NoError(t, err) // Create context with reasonable timeout diff --git a/internal/restapi/coverage_test.go b/internal/restapi/coverage_test.go index 61effc5a..2eae4795 100644 --- a/internal/restapi/coverage_test.go +++ b/internal/restapi/coverage_test.go @@ -12,7 +12,11 @@ import ( ) func TestBuildSituationReferencesCoverage(t *testing.T) { - api := &RestAPI{} + api := &RestAPI{ + Application: &app.Application{ + Clock: clock.NewMockClock(time.Now()), + }, + } alerts := []gtfs.Alert{ { diff --git a/internal/restapi/input_validation_integration_test.go b/internal/restapi/input_validation_integration_test.go index d7d8a269..e635213e 100644 --- a/internal/restapi/input_validation_integration_test.go +++ b/internal/restapi/input_validation_integration_test.go @@ -157,6 +157,44 @@ func TestInputValidationIntegration(t *testing.T) { expectedStatus: http.StatusBadRequest, expectedError: "invalid date format", }, + + // Test arrivals-and-departures-for-location parameter validation + { + name: "arrivals-for-location: invalid latitude too high", + endpoint: "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=91.0&lon=-77.0&radius=500", + expectedStatus: http.StatusBadRequest, + expectedError: "latitude must be between -90 and 90", + }, + { + name: "arrivals-for-location: invalid longitude too high", + endpoint: "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=38.0&lon=181.0&radius=500", + expectedStatus: http.StatusBadRequest, + expectedError: "longitude must be between -180 and 180", + }, + { + name: "arrivals-for-location: missing lat and lon", + endpoint: "/api/where/arrivals-and-departures-for-location.json?key=TEST", + expectedStatus: http.StatusBadRequest, + expectedError: "", + }, + { + name: "arrivals-for-location: invalid minutesAfter", + endpoint: "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=38.9&lon=-77.0&radius=500&minutesAfter=notanumber", + expectedStatus: http.StatusBadRequest, + expectedError: "must be a valid integer", + }, + { + name: "arrivals-for-location: invalid minutesBefore", + endpoint: "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=38.9&lon=-77.0&radius=500&minutesBefore=notanumber", + expectedStatus: http.StatusBadRequest, + expectedError: "must be a valid integer", + }, + { + name: "arrivals-for-location: invalid time", + endpoint: "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=38.9&lon=-77.0&radius=500&time=notanumber", + expectedStatus: http.StatusBadRequest, + expectedError: "must be a valid Unix timestamp in milliseconds", + }, } for _, tt := range tests { @@ -253,6 +291,18 @@ func TestValidInputsPassThrough(t *testing.T) { name: "Valid location with span parameters", endpoint: "/api/where/stops-for-location.json?key=TEST&lat=38.9&lon=-77.0&latSpan=0.01&lonSpan=0.01", }, + { + name: "Valid arrivals-for-location with radius", + endpoint: "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=38.9&lon=-77.0&radius=1000", + }, + { + name: "Valid arrivals-for-location with latSpan and lonSpan", + endpoint: "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=38.9&lon=-77.0&latSpan=0.01&lonSpan=0.01", + }, + { + name: "Valid arrivals-for-location with custom time window", + endpoint: "/api/where/arrivals-and-departures-for-location.json?key=TEST&lat=38.9&lon=-77.0&radius=1000&minutesBefore=2&minutesAfter=20", + }, } for _, tt := range validTests { diff --git a/internal/restapi/reference_utils.go b/internal/restapi/reference_utils.go index 5bb2fadb..a929ceac 100644 --- a/internal/restapi/reference_utils.go +++ b/internal/restapi/reference_utils.go @@ -2,7 +2,6 @@ package restapi import ( "context" - "time" "github.com/OneBusAway/go-gtfs" "maglev.onebusaway.org/gtfsdb" @@ -79,11 +78,24 @@ func (api *RestAPI) BuildSituationReferences(alerts []gtfs.Alert) []models.Situa for _, alert := range alerts { situation := models.Situation{ ID: alert.ID, - CreationTime: models.NewModelTime(time.Time{}), + CreationTime: models.NewModelTime(api.Clock.Now()), ActiveWindows: make([]models.ActiveWindow, 0, len(alert.ActivePeriods)), AllAffects: make([]models.AffectedEntity, 0, len(alert.InformedEntities)), ConsequenceMessage: "", - Consequences: []any{}, + Consequences: []models.Consequence{ + { + Condition: "", + ConditionDetails: models.ConditionDetails{ + DiversionPath: models.DiversionPath{ + Length: 0, + Levels: "", + Points: "", + }, + // Initialized to an empty slice so it outputs [] instead of null + DiversionStopIDs: []string{}, + }, + }, + }, PublicationWindows: []any{}, Reason: mapAlertCauseToReason(alert.Cause), Severity: mapAlertEffectToSeverity(alert.Effect), @@ -101,17 +113,29 @@ func (api *RestAPI) BuildSituationReferences(alerts []gtfs.Alert) []models.Situa } for _, entity := range alert.InformedEntities { + agencyID := getStringValue(entity.AgencyID) + + rawRouteID := getStringValue(entity.RouteID) + if rawRouteID != "" { + rawRouteID = utils.FormCombinedID(agencyID, rawRouteID) + } + + rawStopID := getStringValue(entity.StopID) + if rawStopID != "" { + rawStopID = utils.FormCombinedID(agencyID, rawStopID) + } + affectedEntity := models.AffectedEntity{ - AgencyID: getStringValue(entity.AgencyID), + AgencyID: agencyID, ApplicationID: "", DirectionID: entity.DirectionID.String(), - RouteID: getStringValue(entity.RouteID), - StopID: getStringValue(entity.StopID), + RouteID: rawRouteID, + StopID: rawStopID, TripID: "", } - if entity.TripID != nil { - affectedEntity.TripID = entity.TripID.ID + if entity.TripID != nil && entity.TripID.ID != "" { + affectedEntity.TripID = utils.FormCombinedID(agencyID, entity.TripID.ID) } situation.AllAffects = append(situation.AllAffects, affectedEntity) diff --git a/internal/restapi/routes.go b/internal/restapi/routes.go index cbcb8075..578a720c 100644 --- a/internal/restapi/routes.go +++ b/internal/restapi/routes.go @@ -117,6 +117,7 @@ func (api *RestAPI) SetRoutes(mux *http.ServeMux) { mux.Handle("GET /api/where/arrival-and-departure-for-stop/{id}", CacheControlMiddleware(models.CacheDurationShort, rateLimitAndValidateAPIKey(api, api.arrivalAndDepartureForStopHandler))) mux.Handle("GET /api/where/trips-for-route/{id}", CacheControlMiddleware(models.CacheDurationShort, rateLimitAndValidateAPIKey(api, api.tripsForRouteHandler))) mux.Handle("GET /api/where/arrivals-and-departures-for-stop/{id}", CacheControlMiddleware(models.CacheDurationShort, rateLimitAndValidateAPIKey(api, api.arrivalsAndDeparturesForStopHandler))) + mux.Handle("GET /api/where/arrivals-and-departures-for-location.json", CacheControlMiddleware(models.CacheDurationShort, rateLimitAndValidateAPIKey(api, api.arrivalsAndDeparturesForLocationHandler))) } // SetupAPIRoutes creates and configures the API router with all middleware applied globally diff --git a/internal/restapi/trips_helper.go b/internal/restapi/trips_helper.go index d3eccbf0..13b9e8b1 100644 --- a/internal/restapi/trips_helper.go +++ b/internal/restapi/trips_helper.go @@ -67,7 +67,6 @@ func (api *RestAPI) BuildTripStatus( // Predicted is true because the cancellation itself is real-time information. if status.Status == "CANCELED" { status.Predicted = vehicle != nil && !defaultStaleDetector.Check(vehicle, currentTime) - status.Scheduled = !status.Predicted return status, nil } @@ -740,11 +739,7 @@ func (api *RestAPI) GetSituationIDsForTrip(ctx context.Context, tripID string) [ if alert.ID == "" { continue } - if agencyID != "" { - situationIDs = append(situationIDs, utils.FormCombinedID(agencyID, alert.ID)) - } else { - situationIDs = append(situationIDs, alert.ID) - } + situationIDs = append(situationIDs, alert.ID) } return situationIDs diff --git a/internal/restapi/trips_helper_test.go b/internal/restapi/trips_helper_test.go index df381a77..3df95655 100644 --- a/internal/restapi/trips_helper_test.go +++ b/internal/restapi/trips_helper_test.go @@ -462,7 +462,6 @@ func TestBuildTripStatus_ScheduleDeviation_SetsPredicted(t *testing.T) { require.NotZero(t, status.ScheduleDeviation) assert.Equal(t, 120, status.ScheduleDeviation, "ScheduleDeviation should reflect the trip update delay") assert.True(t, status.Predicted, "Predicted should be true when trip update exists") - assert.False(t, status.Scheduled, "Scheduled should be false when predicted is true") } func TestBuildTripStatus_NoRealtimeData_SetsScheduled(t *testing.T) { @@ -488,7 +487,6 @@ func TestBuildTripStatus_NoRealtimeData_SetsScheduled(t *testing.T) { assert.Equal(t, 0, status.ScheduleDeviation, "ScheduleDeviation should be 0 with no real-time data") assert.False(t, status.Predicted, "Predicted should be false with no real-time data") - assert.True(t, status.Scheduled, "Scheduled should be true with no real-time data") assert.Equal(t, "default", status.Status) assert.Equal(t, "scheduled", status.Phase) }