Transitioned to office SDK

This commit is contained in:
Tevon Strand-Brown
2024-12-12 10:54:50 -08:00
parent 50e8dc3365
commit 09c04edae1
2 changed files with 140 additions and 104 deletions

5
src/helpers.ts Normal file
View File

@@ -0,0 +1,5 @@
export const formatToolCall = (obj: any, isError: boolean = false) => {
return {
content: [{ type: "text", text: JSON.stringify(obj, null, 2), isError }],
};
}

View File

@@ -1,131 +1,162 @@
import { get_hass } from "./hass/index.js"; import { get_hass } from "./hass/index.js";
import { LiteMCP } from "litemcp"; import { Server, } from "@modelcontextprotocol/sdk/server/index.js";
import { z } from "zod"; import { z } from "zod";
import { TAreaId, TFloorId, TRawDomains, TRawEntityIds } from "@digital-alchemy/hass"; import { TAreaId, TFloorId, TRawDomains, TRawEntityIds } from "@digital-alchemy/hass";
import { zodToJsonSchema } from "zod-to-json-schema"; import { zodToJsonSchema } from "zod-to-json-schema";
import { ListRequestSchema, AreaSchema, FloorSchema } from "./schemas.js"; import { ListToolsRequestSchema, CallToolRequestSchema } from "@modelcontextprotocol/sdk/types.js";
import { formatToolCall } from "./helpers.js";
import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js";
const server = new LiteMCP( const server = new Server({
"example-server", name: "homeassistant-mcp-server",
"1.0.0", version: "0.1.0",
); }, {
capabilities: {
tools: {}
}
});
const hass = await get_hass(); const hass = await get_hass();
server.addTool({ server.setRequestHandler(ListToolsRequestSchema, async (request) => {
name: "list_domains", return {
description: "Lists all domains in the home", tools: [
parameters: z.object({}), {
execute: async () => { name: "list_domains",
return ["light", "climate", "alarm_control_panel", "cover", "switch", "sensor", "button"]; description: "Lists all domains in the home",
inputSchema: zodToJsonSchema(z.object({})),
},
{
name: "list_areas",
description: "Lists all areas in the home",
inputSchema: zodToJsonSchema(z.object({})),
},
{
name: "list_floors",
description: "Lists all floors in the home",
inputSchema: zodToJsonSchema(z.object({})),
},
{
name: "get_entity_state",
description: "Gets the state of an entity",
inputSchema: zodToJsonSchema(z.object({
entity_id: z.string()
})),
},
{
name: "get_entities",
description: "Gets entities, filtered by domain, floor, and area as needed",
inputSchema: zodToJsonSchema(z.object({
domain: z.string().optional(),
floor: z.string().optional(),
area: z.string().optional(),
})),
},
{
name: "get_entity_state_by_ids",
description: "Gets a list of entities from a list of entity ids. Use this tool when there is more than one entity to get the state of.",
inputSchema: zodToJsonSchema(z.object({
entity_ids: z.array(z.string())
})),
}
]
} }
}); });
server.addTool({ server.setRequestHandler(CallToolRequestSchema, async (request) => {
name: "list_areas",
description: "Lists all areas in the home", switch (request.params.name) {
parameters: z.object({}), case "list_domains":
execute: async () => { return formatToolCall(listDomains());
return await areasRequestHandler(); case "list_areas":
return formatToolCall(await listAreas());
case "list_floors":
return formatToolCall(await listFloors());
case "get_entity_state":
const entity_id = request.params.entity_id as TRawEntityIds;
return formatToolCall(await getEntityState(entity_id));
case "get_entities":
const entities = await getEntities(request.params.arguments as { domain?: TRawDomains, floor?: TFloorId, area?: TAreaId });
return formatToolCall(entities);
case "get_entity_state_by_ids":
const entity_ids = request.params?.arguments?.entity_ids as TRawEntityIds[];
if (!entity_ids) {
return formatToolCall({
error: "No entity ids provided"
}, true);
}
return formatToolCall(getEntityStateByIds(entity_ids));
case "get_entity_history":
return formatToolCall(getEntityHistory(request.params.arguments as { entity_id: TRawEntityIds, start_time: string, end_time?: string }));
case "get_entity_history_by_ids":
return formatToolCall(getEntityHistoryByIds(request.params.arguments as { entity_ids: TRawEntityIds[], start_time: string, end_time?: string }));
} }
return formatToolCall({
error: "Tool not found"
}, true);
}); });
server.addTool({
name: "list_floors", async function runServer() {
description: "Lists all floors in the home", const transport = new StdioServerTransport();
parameters: z.object({}), await server.connect(transport);
execute: async () => { console.error("Home Assistant MCP Server running on stdio");
return await floorsRequestHandler(); }
}
runServer().catch((error) => {
console.error("Fatal error in runServer():", error);
process.exit(1);
}); });
server.addTool({ const listDomains = () => {
name: "get_entity_state", return ["light", "climate", "alarm_control_panel", "cover", "switch", "sensor", "button"];
description: "Gets the state of an entity", }
parameters: z.object({
entity_id: z.string()
}),
execute: async (request) => {
return await hass.hass.entity.getCurrentState(request.entity_id as TRawEntityIds);
}
});
server.addTool({ const getEntityHistoryByIds = (params: { entity_ids: TRawEntityIds[], start_time: string, end_time?: string }) => {
name: "get_entities", return hass.hass.entity.history({
description: "Gets entities, filtered by domain, floor, and area as needed", entity_ids: params.entity_ids as TRawEntityIds[],
parameters: z.object({ end_time: params.end_time ? new Date(params.end_time) : new Date(),
domain: z.string().optional(), start_time: params.start_time
floor: z.string().optional(), });
area: z.string().optional(), }
}),
execute: async (request) => {
if (request.floor) {
return hass.hass.idBy.floor(request.floor as TFloorId, request.domain as TRawDomains || undefined);
}
if (request.area) {
return hass.hass.idBy.area(request.area as TAreaId, request.domain as TRawDomains || undefined);
}
if (request.domain) {
return hass.hass.entity.listEntities(request.domain as TRawDomains);
}
return hass.hass.entity.listEntities();
}
});
server.addTool({ const getEntityHistory = async (params: { entity_id: TRawEntityIds, start_time: string, end_time?: string }) => {
name: "get_entity_state_by_ids", return await hass.hass.entity.history({
description: "Gets a list of entities from a list of entity ids. Use this tool when there is more than one entity to get the state of.", entity_ids: [params.entity_id as TRawEntityIds],
parameters: z.object({ end_time: params.end_time ? new Date(params.end_time) : new Date(),
entity_ids: z.array(z.string()) start_time: params.start_time
}), });
execute: async (request) => { }
const entities = request.entity_ids.map(entity_id => hass.hass.entity.getCurrentState(entity_id as TRawEntityIds));
return entities;
}
})
server.addTool({ const getEntityStateByIds = (entity_ids: TRawEntityIds[]) => {
name: "get_entity_history", const entities = entity_ids.map(entity_id => hass.hass.entity.getCurrentState(entity_id as TRawEntityIds));
description: "Gets the history of an entity", return entities;
parameters: z.object({ }
entity_id: z.string(),
start_time: z.string(),
end_time: z.string().optional()
}),
execute: async (request) => {
return await hass.hass.entity.history({
end_time: request.end_time ? new Date(request.end_time) : new Date(),
entity_ids: [request.entity_id as TRawEntityIds],
start_time: request.start_time
});
}
})
server.addTool({ const getEntities = async (params: { domain?: TRawDomains, floor?: TFloorId, area?: TAreaId }) => {
name: "get_entity_history_by_ids", if (params.floor) {
description: "Gets the history of a list of entities", return hass.hass.idBy.floor(params.floor as TFloorId, params.domain as TRawDomains || undefined);
parameters: z.object({
entity_ids: z.array(z.string()),
start_time: z.string(),
end_time: z.string().optional()
}),
execute: async (request) => {
return await hass.hass.entity.history({
entity_ids: request.entity_ids as TRawEntityIds[],
end_time: request.end_time ? new Date(request.end_time) : new Date(),
start_time: request.start_time
});
} }
}) if (params.area) {
return hass.hass.idBy.area(params.area as TAreaId, params.domain as TRawDomains || undefined);
const areasRequestHandler = async () => { }
if (params.domain) {
return hass.hass.entity.listEntities(params.domain as TRawDomains);
}
return hass.hass.entity.listEntities();
}
const getEntityState = async (entity_id: TRawEntityIds) => {
return await hass.hass.entity.getCurrentState(entity_id);
}
const listAreas = async () => {
const areas = await hass.hass.area.list() const areas = await hass.hass.area.list()
return areas; return areas;
} }
const floorsRequestHandler = async () => { const listFloors = async () => {
const floors = await hass.hass.floor.list() const floors = await hass.hass.floor.list()
return floors; return floors;
} }
server.start();