drizzle-team / drizzle-orm

Headless TypeScript ORM with a head. Runs on Node, Bun and Deno. Lives on the Edge and yes, it's a JavaScript ORM too 😅
https://orm.drizzle.team
Apache License 2.0
24.49k stars 643 forks source link

Mysql Geospatial types #523

Open SirCameron opened 1 year ago

SirCameron commented 1 year ago

Describe want to want

I'm looking at Drizzle ORM for a new project and loving it so far :)

I don't see anything about MySQL geospatial types in the docs. Is this on the roadmap?

In more detail: POINT, etc

mauriciabad commented 7 months ago

I managed to get geospatial data (Point) working in MySQL (Planetscale) with this code:

// src/server/helpers/spatial-data/point.ts
import { DriverValueMapper, sql } from 'drizzle-orm'
import { customType } from 'drizzle-orm/mysql-core'
import {
  PointString,
  getPoint,
  pointToString,
} from '~/helpers/spatial-data/point'
import { SRID_CODE } from '.'

// TODO: This type should be MapPoint, but for some reason it is not working
// I'll wait until to fix it until they merge this PR https://github.com/drizzle-team/drizzle-orm/pull/1423
type WrongPointType = PointString

export const pointType = customType<{
  data: WrongPointType
  driverData: string
}>({
  dataType() {
    return `POINT SRID ${SRID_CODE}`
  },
  toDriver(value: WrongPointType | string) {
    const point = getPoint(value)
    if (!point) throw new Error(`Invalid point value: ${JSON.stringify(value)}`)
    return sql`ST_PointFromText(${pointToString(point)}, ${SRID_CODE})`
  },
  fromDriver(value: string): WrongPointType {
    const point = getPoint(value)
    if (!point) throw new Error(`Invalid point value: ${JSON.stringify(value)}`)
    return pointToString(point)
  },
})

export const selectPoint = <
  C extends string,
  D extends DriverValueMapper<D1, D2>,
  D1 = any,
  D2 = any,
>(
  column: C,
  decoder: D
) => {
  return sql<WrongPointType>`ST_AsText(${sql.identifier(column)})`
    .mapWith(decoder)
    .as(column)
}
// src/helpers/spatial-data/point.ts
export type MapPoint = {
  lat: number
  lng: number
}
export type PointString = `POINT(${number} ${number})`

type GeoJsonPointType<
  Lat extends number = number,
  Lng extends number = number,
> = {
  type: 'Point'
  coordinates: [Lng, Lat]
}

/**
 * Extracts the lat and lng from a string.
 * Notice that the order is lng lat, not lat lng. Except in JSON format, where it doesn't matter.
 * @param value String in the format `lng lat`, `POINT(lng lat), `{ "lat": 1,"lng": 1 }` or `{ "x": 1,"y": 1 }`. There can be a comma between the values.
 * @returns Object with lat and lng properties
 */
export function getPoint(
  value: PointString | { x: number; y: number } | MapPoint | GeoJsonPointType
): MapPoint
export function getPoint(value: null | undefined): null
export function getPoint(
  value:
    | string
    | PointString
    | { x: number; y: number }
    | MapPoint
    | GeoJsonPointType
    | null
    | undefined
): MapPoint | null
export function getPoint(
  value:
    | string
    | PointString
    | { x: number; y: number }
    | MapPoint
    | GeoJsonPointType
    | null
    | undefined
): MapPoint | null {
  if (!value) return null
  if (typeof value === 'string') {
    try {
      return getPoint(JSON.parse(value))
    } catch (e) {
      // Ignore
    }

    const matches = value.match(/(?<lng>[\d.-]+) *,? +(?<lat>[\d.-]+)/)
    if (!matches?.groups?.['lat'] || !matches?.groups?.['lng']) {
      return null
    }
    return {
      lat: parseFloat(String(matches.groups['lat'])),
      lng: parseFloat(String(matches.groups['lng'])),
    }
  }

  if ('x' in value && 'y' in value && value?.x && value?.y) {
    return {
      lat: value.y,
      lng: value.x,
    }
  }

  if ('lat' in value && 'lng' in value && value?.lat && value?.lng) {
    return {
      lat: value.lat,
      lng: value.lng,
    }
  }

  if ('coordinates' in value && Array.isArray(value.coordinates)) {
    return {
      lat: value.coordinates[1],
      lng: value.coordinates[0],
    }
  }

  return null
}

export function calculateLocation<
  L extends PointString | null | undefined,
  P extends { location: L },
>(place: P) {
  return {
    ...place,
    location: getPoint(place.location),
  }
}

export function pointToString(value: MapPoint): PointString {
  return `POINT(${value.lng} ${value.lat})`
}
export const places = mysqlTable('place', {
  // Other fields...
  location: pointType('location').notNull(),
})
const getPlace = db.query.places
  .findFirst({
    columns: {
      id: true,
      name: true,
    },
    extras: {
      location: selectPoint('location', places.location),
    },
    where: (place, { eq }) => eq(place.id, sql.placeholder('id')),
  })
  .prepare()

// TRPC endpoint
export const placesRouter = router({
  get: publicProcedure.input(getPlacesSchema).query(async ({ input }) => {
    const result = await getPlace.execute({ id: input.id })
    if (!result) return undefined
    return calculateLocation(result)
  }),
})
mauriciabad commented 7 months ago

And also MultiLineString:

// src/server/helpers/spatial-data/multi-line.ts
import { DriverValueMapper, sql } from 'drizzle-orm'
import { customType } from 'drizzle-orm/mysql-core'
import {
  MultiLineString,
  getMultiLine,
  multiLineToString,
} from '~/helpers/spatial-data/multi-line'
import { SRID_CODE } from '.'

// I'll wait until to fix it until they merge this PR https://github.com/drizzle-team/drizzle-orm/pull/1423
type WrongMultiLineType = MultiLineString

export const multiLineType = customType<{
  data: WrongMultiLineType
  driverData: string
}>({
  dataType() {
    return `MULTILINESTRING SRID ${SRID_CODE}`
  },
  toDriver(value: WrongMultiLineType | string) {
    const multiLine = getMultiLine(value)
    if (!multiLine)
      throw new Error(`Invalid multiLine value: ${JSON.stringify(value)}`)
    return sql`ST_MultiLineStringFromText(${multiLineToString(multiLine)}, ${SRID_CODE})`
  },
  fromDriver(value: string): WrongMultiLineType {
    const multiLine = getMultiLine(value)
    if (!multiLine)
      throw new Error(`Invalid multiLine value: ${JSON.stringify(value)}`)
    return multiLineToString(multiLine)
  },
})

export const selectMultiLine = <
  C extends string,
  D extends DriverValueMapper<D1, D2>,
  D1 = any,
  D2 = any,
>(
  column: C,
  decoder: D
) => {
  return sql<WrongMultiLineType>`ST_AsText(${sql.identifier(column)})`
    .mapWith(decoder)
    .as(column)
}
// src/helpers/spatial-data/multi-line
import { z } from 'zod'

export type MapMultiLine<
  Lat extends number = number,
  Lng extends number = number,
> = [Lat, Lng][][]

/**
 * @example MultiLineString((1 1,2 2,3 3),(4 4,5 5))
 */
export type MultiLineString = `MultiLineString(${string})`

export const multiLineSchema = z.array(
  z.array(z.tuple([z.number(), z.number()]))
)

/**
 * Extracts a multi-line from a string.
 * Notice that the order is lng lat, not lat lng. Except in JSON format, where it doesn't matter.
 * @param value String in the format `lng lat`, `MULTILINE(lng lat), `{ "lat": 1,"lng": 1 }` or `{ "x": 1,"y": 1 }`. There can be a comma between the values.
 * @returns Object representing a multi-line
 */
export function getMultiLine(
  value: MultiLineString | MapMultiLine
): MapMultiLine
export function getMultiLine(value: null | undefined): null
export function getMultiLine(
  value: string | MultiLineString | MapMultiLine | null | undefined
): MapMultiLine | null
export function getMultiLine(
  value: string | MultiLineString | MapMultiLine | null | undefined
): MapMultiLine | null {
  if (!value) return null
  if (typeof value === 'string') {
    try {
      return getMultiLine(JSON.parse(value))
    } catch (e) {
      // Ignore
    }

    return nullIfHasNull(
      Array.from(
        value.matchAll(/(\([\d.-]+ +[\d.-]+( *, *[\d.-]+ +[\d.-]+)*\))/g)
      ).map(([rawLine]) =>
        rawLine
          .slice(1, -1)
          .split(',')
          .map((rawPoint) => {
            const matchesPoints = rawPoint.match(
              /(?<lng>[\d.-]+) +(?<lat>[\d.-]+)/
            )
            if (
              !matchesPoints?.groups?.['lat'] ||
              !matchesPoints?.groups?.['lng']
            ) {
              return null
            }
            return [
              parseFloat(String(matchesPoints.groups['lat'])),
              parseFloat(String(matchesPoints.groups['lng'])),
            ] as const
          })
      )
    )
  }

  try {
    return multiLineSchema.parse(value)
  } catch (e) {
    // Ignore
  }

  return null
}

export function calculatePath<
  L extends MultiLineString,
  P extends { path: L },
>({ path, ...route }: P) {
  return {
    ...route,
    path: getMultiLine(path),
  }
}

export function multiLineToString(value: MapMultiLine): MultiLineString {
  return `MultiLineString(${value.map((line) => `(${line.map(([lat, lng]) => `${lng} ${lat}`).join(',')})`).join(',')})`
}

function nullIfHasNull<T>(l1: (T | null)[][]): T[][] | null {
  if (l1.some((l2) => l2.some((l3) => l3 === null))) {
    return null
  }

  return l1 as T[][]
}

export function multiLineFromGeoJson(value: unknown): MapMultiLine | null {
  try {
    const geoJson = geoJsonSchema.parse(value)
    return geoJson.features.map((feature) =>
      feature.geometry.coordinates.map(([lng, lat]) => [lat, lng])
    )
  } catch (e) {
    return null
  }
}

export function multilineFromGeoJsonString(string: string) {
  try {
    return multiLineFromGeoJson(JSON.parse(string))
  } catch (e) {
    return null
  }
}

export function multiLineToGeoJson(multiline: MapMultiLine) {
  return {
    type: 'FeatureCollection',
    features: multiline.map((line) => ({
      type: 'Feature',
      properties: {},
      geometry: {
        coordinates: line.map(([lat, lng]) => [lng, lat]),
        type: 'LineString',
      },
    })),
  }
}
export function multiLineToGeoJsonString(multiline: MapMultiLine) {
  return JSON.stringify(multiLineToGeoJson(multiline), null, 2)
}

export const geoJsonSchema = z.object({
  type: z.literal('FeatureCollection'),
  features: z.array(
    z.object({
      type: z.literal('Feature'),
      geometry: z.object({
        type: z.literal('LineString'),
        properties: z.object({}).optional(),
        coordinates: z.array(z.tuple([z.number(), z.number()])),
      }),
    })
  ),
})
export const routes = mysqlTable('route', {
  // Other fields...
    path: multiLineType('path').notNull(),
})
const getRoute = db.query.routes
  .findFirst({
    columns: {
      id: true,
      name: true,
    },
    extras: {
      path: selectMultiLine('path', routes.path),
    },
    where: (route, { eq }) => eq(route.id, sql.placeholder('id')),
  })
  .prepare()

export const routesRouter = router({
  get: publicProcedure.input(getRoutesSchema).query(async ({ input }) => {
    const result = await getRoute.execute({ id: input.id })
    if (!result) return undefined
    return calculatePath(result)
  }),
})
Sparticuz commented 7 months ago

This code also seems to work fine with MariaDB.