From 09c04edae16079e5fe2c451178896655859a56a9 Mon Sep 17 00:00:00 2001 From: Tevon Strand-Brown Date: Thu, 12 Dec 2024 10:54:50 -0800 Subject: [PATCH] Transitioned to office SDK --- src/helpers.ts | 5 ++ src/index.ts | 239 ++++++++++++++++++++++++++++--------------------- 2 files changed, 140 insertions(+), 104 deletions(-) create mode 100644 src/helpers.ts diff --git a/src/helpers.ts b/src/helpers.ts new file mode 100644 index 0000000..03f90e7 --- /dev/null +++ b/src/helpers.ts @@ -0,0 +1,5 @@ +export const formatToolCall = (obj: any, isError: boolean = false) => { + return { + content: [{ type: "text", text: JSON.stringify(obj, null, 2), isError }], + }; +} \ No newline at end of file diff --git a/src/index.ts b/src/index.ts index e0eb657..231ed48 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,131 +1,162 @@ import { get_hass } from "./hass/index.js"; -import { LiteMCP } from "litemcp"; +import { Server, } from "@modelcontextprotocol/sdk/server/index.js"; import { z } from "zod"; import { TAreaId, TFloorId, TRawDomains, TRawEntityIds } from "@digital-alchemy/hass"; import { zodToJsonSchema } from "zod-to-json-schema"; -import { ListRequestSchema, AreaSchema, FloorSchema } from "./schemas.js"; +import { ListToolsRequestSchema, CallToolRequestSchema } from "@modelcontextprotocol/sdk/types.js"; +import { formatToolCall } from "./helpers.js"; +import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; -const server = new LiteMCP( - "example-server", - "1.0.0", -); +const server = new Server({ + name: "homeassistant-mcp-server", + version: "0.1.0", +}, { + capabilities: { + tools: {} + } +}); const hass = await get_hass(); -server.addTool({ - name: "list_domains", - description: "Lists all domains in the home", - parameters: z.object({}), - execute: async () => { - return ["light", "climate", "alarm_control_panel", "cover", "switch", "sensor", "button"]; +server.setRequestHandler(ListToolsRequestSchema, async (request) => { + return { + tools: [ + { + name: "list_domains", + description: "Lists all domains in the home", + inputSchema: zodToJsonSchema(z.object({})), + }, + { + name: "list_areas", + description: "Lists all areas in the home", + inputSchema: zodToJsonSchema(z.object({})), + }, + { + name: "list_floors", + description: "Lists all floors in the home", + inputSchema: zodToJsonSchema(z.object({})), + }, + { + name: "get_entity_state", + description: "Gets the state of an entity", + inputSchema: zodToJsonSchema(z.object({ + entity_id: z.string() + })), + }, + { + name: "get_entities", + description: "Gets entities, filtered by domain, floor, and area as needed", + inputSchema: zodToJsonSchema(z.object({ + domain: z.string().optional(), + floor: z.string().optional(), + area: z.string().optional(), + })), + }, + { + name: "get_entity_state_by_ids", + description: "Gets a list of entities from a list of entity ids. Use this tool when there is more than one entity to get the state of.", + inputSchema: zodToJsonSchema(z.object({ + entity_ids: z.array(z.string()) + })), + } + ] } }); -server.addTool({ - name: "list_areas", - description: "Lists all areas in the home", - parameters: z.object({}), - execute: async () => { - return await areasRequestHandler(); +server.setRequestHandler(CallToolRequestSchema, async (request) => { + + switch (request.params.name) { + case "list_domains": + return formatToolCall(listDomains()); + case "list_areas": + return formatToolCall(await listAreas()); + case "list_floors": + return formatToolCall(await listFloors()); + case "get_entity_state": + const entity_id = request.params.entity_id as TRawEntityIds; + return formatToolCall(await getEntityState(entity_id)); + case "get_entities": + const entities = await getEntities(request.params.arguments as { domain?: TRawDomains, floor?: TFloorId, area?: TAreaId }); + return formatToolCall(entities); + case "get_entity_state_by_ids": + const entity_ids = request.params?.arguments?.entity_ids as TRawEntityIds[]; + if (!entity_ids) { + return formatToolCall({ + error: "No entity ids provided" + }, true); + } + return formatToolCall(getEntityStateByIds(entity_ids)); + case "get_entity_history": + return formatToolCall(getEntityHistory(request.params.arguments as { entity_id: TRawEntityIds, start_time: string, end_time?: string })); + case "get_entity_history_by_ids": + return formatToolCall(getEntityHistoryByIds(request.params.arguments as { entity_ids: TRawEntityIds[], start_time: string, end_time?: string })); } + + return formatToolCall({ + error: "Tool not found" + }, true); }); -server.addTool({ - name: "list_floors", - description: "Lists all floors in the home", - parameters: z.object({}), - execute: async () => { - return await floorsRequestHandler(); - } + +async function runServer() { + const transport = new StdioServerTransport(); + await server.connect(transport); + console.error("Home Assistant MCP Server running on stdio"); +} + +runServer().catch((error) => { + console.error("Fatal error in runServer():", error); + process.exit(1); }); -server.addTool({ - name: "get_entity_state", - description: "Gets the state of an entity", - parameters: z.object({ - entity_id: z.string() - }), - execute: async (request) => { - return await hass.hass.entity.getCurrentState(request.entity_id as TRawEntityIds); - } -}); +const listDomains = () => { + return ["light", "climate", "alarm_control_panel", "cover", "switch", "sensor", "button"]; +} -server.addTool({ - name: "get_entities", - description: "Gets entities, filtered by domain, floor, and area as needed", - parameters: z.object({ - domain: z.string().optional(), - floor: z.string().optional(), - area: z.string().optional(), - }), - execute: async (request) => { - if (request.floor) { - return hass.hass.idBy.floor(request.floor as TFloorId, request.domain as TRawDomains || undefined); - } - if (request.area) { - return hass.hass.idBy.area(request.area as TAreaId, request.domain as TRawDomains || undefined); - } - if (request.domain) { - return hass.hass.entity.listEntities(request.domain as TRawDomains); - } - return hass.hass.entity.listEntities(); - } -}); +const getEntityHistoryByIds = (params: { entity_ids: TRawEntityIds[], start_time: string, end_time?: string }) => { + return hass.hass.entity.history({ + entity_ids: params.entity_ids as TRawEntityIds[], + end_time: params.end_time ? new Date(params.end_time) : new Date(), + start_time: params.start_time + }); +} -server.addTool({ - name: "get_entity_state_by_ids", - description: "Gets a list of entities from a list of entity ids. Use this tool when there is more than one entity to get the state of.", - parameters: z.object({ - entity_ids: z.array(z.string()) - }), - execute: async (request) => { - const entities = request.entity_ids.map(entity_id => hass.hass.entity.getCurrentState(entity_id as TRawEntityIds)); - return entities; - } -}) +const getEntityHistory = async (params: { entity_id: TRawEntityIds, start_time: string, end_time?: string }) => { + return await hass.hass.entity.history({ + entity_ids: [params.entity_id as TRawEntityIds], + end_time: params.end_time ? new Date(params.end_time) : new Date(), + start_time: params.start_time + }); +} -server.addTool({ - name: "get_entity_history", - description: "Gets the history of an entity", - parameters: z.object({ - entity_id: z.string(), - start_time: z.string(), - end_time: z.string().optional() - }), - execute: async (request) => { - return await hass.hass.entity.history({ - end_time: request.end_time ? new Date(request.end_time) : new Date(), - entity_ids: [request.entity_id as TRawEntityIds], - start_time: request.start_time - }); - } -}) +const getEntityStateByIds = (entity_ids: TRawEntityIds[]) => { + const entities = entity_ids.map(entity_id => hass.hass.entity.getCurrentState(entity_id as TRawEntityIds)); + return entities; +} -server.addTool({ - name: "get_entity_history_by_ids", - description: "Gets the history of a list of entities", - parameters: z.object({ - entity_ids: z.array(z.string()), - start_time: z.string(), - end_time: z.string().optional() - }), - execute: async (request) => { - return await hass.hass.entity.history({ - entity_ids: request.entity_ids as TRawEntityIds[], - end_time: request.end_time ? new Date(request.end_time) : new Date(), - start_time: request.start_time - }); +const getEntities = async (params: { domain?: TRawDomains, floor?: TFloorId, area?: TAreaId }) => { + if (params.floor) { + return hass.hass.idBy.floor(params.floor as TFloorId, params.domain as TRawDomains || undefined); } -}) - -const areasRequestHandler = async () => { + if (params.area) { + return hass.hass.idBy.area(params.area as TAreaId, params.domain as TRawDomains || undefined); + } + if (params.domain) { + return hass.hass.entity.listEntities(params.domain as TRawDomains); + } + return hass.hass.entity.listEntities(); +} + +const getEntityState = async (entity_id: TRawEntityIds) => { + return await hass.hass.entity.getCurrentState(entity_id); +} + +const listAreas = async () => { const areas = await hass.hass.area.list() return areas; } -const floorsRequestHandler = async () => { +const listFloors = async () => { const floors = await hass.hass.floor.list() return floors; -} - -server.start(); +} \ No newline at end of file